we resolvin'

This commit is contained in:
Ali Mohammad Pur 2025-03-09 17:30:25 +01:00
parent 1933b68b41
commit 8cb56326b4
7 changed files with 217 additions and 82 deletions

View file

@ -8,8 +8,9 @@
namespace AK {
CountingStream::CountingStream(MaybeOwned<Stream> stream)
CountingStream::CountingStream(MaybeOwned<Stream> stream, size_t offset)
: m_stream(move(stream))
, m_read_bytes(offset)
{
}

View file

@ -13,7 +13,7 @@ namespace AK {
class CountingStream : public Stream {
public:
CountingStream(MaybeOwned<Stream>);
CountingStream(MaybeOwned<Stream>, size_t offset = 0);
u64 read_bytes() const;

View file

@ -663,11 +663,17 @@ ErrorOr<DomainName> DomainName::from_raw(ParseContext& ctx)
constexpr static u8 OffsetMarkerMask = 0b11000000;
if ((length & OffsetMarkerMask) == OffsetMarkerMask) {
// This is a pointer to a prior domain name.
u16 const offset = static_cast<u16>(length & ~OffsetMarkerMask) << 8 | TRY(ctx.stream.read_value<u8>());
u16 offset = static_cast<u16>(length & ~OffsetMarkerMask) << 8 | TRY(ctx.stream.read_value<u8>());
if (auto it = ctx.pointers->find_largest_not_above_iterator(offset); !it.is_end()) {
auto labels = it->labels;
for (auto& entry : labels)
name.labels.append(entry);
size_t start_index = 0;
size_t start_entry_offset = offset - it.key();
while (start_entry_offset > 0 && start_index < labels.size()) {
start_entry_offset -= labels[start_index].length() + 1; // +1 for the length byte
start_index++;
}
for (size_t i = start_index; i < labels.size(); ++i)
name.labels.append(labels[i].substring_view(i == start_index ? start_entry_offset : 0));
break;
}
dbgln("Invalid domain name pointer in label, no prior domain name found around offset {}", offset);
@ -786,9 +792,10 @@ ErrorOr<ResourceRecord> ResourceRecord::from_raw(ParseContext& ctx)
ResourceType type;
Class class_;
u32 ttl;
size_t original_offset = ctx.stream.read_bytes();
{
RecordingStream rr_stream { ctx.stream };
CountingStream rr_counting_stream { MaybeOwned<Stream>(rr_stream) };
CountingStream rr_counting_stream { MaybeOwned<Stream>(rr_stream), original_offset };
ParseContext rr_ctx { rr_counting_stream, move(ctx.pointers) };
ScopeGuard guard([&] { ctx.pointers = move(rr_ctx.pointers); });
@ -809,13 +816,14 @@ ErrorOr<ResourceRecord> ResourceRecord::from_raw(ParseContext& ctx)
class_ = static_cast<Class>(static_cast<u16>(TRY(rr_ctx.stream.read_value<NetworkOrdered<u16>>())));
ttl = static_cast<u32>(TRY(rr_ctx.stream.read_value<NetworkOrdered<u32>>()));
auto rd_length = static_cast<u16>(TRY(rr_ctx.stream.read_value<NetworkOrdered<u16>>()));
original_offset = rr_ctx.stream.read_bytes();
TRY(rr_ctx.stream.read_until_filled(TRY(rdata.get_bytes_for_writing(rd_length))));
rr_raw_data = move(rr_stream).take_recorded_data();
}
FixedMemoryStream stream { rdata.bytes() };
CountingStream rdata_stream { MaybeOwned<Stream>(stream) };
CountingStream rdata_stream { MaybeOwned<Stream>(stream), original_offset };
ParseContext rdata_ctx { rdata_stream, move(ctx.pointers) };
ScopeGuard guard([&] { ctx.pointers = move(rdata_ctx.pointers); });
@ -1027,6 +1035,24 @@ ErrorOr<Records::SOA> Records::SOA::from_raw(ParseContext& ctx)
return Records::SOA { move(mname), move(rname), serial, refresh, retry, expire, minimum };
}
ErrorOr<void> Records::SOA::to_raw(ByteBuffer& buffer) const
{
TRY(mname.to_raw(buffer));
TRY(rname.to_raw(buffer));
auto const output_size = 5 * sizeof(u32);
FixedMemoryStream stream { TRY(buffer.get_bytes_for_writing(output_size)) };
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(serial)));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(refresh)));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(retry)));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(expire)));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(minimum)));
return {};
}
ErrorOr<Records::MX> Records::MX::from_raw(ParseContext& ctx)
{
// RFC 1035, 3.3.9. MX RDATA format.

View file

@ -98,6 +98,9 @@ struct DomainName {
copy.labels.take_first();
return copy;
}
bool operator==(DomainName const&) const& = default;
bool operator!=(DomainName const&) const& = default;
};
// Listing from IANA https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4.
@ -419,7 +422,7 @@ struct SOA {
static constexpr ResourceType type = ResourceType::SOA;
static ErrorOr<SOA> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented: SOA::to_raw"); }
ErrorOr<void> to_raw(ByteBuffer&) const;
ErrorOr<String> to_string() const
{
return String::formatted("SOA MName: '{}', RName: '{}', Serial: {}, Refresh: {}, Retry: {}, Expire: {}, Minimum: {}", mname.to_string(), rname.to_string(), serial, refresh, retry, expire, minimum);

View file

@ -28,6 +28,16 @@
#undef DNS_DEBUG
#define DNS_DEBUG 1
#define TRY_OR_REJECT_PROMISE(promise, expr) \
({ \
auto _result = (expr); \
if (_result.is_error()) { \
promise->reject(_result.release_error()); \
return promise; \
} \
_result.release_value(); \
})
namespace DNS {
class Resolver;
@ -92,6 +102,35 @@ public:
return result;
}
Vector<Messages::ResourceRecord> records(Messages::ResourceType type) const
{
Vector<Messages::ResourceRecord> result;
for (auto& re : m_cached_records) {
if (re.record.type == type)
result.append(re.record);
}
return result;
}
Messages::ResourceRecord const& record(Messages::ResourceType type) const
{
for (auto const& re : m_cached_records) {
if (re.record.type == type)
return re.record;
}
VERIFY_NOT_REACHED();
}
template<typename RR>
RR const& record() const
{
for (auto const& re : m_cached_records) {
if (re.record.type == RR::type)
return re.record.record.get<RR>();
}
VERIFY_NOT_REACHED();
}
bool has_record_of_type(Messages::ResourceType type, bool later = false) const
{
if (later && m_desired_types.contains(type))
@ -364,7 +403,8 @@ public:
dbgln_if(DNS_DEBUG, "DNS: Adding {} to cache", name);
auto ptr = make_ref_counted<LookupResult>(domain_name);
ptr->set_dnssec_validated(options.validate_dnssec_locally);
if (!ptr->is_dnssec_validated())
ptr->set_dnssec_validated(options.validate_dnssec_locally);
for (auto const& type : desired_types)
ptr->will_add_record_of_type(type);
cache.set(name, ptr);
@ -580,6 +620,66 @@ private:
Vector<Messages::Records::DNSKEY> dnskeys;
};
NonnullRefPtr<Core::Promise<bool>> validate_dnssec_chain_step(Messages::DomainName const& name, bool top_level = false)
{
dbgln_if(DNS_DEBUG, "DNS: Validating DNSSEC chain for {}", name.to_string());
auto promise = Core::Promise<bool>::construct();
// - If this is the root, we're done, just return true.
if (name.labels.size() == 0) {
promise->resolve(true);
return promise;
}
// - Lookup the SOA record for the domain.
auto soa_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::SOA }, { .validate_dnssec_locally = !top_level })->await()));
// - If we have no SOA record-
if (!soa_result->has_record_of_type(Messages::ResourceType::SOA)) {
dbgln_if(DNS_DEBUG, "DNS: No SOA record found for {}", name.to_string());
// - First, check for a DS record-
auto ds_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::DS }, { .validate_dnssec_locally = !top_level })->await()));
// - If there's no DS record, check for an NS record-
if (!ds_result->has_record_of_type(Messages::ResourceType::DS)) {
dbgln_if(DNS_DEBUG, "DNS: No DS record found for {}", name.to_string());
// - If there's no DS record, check for an NS record-
auto ns_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::NS }, { .validate_dnssec_locally = !top_level })->await()));
if (ns_result->has_record_of_type(Messages::ResourceType::NS)) {
// - but if there _is_ an NS record, this is a broken delegation, so reject.
dbgln_if(DNS_DEBUG, "DNS: Found NS record for {}", name.to_string());
promise->resolve(false);
return promise;
}
dbgln_if(DNS_DEBUG, "DNS: No NS record found for {}", name.to_string());
// this is just part of the parent delegation, so go up one level.
return validate_dnssec_chain_step(name.parent());
}
// - If there is a DS record, this is a separate zone...but since we don't have an SOA record, this is a misconfigured zone.
// Let's just reject.
dbgln_if(DNS_DEBUG, "DNS: Found DS record for {}", name.to_string());
promise->resolve(false);
return promise;
}
// So we have an SOA record, there's much rejoicing and we can continue.
auto& soa = soa_result->record<Messages::Records::SOA>();
dbgln_if(DNS_DEBUG, "DNS: Found SOA record for {}: {}", name.to_string(), soa.mname.to_string());
if (soa.mname == name.parent()) {
// Just go up one level, all is well.
return validate_dnssec_chain_step(name.parent());
}
// So this is a separate zone, let's look up the DS record.
auto ds_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::DS }, { .validate_dnssec_locally = false })->await()));
if (!ds_result->has_record_of_type(Messages::ResourceType::DS)) {
// If there's no DS record, this is a misconfigured zone.
dbgln_if(DNS_DEBUG, "DNS: No DS record found for {}", name.to_string());
promise->resolve(false);
return promise;
}
promise->resolve(true);
return promise;
}
ErrorOr<void> validate_dnssec(Messages::Message message, PendingLookup& lookup, NonnullRefPtr<LookupResult> result)
{
struct RecordAndRRSIG {
@ -588,7 +688,6 @@ private:
};
HashMap<Messages::ResourceType, RecordAndRRSIG> records_with_rrsigs;
for (auto& record : message.answers) {
dbgln("- {}", record.to_string());
if (record.type == Messages::ResourceType::RRSIG) {
auto& rrsig = record.record.get<Messages::Records::RRSIG>();
auto type = rrsig.type_covered;
@ -621,6 +720,15 @@ private:
auto is_root_zone = lookup.parsed_name.labels.size() == 0;
if (!is_root_zone) {
auto chain_valid_result = validate_dnssec_chain_step(name, true)->await();
if (chain_valid_result.is_error()) {
lookup.promise->reject(chain_valid_result.release_error());
return;
}
if (!chain_valid_result.value()) {
lookup.promise->reject(Error::from_string_literal("DNSSEC chain is invalid"));
return;
}
auto parent_result = this->lookup(lookup.parsed_name.parent().to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::DNSKEY }, { .validate_dnssec_locally = true })
->await();
if (parent_result.is_error()) {
@ -633,7 +741,10 @@ private:
return;
}
parent_zone_keys = parent_result.release_value()->used_dnskeys();
parent_zone_keys = parent_result.value()->used_dnskeys();
for (auto& rr : parent_result.value()->records(Messages::ResourceType::DNSKEY))
parent_zone_keys.append(rr.record.get<Messages::Records::DNSKEY>());
dbgln("Found {} DNSKEYs for parent zone ({})", parent_zone_keys.size(), lookup.parsed_name.parent().to_string());
}
auto resolve_using_keys = [=, this, records_with_rrsigs = move(records_with_rrsigs)](Vector<Messages::Records::DNSKEY> keys) mutable {
@ -708,29 +819,38 @@ private:
lookup.promise = move(promise);
};
if (is_root_zone) {
return resolve_using_keys(Vector { Messages::Records::DNSKEY {
.flags = 256,
.protocol = 3,
.algorithm = Messages::DNSSEC::Algorithm::RSASHA256,
.public_key = MUST(decode_base64("AwEAAa96jeuknZlaeSrvyAJj6ZHv28hhOKkx3rLGXVaC6rXTsDc449/cidltpkyGwCJNnOAlFNKF2jBosZBU5eeHspaQWOmOElZsjICMQMC3aeHbGiShvZsx4wMYSjH8e7Vrhbu6irwCzVBApESjbUdpWWmEnhathWu1jo+siFUiRAAxm9qyJNg/wOZqqzL/dL/q8PkcRU5oUKEpUge71M3ej2/7CPqpdVwuMoTvoB+ZOT4YeGyxMvHmbrxlFzGOHOijtzN+u1TQNatX2XBuzZNQ1K+s2CXkPIZo7s6JgZyvaBevYtxPvYLw4z9mR7K2vaF18UYH9Z9GNUUeayffKC73PYc="sv)),
.calculated_key_tag = 38696,
} });
resolve_using_keys(Vector<Messages::Records::DNSKEY> {
{
.flags = 257,
.protocol = 3,
.algorithm = Messages::DNSSEC::Algorithm::RSASHA256,
.public_key = MUST(decode_base64("AwEAAaz/tAm8yTn4Mfeh5eyI96WSVexTBAvkMgJzkKTOiW1vkIbzxeF3+/4RgWOq7HrxRixHlFlExOLAJr5emLvN7SWXgnLh4+B5xQlNVz8Og8kvArMtNROxVQuCaSnIDdD5LKyWbRd2n9WGe2R8PzgCmr3EgVLrjyBxWezF0jLHwVN8efS3rCj/EWgvIWgb9tarpVUDK/b58Da+sqqls3eNbuv7pr+eoZG+SrDK6nWeL3c6H5Apxz7LjVc1uTIdsIXxuOLYA4/ilBmSVIzuDWfdRUfhHdY6+cn8HFRm+2hM8AnXGXws9555KrUB5qihylGa8subX2Nn6UwNR1AkUTV74bU="sv)),
.calculated_key_tag = 20326,
},
{
.flags = 256,
.protocol = 3,
.algorithm = Messages::DNSSEC::Algorithm::RSASHA256,
.public_key = MUST(decode_base64("AwEAAa96jeuknZlaeSrvyAJj6ZHv28hhOKkx3rLGXVaC6rXTsDc449/cidltpkyGwCJNnOAlFNKF2jBosZBU5eeHspaQWOmOElZsjICMQMC3aeHbGiShvZsx4wMYSjH8e7Vrhbu6irwCzVBApESjbUdpWWmEnhathWu1jo+siFUiRAAxm9qyJNg/wOZqqzL/dL/q8PkcRU5oUKEpUge71M3ej2/7CPqpdVwuMoTvoB+ZOT4YeGyxMvHmbrxlFzGOHOijtzN+u1TQNatX2XBuzZNQ1K+s2CXkPIZo7s6JgZyvaBevYtxPvYLw4z9mR7K2vaF18UYH9Z9GNUUeayffKC73PYc="sv)),
.calculated_key_tag = 38696,
},
});
return;
}
dbgln_if(DNS_DEBUG, "DNS: Starting DNSKEY lookup for {}", lookup.name);
this->lookup(lookup.name, Messages::Class::IN, { Messages::ResourceType::DNSKEY }, { .validate_dnssec_locally = false })
->when_resolved([=](NonnullRefPtr<LookupResult const>& dnskey_lookup_result) mutable {
dbgln_if(DNS_DEBUG, "DNSKEY for {}:", name.to_string());
for (auto& record : dnskey_lookup_result->records())
auto key_records = dnskey_lookup_result->records(Messages::ResourceType::DNSKEY);
for (auto& record : key_records)
dbgln_if(DNS_DEBUG, "- DNSKEY: {}", record.to_string());
Vector<Messages::Records::DNSKEY> keys;
keys.ensure_capacity(parent_zone_keys.size() + dnskey_lookup_result->records().size());
for (auto& record : parent_zone_keys)
keys.append(record);
for (auto& record : dnskey_lookup_result->records()) {
if (auto k = record.record.get_pointer<Messages::Records::DNSKEY>())
keys.append(move(*k));
}
for (auto& record : key_records)
keys.append(move(record.record).get<Messages::Records::DNSKEY>());
resolve_using_keys(move(keys));
})
.when_rejected([=](auto& error) mutable {
@ -755,16 +875,6 @@ private:
return nullptr;
}
#define TRY_OR_REJECT_PROMISE(promise, expr) \
({ \
auto _result = (expr); \
if (_result.is_error()) { \
promise->reject(_result.release_error()); \
return promise; \
} \
_result.release_value(); \
})
NonnullRefPtr<Core::Promise<Empty>> validate_rrset_with_rrsig(CanonicalizedRRSetWithRRSIG rrset_with_rrsig, NonnullRefPtr<LookupResult> result)
{
auto promise = Core::Promise<Empty>::construct();
@ -788,29 +898,25 @@ private:
for (auto& rr : canon_encoded_rrs)
canon_encoded.append(rr);
if (result->name().labels.size() == 1) {
dbgln_if(DNS_DEBUG, "Root zone, implicitly trusting DNSKEY for {}", result->name().to_string());
promise->resolve({});
return promise;
}
auto& dnskey = *find_dnskey(rrset_with_rrsig);
dbgln_if(DNS_DEBUG, "Validating RRSet with RRSIG for {}", result->name().to_string());
for (auto& rr : rrset_with_rrsig.rrset)
dbgln_if(DNS_DEBUG, "- RR {}", rr.to_string());
for (auto& canon : canon_encoded_rrs) {
FixedMemoryStream stream { canon.bytes() };
CountingStream rr_counting_stream { MaybeOwned<Stream>(stream) };
DNS::Messages::ParseContext rr_ctx { rr_counting_stream, make<RedBlackTree<u16, Messages::DomainName>>() };
auto maybe_decoded = Messages::ResourceRecord::from_raw(rr_ctx);
if (maybe_decoded.is_error())
dbgln("-- Failed to decode RR: {}", maybe_decoded.error());
else
dbgln("-- Canon encoded (decoded): {}", maybe_decoded.value().to_string());
if constexpr (DNS_DEBUG) {
dbgln("Validating RRSet with RRSIG for {}", result->name().to_string());
for (auto& rr : rrset_with_rrsig.rrset)
dbgln("- RR {}", rr.to_string());
for (auto& canon : canon_encoded_rrs) {
FixedMemoryStream stream { canon.bytes() };
CountingStream rr_counting_stream { MaybeOwned<Stream>(stream) };
DNS::Messages::ParseContext rr_ctx { rr_counting_stream, make<RedBlackTree<u16, Messages::DomainName>>() };
auto maybe_decoded = Messages::ResourceRecord::from_raw(rr_ctx);
if (maybe_decoded.is_error())
dbgln("-- Failed to decode RR: {}", maybe_decoded.error());
else
dbgln("-- Canon encoded (decoded): {}", maybe_decoded.value().to_string());
}
dbgln( "- DNSKEY {}", dnskey.to_string());
dbgln("- RRSIG {}", rrsig.to_string());
}
dbgln_if(DNS_DEBUG, "- DNSKEY {}", dnskey.to_string());
dbgln_if(DNS_DEBUG, "- RRSIG {}", rrsig.to_string());
ByteBuffer to_be_signed;
{
@ -852,15 +958,15 @@ private:
TRY_OR_REJECT_PROMISE(promise, rrsig.signers_name.to_raw(to_be_signed));
TRY_OR_REJECT_PROMISE(promise, to_be_signed.try_append(canon_encoded.data(), canon_encoded.size()));
dbgln("To be signed: {:hex-dump}", to_be_signed.bytes());
dbgln_if(DNS_DEBUG, "To be signed: {:hex-dump}", to_be_signed.bytes());
constexpr auto rsa_prefix_for = [](Crypto::Hash::HashKind kind) -> ReadonlyBytes {
if (kind == Crypto::Hash::HashKind::SHA256)
return "\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20"sv.bytes();
if (kind == Crypto::Hash::HashKind::SHA512)
return "\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40"sv.bytes();
VERIFY_NOT_REACHED();
};
// constexpr auto rsa_prefix_for = [](Crypto::Hash::HashKind kind) -> ReadonlyBytes {
// if (kind == Crypto::Hash::HashKind::SHA256)
// return "\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20"sv.bytes();
// if (kind == Crypto::Hash::HashKind::SHA512)
// return "\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40"sv.bytes();
// VERIFY_NOT_REACHED();
// };
switch (dnskey.algorithm) {
case Messages::DNSSEC::Algorithm::RSAMD5: {
@ -927,17 +1033,10 @@ private:
break;
}
case Messages::DNSSEC::Algorithm::RSASHA256: {
auto const prefix = rsa_prefix_for(Crypto::Hash::HashKind::SHA256);
auto n = Crypto::UnsignedBigInteger::import_data(dnskey.public_key_rsa_modulus());
auto e = Crypto::UnsignedBigInteger::import_data(dnskey.public_key_rsa_exponent());
Crypto::PK::RSA_PSS_EMSA rsa { Crypto::Hash::HashKind::SHA256, Crypto::PK::RSAPublicKey { move(n), move(e) } };
auto digest = Crypto::Hash::SHA256::hash(to_be_signed);
ByteBuffer prefixed_digest;
TRY_OR_REJECT_PROMISE(promise, prefixed_digest.try_ensure_capacity(prefix.size() + digest.data_length()));
prefixed_digest.append(prefix);
prefixed_digest.append(digest.bytes());
if (auto ok = TRY_OR_REJECT_PROMISE(promise, rsa.verify(prefixed_digest.bytes(), rrsig.signature)); !ok) {
Crypto::PK::RSA_PKCS1_EMSA rsa { Crypto::Hash::HashKind::SHA256, Crypto::PK::RSAPublicKey { move(n), move(e) } };
if (auto ok = TRY_OR_REJECT_PROMISE(promise, rsa.verify(to_be_signed, rrsig.signature)); !ok) {
promise->reject(Error::from_string_literal("RSA/SHA256 signature validation failed"));
return promise;
}
@ -953,14 +1052,10 @@ private:
}
case Messages::DNSSEC::Algorithm::DSA:
case Messages::DNSSEC::Algorithm::RSASHA1NSEC3SHA1:
// Not implemented here
dbgln("Not implemented: DNSSEC algorithm {}", to_string(dnskey.algorithm));
break;
// Not implemented yet.
case Messages::DNSSEC::Algorithm::Unknown:
dbgln("DNS: Unsupported algorithm for DNSSEC validation: {}",
to_string(dnskey.algorithm));
promise->reject(
Error::from_string_literal("Unsupported algorithm for DNSSEC validation"));
dbgln("DNS: Unsupported algorithm for DNSSEC validation: {}", to_string(dnskey.algorithm));
promise->reject(Error::from_string_literal("Unsupported algorithm for DNSSEC validation"));
break;
}
@ -977,8 +1072,6 @@ private:
return promise;
}
#undef TRY_OR_REJECT_PROMISE
bool has_connection(bool attempt_restart = true)
{
auto result = m_socket.with_read_locked(
@ -1040,3 +1133,5 @@ private:
Vector<NonnullRefPtr<Core::Promise<Empty>>> m_socket_ready_promises;
};
}
#undef TRY_OR_REJECT_PROMISE

View file

@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <LibCore/ElapsedTimer.h>
#include <LibCore/Promise.h>
#include <LibCrypto/OpenSSL.h>
#include <LibTLS/TLSv12.h>
@ -137,8 +138,17 @@ ErrorOr<bool> TLSv12::can_read_without_blocking(int timeout) const
if (SSL_has_pending(m_ssl))
return true;
if (timeout > 0)
return TRY(m_socket->can_read_without_blocking(timeout)) && SSL_has_pending(m_ssl);
while (timeout > 0) {
auto timer = Core::ElapsedTimer();
if (!TRY(m_socket->can_read_without_blocking(timeout)))
return SSL_has_pending(m_ssl);
if (SSL_has_pending(m_ssl))
return true;
auto elapsed = timer.elapsed_milliseconds();
if (elapsed >= timeout)
break;
timeout -= elapsed;
}
return false;
}

View file

@ -55,7 +55,7 @@ public:
virtual void close() override;
virtual ErrorOr<size_t> pending_bytes() const override;
virtual ErrorOr<bool> can_read_without_blocking(int = 0) const override;
virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override;
virtual ErrorOr<void> set_blocking(bool block) override;
virtual ErrorOr<void> set_close_on_exec(bool enabled) override;