diff --git a/AK/CountingStream.cpp b/AK/CountingStream.cpp index c6074b25aa3..a7d9bc4a14e 100644 --- a/AK/CountingStream.cpp +++ b/AK/CountingStream.cpp @@ -8,8 +8,9 @@ namespace AK { -CountingStream::CountingStream(MaybeOwned stream) +CountingStream::CountingStream(MaybeOwned stream, size_t offset) : m_stream(move(stream)) + , m_read_bytes(offset) { } diff --git a/AK/CountingStream.h b/AK/CountingStream.h index 3afd1f4b48e..87cf4bec985 100644 --- a/AK/CountingStream.h +++ b/AK/CountingStream.h @@ -13,7 +13,7 @@ namespace AK { class CountingStream : public Stream { public: - CountingStream(MaybeOwned); + CountingStream(MaybeOwned, size_t offset = 0); u64 read_bytes() const; diff --git a/Libraries/LibDNS/Message.cpp b/Libraries/LibDNS/Message.cpp index c183b64e5e3..b5643f41bc9 100644 --- a/Libraries/LibDNS/Message.cpp +++ b/Libraries/LibDNS/Message.cpp @@ -663,11 +663,17 @@ ErrorOr DomainName::from_raw(ParseContext& ctx) constexpr static u8 OffsetMarkerMask = 0b11000000; if ((length & OffsetMarkerMask) == OffsetMarkerMask) { // This is a pointer to a prior domain name. - u16 const offset = static_cast(length & ~OffsetMarkerMask) << 8 | TRY(ctx.stream.read_value()); + u16 offset = static_cast(length & ~OffsetMarkerMask) << 8 | TRY(ctx.stream.read_value()); if (auto it = ctx.pointers->find_largest_not_above_iterator(offset); !it.is_end()) { auto labels = it->labels; - for (auto& entry : labels) - name.labels.append(entry); + size_t start_index = 0; + size_t start_entry_offset = offset - it.key(); + while (start_entry_offset > 0 && start_index < labels.size()) { + start_entry_offset -= labels[start_index].length() + 1; // +1 for the length byte + start_index++; + } + for (size_t i = start_index; i < labels.size(); ++i) + name.labels.append(labels[i].substring_view(i == start_index ? start_entry_offset : 0)); break; } dbgln("Invalid domain name pointer in label, no prior domain name found around offset {}", offset); @@ -786,9 +792,10 @@ ErrorOr ResourceRecord::from_raw(ParseContext& ctx) ResourceType type; Class class_; u32 ttl; + size_t original_offset = ctx.stream.read_bytes(); { RecordingStream rr_stream { ctx.stream }; - CountingStream rr_counting_stream { MaybeOwned(rr_stream) }; + CountingStream rr_counting_stream { MaybeOwned(rr_stream), original_offset }; ParseContext rr_ctx { rr_counting_stream, move(ctx.pointers) }; ScopeGuard guard([&] { ctx.pointers = move(rr_ctx.pointers); }); @@ -809,13 +816,14 @@ ErrorOr ResourceRecord::from_raw(ParseContext& ctx) class_ = static_cast(static_cast(TRY(rr_ctx.stream.read_value>()))); ttl = static_cast(TRY(rr_ctx.stream.read_value>())); auto rd_length = static_cast(TRY(rr_ctx.stream.read_value>())); + original_offset = rr_ctx.stream.read_bytes(); TRY(rr_ctx.stream.read_until_filled(TRY(rdata.get_bytes_for_writing(rd_length)))); rr_raw_data = move(rr_stream).take_recorded_data(); } FixedMemoryStream stream { rdata.bytes() }; - CountingStream rdata_stream { MaybeOwned(stream) }; + CountingStream rdata_stream { MaybeOwned(stream), original_offset }; ParseContext rdata_ctx { rdata_stream, move(ctx.pointers) }; ScopeGuard guard([&] { ctx.pointers = move(rdata_ctx.pointers); }); @@ -1027,6 +1035,24 @@ ErrorOr Records::SOA::from_raw(ParseContext& ctx) return Records::SOA { move(mname), move(rname), serial, refresh, retry, expire, minimum }; } +ErrorOr Records::SOA::to_raw(ByteBuffer& buffer) const +{ + TRY(mname.to_raw(buffer)); + TRY(rname.to_raw(buffer)); + + auto const output_size = 5 * sizeof(u32); + FixedMemoryStream stream { TRY(buffer.get_bytes_for_writing(output_size)) }; + + TRY(stream.write_value(static_cast>(serial))); + TRY(stream.write_value(static_cast>(refresh))); + TRY(stream.write_value(static_cast>(retry))); + TRY(stream.write_value(static_cast>(expire))); + TRY(stream.write_value(static_cast>(minimum))); + + return {}; +} + + ErrorOr Records::MX::from_raw(ParseContext& ctx) { // RFC 1035, 3.3.9. MX RDATA format. diff --git a/Libraries/LibDNS/Message.h b/Libraries/LibDNS/Message.h index 45fe1d1afd9..0f23fefe4b6 100644 --- a/Libraries/LibDNS/Message.h +++ b/Libraries/LibDNS/Message.h @@ -98,6 +98,9 @@ struct DomainName { copy.labels.take_first(); return copy; } + + bool operator==(DomainName const&) const& = default; + bool operator!=(DomainName const&) const& = default; }; // Listing from IANA https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4. @@ -419,7 +422,7 @@ struct SOA { static constexpr ResourceType type = ResourceType::SOA; static ErrorOr from_raw(ParseContext&); - ErrorOr to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented: SOA::to_raw"); } + ErrorOr to_raw(ByteBuffer&) const; ErrorOr to_string() const { return String::formatted("SOA MName: '{}', RName: '{}', Serial: {}, Refresh: {}, Retry: {}, Expire: {}, Minimum: {}", mname.to_string(), rname.to_string(), serial, refresh, retry, expire, minimum); diff --git a/Libraries/LibDNS/Resolver.h b/Libraries/LibDNS/Resolver.h index 3d30eff128b..ab96d991d99 100644 --- a/Libraries/LibDNS/Resolver.h +++ b/Libraries/LibDNS/Resolver.h @@ -28,6 +28,16 @@ #undef DNS_DEBUG #define DNS_DEBUG 1 +#define TRY_OR_REJECT_PROMISE(promise, expr) \ + ({ \ + auto _result = (expr); \ + if (_result.is_error()) { \ + promise->reject(_result.release_error()); \ + return promise; \ + } \ + _result.release_value(); \ + }) + namespace DNS { class Resolver; @@ -92,6 +102,35 @@ public: return result; } + Vector records(Messages::ResourceType type) const + { + Vector result; + for (auto& re : m_cached_records) { + if (re.record.type == type) + result.append(re.record); + } + return result; + } + + Messages::ResourceRecord const& record(Messages::ResourceType type) const + { + for (auto const& re : m_cached_records) { + if (re.record.type == type) + return re.record; + } + VERIFY_NOT_REACHED(); + } + + template + RR const& record() const + { + for (auto const& re : m_cached_records) { + if (re.record.type == RR::type) + return re.record.record.get(); + } + VERIFY_NOT_REACHED(); + } + bool has_record_of_type(Messages::ResourceType type, bool later = false) const { if (later && m_desired_types.contains(type)) @@ -364,7 +403,8 @@ public: dbgln_if(DNS_DEBUG, "DNS: Adding {} to cache", name); auto ptr = make_ref_counted(domain_name); - ptr->set_dnssec_validated(options.validate_dnssec_locally); + if (!ptr->is_dnssec_validated()) + ptr->set_dnssec_validated(options.validate_dnssec_locally); for (auto const& type : desired_types) ptr->will_add_record_of_type(type); cache.set(name, ptr); @@ -580,6 +620,66 @@ private: Vector dnskeys; }; + NonnullRefPtr> validate_dnssec_chain_step(Messages::DomainName const& name, bool top_level = false) + { + dbgln_if(DNS_DEBUG, "DNS: Validating DNSSEC chain for {}", name.to_string()); + auto promise = Core::Promise::construct(); + // - If this is the root, we're done, just return true. + if (name.labels.size() == 0) { + promise->resolve(true); + return promise; + } + + // - Lookup the SOA record for the domain. + auto soa_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::SOA }, { .validate_dnssec_locally = !top_level })->await())); + // - If we have no SOA record- + if (!soa_result->has_record_of_type(Messages::ResourceType::SOA)) { + dbgln_if(DNS_DEBUG, "DNS: No SOA record found for {}", name.to_string()); + // - First, check for a DS record- + auto ds_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::DS }, { .validate_dnssec_locally = !top_level })->await())); + // - If there's no DS record, check for an NS record- + if (!ds_result->has_record_of_type(Messages::ResourceType::DS)) { + dbgln_if(DNS_DEBUG, "DNS: No DS record found for {}", name.to_string()); + // - If there's no DS record, check for an NS record- + auto ns_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::NS }, { .validate_dnssec_locally = !top_level })->await())); + if (ns_result->has_record_of_type(Messages::ResourceType::NS)) { + // - but if there _is_ an NS record, this is a broken delegation, so reject. + dbgln_if(DNS_DEBUG, "DNS: Found NS record for {}", name.to_string()); + promise->resolve(false); + return promise; + } + dbgln_if(DNS_DEBUG, "DNS: No NS record found for {}", name.to_string()); + // this is just part of the parent delegation, so go up one level. + return validate_dnssec_chain_step(name.parent()); + } + // - If there is a DS record, this is a separate zone...but since we don't have an SOA record, this is a misconfigured zone. + // Let's just reject. + dbgln_if(DNS_DEBUG, "DNS: Found DS record for {}", name.to_string()); + promise->resolve(false); + return promise; + } + + // So we have an SOA record, there's much rejoicing and we can continue. + auto& soa = soa_result->record(); + dbgln_if(DNS_DEBUG, "DNS: Found SOA record for {}: {}", name.to_string(), soa.mname.to_string()); + if (soa.mname == name.parent()) { + // Just go up one level, all is well. + return validate_dnssec_chain_step(name.parent()); + } + + // So this is a separate zone, let's look up the DS record. + auto ds_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::DS }, { .validate_dnssec_locally = false })->await())); + if (!ds_result->has_record_of_type(Messages::ResourceType::DS)) { + // If there's no DS record, this is a misconfigured zone. + dbgln_if(DNS_DEBUG, "DNS: No DS record found for {}", name.to_string()); + promise->resolve(false); + return promise; + } + + promise->resolve(true); + return promise; + } + ErrorOr validate_dnssec(Messages::Message message, PendingLookup& lookup, NonnullRefPtr result) { struct RecordAndRRSIG { @@ -588,7 +688,6 @@ private: }; HashMap records_with_rrsigs; for (auto& record : message.answers) { - dbgln("- {}", record.to_string()); if (record.type == Messages::ResourceType::RRSIG) { auto& rrsig = record.record.get(); auto type = rrsig.type_covered; @@ -621,6 +720,15 @@ private: auto is_root_zone = lookup.parsed_name.labels.size() == 0; if (!is_root_zone) { + auto chain_valid_result = validate_dnssec_chain_step(name, true)->await(); + if (chain_valid_result.is_error()) { + lookup.promise->reject(chain_valid_result.release_error()); + return; + } + if (!chain_valid_result.value()) { + lookup.promise->reject(Error::from_string_literal("DNSSEC chain is invalid")); + return; + } auto parent_result = this->lookup(lookup.parsed_name.parent().to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::DNSKEY }, { .validate_dnssec_locally = true }) ->await(); if (parent_result.is_error()) { @@ -633,7 +741,10 @@ private: return; } - parent_zone_keys = parent_result.release_value()->used_dnskeys(); + parent_zone_keys = parent_result.value()->used_dnskeys(); + for (auto& rr : parent_result.value()->records(Messages::ResourceType::DNSKEY)) + parent_zone_keys.append(rr.record.get()); + dbgln("Found {} DNSKEYs for parent zone ({})", parent_zone_keys.size(), lookup.parsed_name.parent().to_string()); } auto resolve_using_keys = [=, this, records_with_rrsigs = move(records_with_rrsigs)](Vector keys) mutable { @@ -708,29 +819,38 @@ private: lookup.promise = move(promise); }; if (is_root_zone) { - return resolve_using_keys(Vector { Messages::Records::DNSKEY { - .flags = 256, - .protocol = 3, - .algorithm = Messages::DNSSEC::Algorithm::RSASHA256, - .public_key = MUST(decode_base64("AwEAAa96jeuknZlaeSrvyAJj6ZHv28hhOKkx3rLGXVaC6rXTsDc449/cidltpkyGwCJNnOAlFNKF2jBosZBU5eeHspaQWOmOElZsjICMQMC3aeHbGiShvZsx4wMYSjH8e7Vrhbu6irwCzVBApESjbUdpWWmEnhathWu1jo+siFUiRAAxm9qyJNg/wOZqqzL/dL/q8PkcRU5oUKEpUge71M3ej2/7CPqpdVwuMoTvoB+ZOT4YeGyxMvHmbrxlFzGOHOijtzN+u1TQNatX2XBuzZNQ1K+s2CXkPIZo7s6JgZyvaBevYtxPvYLw4z9mR7K2vaF18UYH9Z9GNUUeayffKC73PYc="sv)), - .calculated_key_tag = 38696, - } }); + resolve_using_keys(Vector { + { + .flags = 257, + .protocol = 3, + .algorithm = Messages::DNSSEC::Algorithm::RSASHA256, + .public_key = MUST(decode_base64("AwEAAaz/tAm8yTn4Mfeh5eyI96WSVexTBAvkMgJzkKTOiW1vkIbzxeF3+/4RgWOq7HrxRixHlFlExOLAJr5emLvN7SWXgnLh4+B5xQlNVz8Og8kvArMtNROxVQuCaSnIDdD5LKyWbRd2n9WGe2R8PzgCmr3EgVLrjyBxWezF0jLHwVN8efS3rCj/EWgvIWgb9tarpVUDK/b58Da+sqqls3eNbuv7pr+eoZG+SrDK6nWeL3c6H5Apxz7LjVc1uTIdsIXxuOLYA4/ilBmSVIzuDWfdRUfhHdY6+cn8HFRm+2hM8AnXGXws9555KrUB5qihylGa8subX2Nn6UwNR1AkUTV74bU="sv)), + .calculated_key_tag = 20326, + }, + { + .flags = 256, + .protocol = 3, + .algorithm = Messages::DNSSEC::Algorithm::RSASHA256, + .public_key = MUST(decode_base64("AwEAAa96jeuknZlaeSrvyAJj6ZHv28hhOKkx3rLGXVaC6rXTsDc449/cidltpkyGwCJNnOAlFNKF2jBosZBU5eeHspaQWOmOElZsjICMQMC3aeHbGiShvZsx4wMYSjH8e7Vrhbu6irwCzVBApESjbUdpWWmEnhathWu1jo+siFUiRAAxm9qyJNg/wOZqqzL/dL/q8PkcRU5oUKEpUge71M3ej2/7CPqpdVwuMoTvoB+ZOT4YeGyxMvHmbrxlFzGOHOijtzN+u1TQNatX2XBuzZNQ1K+s2CXkPIZo7s6JgZyvaBevYtxPvYLw4z9mR7K2vaF18UYH9Z9GNUUeayffKC73PYc="sv)), + .calculated_key_tag = 38696, + }, + }); + return; } dbgln_if(DNS_DEBUG, "DNS: Starting DNSKEY lookup for {}", lookup.name); this->lookup(lookup.name, Messages::Class::IN, { Messages::ResourceType::DNSKEY }, { .validate_dnssec_locally = false }) ->when_resolved([=](NonnullRefPtr& dnskey_lookup_result) mutable { dbgln_if(DNS_DEBUG, "DNSKEY for {}:", name.to_string()); - for (auto& record : dnskey_lookup_result->records()) + auto key_records = dnskey_lookup_result->records(Messages::ResourceType::DNSKEY); + for (auto& record : key_records) dbgln_if(DNS_DEBUG, "- DNSKEY: {}", record.to_string()); Vector keys; keys.ensure_capacity(parent_zone_keys.size() + dnskey_lookup_result->records().size()); for (auto& record : parent_zone_keys) keys.append(record); - for (auto& record : dnskey_lookup_result->records()) { - if (auto k = record.record.get_pointer()) - keys.append(move(*k)); - } + for (auto& record : key_records) + keys.append(move(record.record).get()); resolve_using_keys(move(keys)); }) .when_rejected([=](auto& error) mutable { @@ -755,16 +875,6 @@ private: return nullptr; } -#define TRY_OR_REJECT_PROMISE(promise, expr) \ - ({ \ - auto _result = (expr); \ - if (_result.is_error()) { \ - promise->reject(_result.release_error()); \ - return promise; \ - } \ - _result.release_value(); \ - }) - NonnullRefPtr> validate_rrset_with_rrsig(CanonicalizedRRSetWithRRSIG rrset_with_rrsig, NonnullRefPtr result) { auto promise = Core::Promise::construct(); @@ -788,29 +898,25 @@ private: for (auto& rr : canon_encoded_rrs) canon_encoded.append(rr); - if (result->name().labels.size() == 1) { - dbgln_if(DNS_DEBUG, "Root zone, implicitly trusting DNSKEY for {}", result->name().to_string()); - promise->resolve({}); - return promise; - } - auto& dnskey = *find_dnskey(rrset_with_rrsig); - dbgln_if(DNS_DEBUG, "Validating RRSet with RRSIG for {}", result->name().to_string()); - for (auto& rr : rrset_with_rrsig.rrset) - dbgln_if(DNS_DEBUG, "- RR {}", rr.to_string()); - for (auto& canon : canon_encoded_rrs) { - FixedMemoryStream stream { canon.bytes() }; - CountingStream rr_counting_stream { MaybeOwned(stream) }; - DNS::Messages::ParseContext rr_ctx { rr_counting_stream, make>() }; - auto maybe_decoded = Messages::ResourceRecord::from_raw(rr_ctx); - if (maybe_decoded.is_error()) - dbgln("-- Failed to decode RR: {}", maybe_decoded.error()); - else - dbgln("-- Canon encoded (decoded): {}", maybe_decoded.value().to_string()); + if constexpr (DNS_DEBUG) { + dbgln("Validating RRSet with RRSIG for {}", result->name().to_string()); + for (auto& rr : rrset_with_rrsig.rrset) + dbgln("- RR {}", rr.to_string()); + for (auto& canon : canon_encoded_rrs) { + FixedMemoryStream stream { canon.bytes() }; + CountingStream rr_counting_stream { MaybeOwned(stream) }; + DNS::Messages::ParseContext rr_ctx { rr_counting_stream, make>() }; + auto maybe_decoded = Messages::ResourceRecord::from_raw(rr_ctx); + if (maybe_decoded.is_error()) + dbgln("-- Failed to decode RR: {}", maybe_decoded.error()); + else + dbgln("-- Canon encoded (decoded): {}", maybe_decoded.value().to_string()); + } + dbgln( "- DNSKEY {}", dnskey.to_string()); + dbgln("- RRSIG {}", rrsig.to_string()); } - dbgln_if(DNS_DEBUG, "- DNSKEY {}", dnskey.to_string()); - dbgln_if(DNS_DEBUG, "- RRSIG {}", rrsig.to_string()); ByteBuffer to_be_signed; { @@ -852,15 +958,15 @@ private: TRY_OR_REJECT_PROMISE(promise, rrsig.signers_name.to_raw(to_be_signed)); TRY_OR_REJECT_PROMISE(promise, to_be_signed.try_append(canon_encoded.data(), canon_encoded.size())); - dbgln("To be signed: {:hex-dump}", to_be_signed.bytes()); + dbgln_if(DNS_DEBUG, "To be signed: {:hex-dump}", to_be_signed.bytes()); - constexpr auto rsa_prefix_for = [](Crypto::Hash::HashKind kind) -> ReadonlyBytes { - if (kind == Crypto::Hash::HashKind::SHA256) - return "\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20"sv.bytes(); - if (kind == Crypto::Hash::HashKind::SHA512) - return "\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40"sv.bytes(); - VERIFY_NOT_REACHED(); - }; + // constexpr auto rsa_prefix_for = [](Crypto::Hash::HashKind kind) -> ReadonlyBytes { + // if (kind == Crypto::Hash::HashKind::SHA256) + // return "\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20"sv.bytes(); + // if (kind == Crypto::Hash::HashKind::SHA512) + // return "\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40"sv.bytes(); + // VERIFY_NOT_REACHED(); + // }; switch (dnskey.algorithm) { case Messages::DNSSEC::Algorithm::RSAMD5: { @@ -927,17 +1033,10 @@ private: break; } case Messages::DNSSEC::Algorithm::RSASHA256: { - auto const prefix = rsa_prefix_for(Crypto::Hash::HashKind::SHA256); auto n = Crypto::UnsignedBigInteger::import_data(dnskey.public_key_rsa_modulus()); auto e = Crypto::UnsignedBigInteger::import_data(dnskey.public_key_rsa_exponent()); - Crypto::PK::RSA_PSS_EMSA rsa { Crypto::Hash::HashKind::SHA256, Crypto::PK::RSAPublicKey { move(n), move(e) } }; - auto digest = Crypto::Hash::SHA256::hash(to_be_signed); - ByteBuffer prefixed_digest; - TRY_OR_REJECT_PROMISE(promise, prefixed_digest.try_ensure_capacity(prefix.size() + digest.data_length())); - prefixed_digest.append(prefix); - prefixed_digest.append(digest.bytes()); - - if (auto ok = TRY_OR_REJECT_PROMISE(promise, rsa.verify(prefixed_digest.bytes(), rrsig.signature)); !ok) { + Crypto::PK::RSA_PKCS1_EMSA rsa { Crypto::Hash::HashKind::SHA256, Crypto::PK::RSAPublicKey { move(n), move(e) } }; + if (auto ok = TRY_OR_REJECT_PROMISE(promise, rsa.verify(to_be_signed, rrsig.signature)); !ok) { promise->reject(Error::from_string_literal("RSA/SHA256 signature validation failed")); return promise; } @@ -953,14 +1052,10 @@ private: } case Messages::DNSSEC::Algorithm::DSA: case Messages::DNSSEC::Algorithm::RSASHA1NSEC3SHA1: - // Not implemented here - dbgln("Not implemented: DNSSEC algorithm {}", to_string(dnskey.algorithm)); - break; + // Not implemented yet. case Messages::DNSSEC::Algorithm::Unknown: - dbgln("DNS: Unsupported algorithm for DNSSEC validation: {}", - to_string(dnskey.algorithm)); - promise->reject( - Error::from_string_literal("Unsupported algorithm for DNSSEC validation")); + dbgln("DNS: Unsupported algorithm for DNSSEC validation: {}", to_string(dnskey.algorithm)); + promise->reject(Error::from_string_literal("Unsupported algorithm for DNSSEC validation")); break; } @@ -977,8 +1072,6 @@ private: return promise; } -#undef TRY_OR_REJECT_PROMISE - bool has_connection(bool attempt_restart = true) { auto result = m_socket.with_read_locked( @@ -1040,3 +1133,5 @@ private: Vector>> m_socket_ready_promises; }; } + +#undef TRY_OR_REJECT_PROMISE diff --git a/Libraries/LibTLS/TLSv12.cpp b/Libraries/LibTLS/TLSv12.cpp index b263cdf4a13..91b5c2c73db 100644 --- a/Libraries/LibTLS/TLSv12.cpp +++ b/Libraries/LibTLS/TLSv12.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include #include #include @@ -137,8 +138,17 @@ ErrorOr TLSv12::can_read_without_blocking(int timeout) const if (SSL_has_pending(m_ssl)) return true; - if (timeout > 0) - return TRY(m_socket->can_read_without_blocking(timeout)) && SSL_has_pending(m_ssl); + while (timeout > 0) { + auto timer = Core::ElapsedTimer(); + if (!TRY(m_socket->can_read_without_blocking(timeout))) + return SSL_has_pending(m_ssl); + if (SSL_has_pending(m_ssl)) + return true; + auto elapsed = timer.elapsed_milliseconds(); + if (elapsed >= timeout) + break; + timeout -= elapsed; + } return false; } diff --git a/Libraries/LibTLS/TLSv12.h b/Libraries/LibTLS/TLSv12.h index fdd7847d417..b8553c45007 100644 --- a/Libraries/LibTLS/TLSv12.h +++ b/Libraries/LibTLS/TLSv12.h @@ -55,7 +55,7 @@ public: virtual void close() override; virtual ErrorOr pending_bytes() const override; - virtual ErrorOr can_read_without_blocking(int = 0) const override; + virtual ErrorOr can_read_without_blocking(int timeout = 0) const override; virtual ErrorOr set_blocking(bool block) override; virtual ErrorOr set_close_on_exec(bool enabled) override;