This commit is contained in:
Ali Mohammad Pur 2025-04-05 17:20:46 +00:00 committed by GitHub
commit df07e14fe9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 957 additions and 80 deletions

View file

@ -8,8 +8,9 @@
namespace AK {
CountingStream::CountingStream(MaybeOwned<Stream> stream)
CountingStream::CountingStream(MaybeOwned<Stream> stream, size_t offset)
: m_stream(move(stream))
, m_read_bytes(offset)
{
}

View file

@ -13,7 +13,7 @@ namespace AK {
class CountingStream : public Stream {
public:
CountingStream(MaybeOwned<Stream>);
CountingStream(MaybeOwned<Stream>, size_t offset = 0);
u64 read_bytes() const;

View file

@ -24,6 +24,42 @@ public:
Function<ErrorOr<void>(Result&)> on_resolution;
Function<void(ErrorType&)> on_rejection;
static NonnullRefPtr<Promise> after(Vector<NonnullRefPtr<Promise>>&& promises)
{
auto promise = Promise::construct();
struct Resolved : RefCounted<Resolved> {
explicit Resolved(size_t n)
: needed(n)
{
}
size_t count { 0 };
size_t needed { 0 };
};
auto resolved = make_ref_counted<Resolved>(promises.size());
auto weak_promise = promise->template make_weak_ptr<Promise>();
for (auto p : promises) {
p->when_resolved([weak_promise, resolved](Result&) -> ErrorOr<void> {
if (weak_promise->is_rejected())
return {};
if (++resolved->count == resolved->needed)
weak_promise->resolve({});
return {};
});
p->when_rejected([weak_promise, resolved](ErrorType& error) {
++resolved->count;
weak_promise->reject(move(error));
});
promise->add_child(*p);
}
return promise;
}
void resolve(Result&& result)
{
m_result_or_rejection = move(result);

View file

@ -110,6 +110,29 @@ struct SECPxxxr1Signature {
return SECPxxxr1Signature { r_big_int, s_big_int, scalar_size };
}
static ErrorOr<SECPxxxr1Signature> from_raw(Span<int const> curve_oid, ReadonlyBytes signature)
{
size_t scalar_size;
if (curve_oid == ASN1::secp256r1_oid) {
scalar_size = ceil_div(256, 8);
} else if (curve_oid == ASN1::secp384r1_oid) {
scalar_size = ceil_div(384, 8);
} else if (curve_oid == ASN1::secp521r1_oid) {
scalar_size = ceil_div(521, 8);
} else {
return Error::from_string_literal("Unknown SECPxxxr1 curve");
}
if (signature.size() != scalar_size * 2)
return Error::from_string_literal("Invalid SECPxxxr1 signature");
return SECPxxxr1Signature {
UnsignedBigInteger::import_data(signature.slice(0, scalar_size)),
UnsignedBigInteger::import_data(signature.slice(scalar_size, scalar_size)),
scalar_size,
};
}
ErrorOr<ByteBuffer> r_bytes() const
{
return SECPxxxr1Point::scalar_to_bytes(r, size);

View file

@ -61,9 +61,19 @@ static ErrorOr<ECPublicKey<>> read_ec_public_key(ReadonlyBytes bytes, Vector<Str
UnsignedBigInteger::import_data(bytes.slice(1 + half_size, half_size)),
half_size,
};
} else {
ERROR_WITH_SCOPE("Unsupported public key format");
}
if (bytes.size() % 2 == 0) {
// Raw public key, without the 0x04 prefix
auto half_size = bytes.size() / 2;
return ::Crypto::PK::ECPublicKey<> {
UnsignedBigInteger::import_data(bytes.slice(0, half_size)),
UnsignedBigInteger::import_data(bytes.slice(half_size, half_size)),
half_size,
};
}
ERROR_WITH_SCOPE("Unsupported public key format");
}
// https://www.rfc-editor.org/rfc/rfc5915#section-3

View file

@ -3,4 +3,4 @@ set(SOURCES
)
serenity_lib(LibDNS dns)
target_link_libraries(LibDNS PRIVATE LibCore)
target_link_libraries(LibDNS PRIVATE LibCore PUBLIC LibCrypto)

View file

@ -4,6 +4,7 @@
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/ByteReader.h>
#include <AK/CountingStream.h>
#include <AK/MemoryStream.h>
#include <AK/Stream.h>
@ -662,11 +663,17 @@ ErrorOr<DomainName> DomainName::from_raw(ParseContext& ctx)
constexpr static u8 OffsetMarkerMask = 0b11000000;
if ((length & OffsetMarkerMask) == OffsetMarkerMask) {
// This is a pointer to a prior domain name.
u16 const offset = static_cast<u16>(length & ~OffsetMarkerMask) << 8 | TRY(ctx.stream.read_value<u8>());
u16 offset = static_cast<u16>(length & ~OffsetMarkerMask) << 8 | TRY(ctx.stream.read_value<u8>());
if (auto it = ctx.pointers->find_largest_not_above_iterator(offset); !it.is_end()) {
auto labels = it->labels;
for (auto& entry : labels)
name.labels.append(entry);
size_t start_index = 0;
size_t start_entry_offset = offset - it.key();
while (start_entry_offset > 0 && start_index < labels.size()) {
start_entry_offset -= labels[start_index].length() + 1; // +1 for the length byte
start_index++;
}
for (size_t i = start_index; i < labels.size(); ++i)
name.labels.append(labels[i].substring_view(i == start_index ? start_entry_offset : 0));
break;
}
dbgln("Invalid domain name pointer in label, no prior domain name found around offset {}", offset);
@ -702,6 +709,9 @@ ErrorOr<void> DomainName::to_raw(ByteBuffer& out) const
String DomainName::to_string() const
{
if (labels.is_empty())
return "."_string;
StringBuilder builder;
for (size_t i = 0; i < labels.size(); ++i) {
builder.append(labels[i]);
@ -711,6 +721,26 @@ String DomainName::to_string() const
return MUST(builder.to_string());
}
String DomainName::to_canonical_string() const
{
if (labels.is_empty())
return "."_string;
StringBuilder builder;
for (size_t i = 0; i < labels.size(); ++i) {
auto& label = labels[i];
for (size_t j = 0; j < label.length(); ++j) {
auto ch = label[j];
if (ch >= 'A' && ch <= 'Z')
ch = to_ascii_lowercase(ch);
builder.append(ch);
}
builder.append('.');
}
return MUST(builder.to_string());
}
class RecordingStream final : public Stream {
public:
explicit RecordingStream(Stream& stream)
@ -762,9 +792,10 @@ ErrorOr<ResourceRecord> ResourceRecord::from_raw(ParseContext& ctx)
ResourceType type;
Class class_;
u32 ttl;
size_t original_offset = ctx.stream.read_bytes();
{
RecordingStream rr_stream { ctx.stream };
CountingStream rr_counting_stream { MaybeOwned<Stream>(rr_stream) };
CountingStream rr_counting_stream { MaybeOwned<Stream>(rr_stream), original_offset };
ParseContext rr_ctx { rr_counting_stream, move(ctx.pointers) };
ScopeGuard guard([&] { ctx.pointers = move(rr_ctx.pointers); });
@ -785,13 +816,14 @@ ErrorOr<ResourceRecord> ResourceRecord::from_raw(ParseContext& ctx)
class_ = static_cast<Class>(static_cast<u16>(TRY(rr_ctx.stream.read_value<NetworkOrdered<u16>>())));
ttl = static_cast<u32>(TRY(rr_ctx.stream.read_value<NetworkOrdered<u32>>()));
auto rd_length = static_cast<u16>(TRY(rr_ctx.stream.read_value<NetworkOrdered<u16>>()));
original_offset = rr_ctx.stream.read_bytes();
TRY(rr_ctx.stream.read_until_filled(TRY(rdata.get_bytes_for_writing(rd_length))));
rr_raw_data = move(rr_stream).take_recorded_data();
}
FixedMemoryStream stream { rdata.bytes() };
CountingStream rdata_stream { MaybeOwned<Stream>(stream) };
CountingStream rdata_stream { MaybeOwned<Stream>(stream), original_offset };
ParseContext rdata_ctx { rdata_stream, move(ctx.pointers) };
ScopeGuard guard([&] { ctx.pointers = move(rdata_ctx.pointers); });
@ -887,9 +919,11 @@ ErrorOr<void> ResourceRecord::to_raw(ByteBuffer& buffer) const
ErrorOr<String> ResourceRecord::to_string() const
{
StringBuilder builder;
builder.appendff("[{} {} ", Messages::to_string(class_), Messages::to_string(type));
record.visit(
[&](auto const& record) { builder.appendff("{}", MUST(record.to_string())); },
[&](ByteBuffer const& raw) { builder.appendff("{:hex-dump}", raw.bytes()); });
builder.appendff(" | ttl={}, name={}]", ttl, name.to_string());
return builder.to_string();
}
@ -902,6 +936,16 @@ ErrorOr<Records::A> Records::A::from_raw(ParseContext& ctx)
return Records::A { IPv4Address { address } };
}
ErrorOr<void> Records::A::to_raw(ByteBuffer& buffer) const
{
auto const address = this->address.to_u32();
auto const net_address = bit_cast<NetworkOrdered<u32>>(address);
auto bytes = TRY(buffer.get_bytes_for_writing(sizeof(net_address)));
bytes.overwrite(0, &net_address, sizeof(net_address));
return {};
}
ErrorOr<Records::AAAA> Records::AAAA::from_raw(ParseContext& ctx)
{
// RFC 3596, 2.2. AAAA RDATA format.
@ -911,6 +955,18 @@ ErrorOr<Records::AAAA> Records::AAAA::from_raw(ParseContext& ctx)
return Records::AAAA { IPv6Address { bit_cast<Array<u8, 16>>(address) } };
}
ErrorOr<void> Records::AAAA::to_raw(ByteBuffer& buffer) const
{
auto const* const address_bytes = this->address.to_in6_addr_t();
u128 address {};
memcpy(&address, address_bytes, sizeof(address));
auto const net_address = bit_cast<NetworkOrdered<u128>>(address);
auto bytes = TRY(buffer.get_bytes_for_writing(sizeof(net_address)));
bytes.overwrite(0, &net_address, sizeof(net_address));
return {};
}
ErrorOr<Records::TXT> Records::TXT::from_raw(ParseContext& ctx)
{
// RFC 1035, 3.3.14. TXT RDATA format.
@ -922,6 +978,18 @@ ErrorOr<Records::TXT> Records::TXT::from_raw(ParseContext& ctx)
return Records::TXT { ByteString::copy(content) };
}
ErrorOr<void> Records::TXT::to_raw(ByteBuffer& buffer) const
{
auto const length = static_cast<u8>(content.length());
auto length_bytes = TRY(buffer.get_bytes_for_writing(1));
memcpy(length_bytes.data(), &length, 1);
auto content_bytes = TRY(buffer.get_bytes_for_writing(length));
memcpy(content_bytes.data(), content.characters(), length);
return {};
}
ErrorOr<Records::CNAME> Records::CNAME::from_raw(ParseContext& ctx)
{
// RFC 1035, 3.3.1. CNAME RDATA format.
@ -931,6 +999,11 @@ ErrorOr<Records::CNAME> Records::CNAME::from_raw(ParseContext& ctx)
return Records::CNAME { move(name) };
}
ErrorOr<void> Records::CNAME::to_raw(ByteBuffer& buffer) const
{
return names.to_raw(buffer);
}
ErrorOr<Records::NS> Records::NS::from_raw(ParseContext& ctx)
{
// RFC 1035, 3.3.11. NS RDATA format.
@ -962,6 +1035,24 @@ ErrorOr<Records::SOA> Records::SOA::from_raw(ParseContext& ctx)
return Records::SOA { move(mname), move(rname), serial, refresh, retry, expire, minimum };
}
ErrorOr<void> Records::SOA::to_raw(ByteBuffer& buffer) const
{
TRY(mname.to_raw(buffer));
TRY(rname.to_raw(buffer));
auto const output_size = 5 * sizeof(u32);
FixedMemoryStream stream { TRY(buffer.get_bytes_for_writing(output_size)) };
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(serial)));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(refresh)));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(retry)));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(expire)));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(minimum)));
return {};
}
ErrorOr<Records::MX> Records::MX::from_raw(ParseContext& ctx)
{
// RFC 1035, 3.3.9. MX RDATA format.
@ -1005,11 +1096,36 @@ ErrorOr<Records::DNSKEY> Records::DNSKEY::from_raw(ParseContext& ctx)
// | ALGORITHM| an 8-bit value that identifies the public key's cryptographic algorithm.
// | PUBLICKEY| the public key material.
u32 key_tag = 0;
auto flags = static_cast<u16>(TRY(ctx.stream.read_value<NetworkOrdered<u16>>()));
key_tag += (bit_cast<u16>(NetworkOrdered<u16>(flags)) & 0xff) << 8;
key_tag += (bit_cast<u16>(NetworkOrdered<u16>(flags)) >> 8) & 0xff;
auto protocol = TRY(ctx.stream.read_value<u8>());
key_tag += static_cast<u16>(protocol) << 8;
auto algorithm = static_cast<DNSSEC::Algorithm>(static_cast<u8>(TRY(ctx.stream.read_value<u8>())));
key_tag += static_cast<u16>(algorithm);
auto public_key = TRY(ctx.stream.read_until_eof());
return Records::DNSKEY { flags, protocol, algorithm, move(public_key) };
for (size_t i = 0; i < public_key.size(); ++i) {
key_tag += (i & 1) ? static_cast<u16>(public_key[i]) : static_cast<u16>(public_key[i]) << 8;
}
key_tag += (key_tag >> 16) & 0xffff;
if (public_key.is_empty())
return Error::from_string_literal("Empty public key in DNSKEY record");
return Records::DNSKEY { flags, protocol, algorithm, move(public_key), static_cast<u16>(key_tag & 0xffff) };
}
ErrorOr<void> Records::DNSKEY::to_raw(ByteBuffer& buffer) const
{
auto const output_size = 2 + 1 + 1 + public_key.size();
FixedMemoryStream stream { TRY(buffer.get_bytes_for_writing(output_size)) };
TRY(stream.write_value(static_cast<u16>(bit_cast<NetworkOrdered<u16>>(flags))));
TRY(stream.write_value(protocol));
TRY(stream.write_value(to_underlying(algorithm)));
TRY(stream.write_until_depleted(public_key.bytes()));
return {};
}
ErrorOr<Records::DS> Records::DS::from_raw(ParseContext& ctx)
@ -1051,6 +1167,19 @@ ErrorOr<Records::DS> Records::DS::from_raw(ParseContext& ctx)
return Records::DS { key_tag, algorithm, digest_type, move(digest) };
}
ErrorOr<void> Records::DS::to_raw(ByteBuffer& buffer) const
{
auto const output_size = 2 + 1 + 1 + digest.size();
FixedMemoryStream stream { TRY(buffer.get_bytes_for_writing(output_size)) };
TRY(stream.write_value(static_cast<NetworkOrdered<u16>>(key_tag)));
TRY(stream.write_value(static_cast<u8>(algorithm)));
TRY(stream.write_value(static_cast<u8>(digest_type)));
TRY(stream.write_until_depleted(digest.bytes()));
return {};
}
ErrorOr<Records::SIG> Records::SIG::from_raw(ParseContext& ctx)
{
// RFC 4034, 2.2. The SIG Resource Record.
@ -1077,6 +1206,30 @@ ErrorOr<Records::SIG> Records::SIG::from_raw(ParseContext& ctx)
return Records::SIG { type_covered, algorithm, labels, original_ttl, UnixDateTime::from_seconds_since_epoch(signature_expiration), UnixDateTime::from_seconds_since_epoch(signature_inception), key_tag, move(signer_name), move(signature) };
}
ErrorOr<void> Records::SIG::to_raw_excluding_signature(ByteBuffer& buffer) const
{
AllocatingMemoryStream stream;
TRY(stream.write_value(static_cast<NetworkOrdered<u16>>(to_underlying(type_covered))));
TRY(stream.write_value(static_cast<u8>(algorithm)));
TRY(stream.write_value(label_count));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(original_ttl)));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(expiration.seconds_since_epoch())));
TRY(stream.write_value(static_cast<NetworkOrdered<u32>>(inception.seconds_since_epoch())));
TRY(stream.write_value(static_cast<NetworkOrdered<u16>>(key_tag)));
TRY(stream.read_until_filled(TRY(buffer.get_bytes_for_writing(stream.used_buffer_size()))));
TRY(signers_name.to_raw(buffer));
return {};
}
ErrorOr<void> Records::SIG::to_raw(ByteBuffer& buffer) const
{
TRY(to_raw_excluding_signature(buffer));
TRY(buffer.try_append(signature));
return {};
}
ErrorOr<String> Records::SIG::to_string() const
{
// Single line:
@ -1110,6 +1263,18 @@ ErrorOr<Records::HINFO> Records::HINFO::from_raw(ParseContext& ctx)
return Records::HINFO { ByteString::copy(cpu), ByteString::copy(os) };
}
ErrorOr<void> Records::HINFO::to_raw(ByteBuffer& buffer) const
{
auto allocated_length = cpu.length() + os.length() + 2;
auto bytes = TRY(buffer.get_bytes_for_writing(allocated_length));
FixedMemoryStream stream { bytes };
TRY(stream.write_value(static_cast<u8>(cpu.length())));
TRY(stream.write_until_depleted(cpu.bytes()));
TRY(stream.write_value(static_cast<u8>(os.length())));
TRY(stream.write_until_depleted(os.bytes()));
return {};
}
ErrorOr<Records::OPT> Records::OPT::from_raw(ParseContext& ctx)
{
// RFC 6891, 6.1. The OPT pseudo-RR.

View file

@ -91,6 +91,16 @@ struct DomainName {
static ErrorOr<DomainName> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const;
String to_string() const;
String to_canonical_string() const;
DomainName parent() const
{
auto copy = *this;
copy.labels.take_first();
return copy;
}
bool operator==(DomainName const&) const& = default;
bool operator!=(DomainName const&) const& = default;
};
// Listing from IANA https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4.
@ -278,6 +288,8 @@ static inline StringView to_string(Algorithm algorithm)
return "ED25519"sv;
case Algorithm::Unknown:
return "Unknown"sv;
default:
return "Invalid"sv;
}
VERIFY_NOT_REACHED();
}
@ -364,7 +376,7 @@ struct A {
static constexpr ResourceType type = ResourceType::A;
static ErrorOr<A> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const;
ErrorOr<String> to_string() const { return address.to_string(); }
};
struct AAAA {
@ -372,7 +384,7 @@ struct AAAA {
static constexpr ResourceType type = ResourceType::AAAA;
static ErrorOr<AAAA> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const;
ErrorOr<String> to_string() const { return address.to_string(); }
};
struct TXT {
@ -380,7 +392,7 @@ struct TXT {
static constexpr ResourceType type = ResourceType::TXT;
static ErrorOr<TXT> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const;
ErrorOr<String> to_string() const { return String::formatted("Text: '{}'", StringView { content }); }
};
struct CNAME {
@ -388,7 +400,7 @@ struct CNAME {
static constexpr ResourceType type = ResourceType::CNAME;
static ErrorOr<CNAME> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const;
ErrorOr<String> to_string() const { return names.to_string(); }
};
struct NS {
@ -396,7 +408,7 @@ struct NS {
static constexpr ResourceType type = ResourceType::NS;
static ErrorOr<NS> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented: NS::to_raw"); }
ErrorOr<String> to_string() const { return name.to_string(); }
};
struct SOA {
@ -410,7 +422,7 @@ struct SOA {
static constexpr ResourceType type = ResourceType::SOA;
static ErrorOr<SOA> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const;
ErrorOr<String> to_string() const
{
return String::formatted("SOA MName: '{}', RName: '{}', Serial: {}, Refresh: {}, Retry: {}, Expire: {}, Minimum: {}", mname.to_string(), rname.to_string(), serial, refresh, retry, expire, minimum);
@ -422,7 +434,7 @@ struct MX {
static constexpr ResourceType type = ResourceType::MX;
static ErrorOr<MX> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented: MX::to_raw"); }
ErrorOr<String> to_string() const { return String::formatted("MX Preference: {}, Exchange: '{}'", preference, exchange.to_string()); }
};
struct PTR {
@ -430,7 +442,7 @@ struct PTR {
static constexpr ResourceType type = ResourceType::PTR;
static ErrorOr<PTR> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented: PTR::to_raw"); }
ErrorOr<String> to_string() const { return name.to_string(); }
};
struct SRV {
@ -441,7 +453,7 @@ struct SRV {
static constexpr ResourceType type = ResourceType::SRV;
static ErrorOr<SRV> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented: SRV::to_raw"); }
ErrorOr<String> to_string() const { return String::formatted("SRV Priority: {}, Weight: {}, Port: {}, Target: '{}'", priority, weight, port, target.to_string()); }
};
struct DNSKEY {
@ -449,6 +461,17 @@ struct DNSKEY {
u8 protocol;
DNSSEC::Algorithm algorithm;
ByteBuffer public_key;
// Extra: calculated key tag
u16 calculated_key_tag;
// Extra: public key components (pointing into public_key) ONLY for RSA.
u16 public_key_rsa_exponent_length() const
{
if (public_key[0] != 0)
return public_key[0];
return static_cast<u16>(public_key[1]) | static_cast<u16>(public_key[2]) << 8;
}
ReadonlyBytes public_key_rsa_exponent() const { return public_key.bytes().slice(1, public_key_rsa_exponent_length()); }
ReadonlyBytes public_key_rsa_modulus() const { return public_key.bytes().slice(1 + public_key_rsa_exponent_length()); }
constexpr static inline u16 FlagSecureEntryPoint = 0b1000000000000000;
constexpr static inline u16 FlagZoneKey = 0b0100000000000000;
@ -461,10 +484,10 @@ struct DNSKEY {
static constexpr ResourceType type = ResourceType::DNSKEY;
static ErrorOr<DNSKEY> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const;
ErrorOr<String> to_string() const
{
return String::formatted("DNSKEY Flags: {}{}{}{}({}), Protocol: {}, Algorithm: {}, Public Key: {}",
return String::formatted("DNSKEY Flags: {}{}{}{}({}), Protocol: {}, Algorithm: {}, Public Key: {}, Tag: {}",
is_secure_entry_point() ? "sep "sv : ""sv,
is_zone_key() ? "zone "sv : ""sv,
is_revoked() ? "revoked "sv : ""sv,
@ -472,7 +495,8 @@ struct DNSKEY {
flags,
protocol,
DNSSEC::to_string(algorithm),
TRY(encode_base64(public_key)));
TRY(encode_base64(public_key)),
calculated_key_tag);
}
};
struct CDNSKEY : public DNSKEY {
@ -493,8 +517,15 @@ struct DS {
static constexpr ResourceType type = ResourceType::DS;
static ErrorOr<DS> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<String> to_string() const { return "DS"_string; }
ErrorOr<void> to_raw(ByteBuffer&) const;
ErrorOr<String> to_string() const
{
return String::formatted("DS Key Tag: {}, Algorithm: {}, Digest Type: {}, Digest: {}",
key_tag,
DNSSEC::to_string(algorithm),
DNSSEC::to_string(digest_type),
TRY(encode_base64(digest)));
}
};
struct CDS : public DS {
template<typename... Ts>
@ -518,7 +549,8 @@ struct SIG {
static constexpr ResourceType type = ResourceType::SIG;
static ErrorOr<SIG> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const;
ErrorOr<void> to_raw_excluding_signature(ByteBuffer&) const;
ErrorOr<String> to_string() const;
};
struct RRSIG : public SIG {
@ -530,6 +562,7 @@ struct RRSIG : public SIG {
static constexpr ResourceType type = ResourceType::RRSIG;
static ErrorOr<RRSIG> from_raw(ParseContext& raw) { return SIG::from_raw(raw); }
ErrorOr<void> to_raw_excluding_signature(ByteBuffer& buffer) const { return SIG::to_raw_excluding_signature(buffer); }
};
struct NSEC {
DomainName next_domain_name;
@ -537,7 +570,7 @@ struct NSEC {
static constexpr ResourceType type = ResourceType::NSEC;
static ErrorOr<NSEC> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented: NSC::to_raw"); }
ErrorOr<String> to_string() const { return "NSEC"_string; }
};
struct NSEC3 {
@ -550,7 +583,7 @@ struct NSEC3 {
static constexpr ResourceType type = ResourceType::NSEC3;
static ErrorOr<NSEC3> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented: NSEC3::to_raw"); }
ErrorOr<String> to_string() const { return "NSEC3"_string; }
};
struct NSEC3PARAM {
@ -565,7 +598,7 @@ struct NSEC3PARAM {
static constexpr ResourceType type = ResourceType::NSEC3PARAM;
static ErrorOr<NSEC3PARAM> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented: NSEC3PARAM::to_raw"); }
ErrorOr<String> to_string() const { return "NSEC3PARAM"_string; }
};
struct TLSA {
@ -575,7 +608,7 @@ struct TLSA {
ByteBuffer certificate_association_data;
static ErrorOr<TLSA> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented: TLSA::to_raw"); }
ErrorOr<String> to_string() const { return "TLSA"_string; }
};
struct HINFO {
@ -584,7 +617,7 @@ struct HINFO {
static constexpr ResourceType type = ResourceType::HINFO;
static ErrorOr<HINFO> from_raw(ParseContext&);
ErrorOr<void> to_raw(ByteBuffer&) const { return Error::from_string_literal("Not implemented"); }
ErrorOr<void> to_raw(ByteBuffer&) const;
ErrorOr<String> to_string() const { return String::formatted("HINFO CPU: '{}', OS: '{}'", StringView { cpu }, StringView { os }); }
};
struct OPT {

View file

@ -7,9 +7,11 @@
#pragma once
#include <AK/AtomicRefCounted.h>
#include <AK/CountingStream.h>
#include <AK/HashTable.h>
#include <AK/MaybeOwned.h>
#include <AK/MemoryStream.h>
#include <AK/QuickSort.h>
#include <AK/Random.h>
#include <AK/StringView.h>
#include <AK/TemporaryChange.h>
@ -17,10 +19,25 @@
#include <LibCore/Promise.h>
#include <LibCore/Socket.h>
#include <LibCore/Timer.h>
#include <LibCrypto/Certificate/Certificate.h>
#include <LibCrypto/Curves/EdwardsCurve.h>
#include <LibCrypto/PK/RSA.h>
#include <LibDNS/Message.h>
#include <LibThreading/MutexProtected.h>
#include <LibThreading/RWLockProtected.h>
#undef DNS_DEBUG
#define DNS_DEBUG 1
#define TRY_OR_REJECT_PROMISE(promise, expr) \
({ \
auto _result = (expr); \
if (_result.is_error()) { \
promise->reject(_result.release_error()); \
return promise; \
} \
_result.release_value(); \
})
namespace DNS {
class Resolver;
@ -39,7 +56,8 @@ public:
re.record.record.visit(
[&](Messages::Records::A const& a) { result.append(a.address); },
[&](Messages::Records::AAAA const& aaaa) { result.append(aaaa.address); },
[](auto&) {});
[](auto&) {
});
}
return result;
}
@ -56,7 +74,8 @@ public:
dbgln_if(DNS_DEBUG, "DNS: Removing expired record for {}", m_name.to_string());
m_cached_records.remove(i);
} else {
dbgln_if(DNS_DEBUG, "DNS: Keeping record for {} (expires in {})", m_name.to_string(), record.expiration.has_value() ? record.expiration.value().to_string() : "never"_string);
dbgln_if(DNS_DEBUG, "DNS: Keeping record for {} (expires in {})", m_name.to_string(),
record.expiration.has_value() ? record.expiration.value().to_string() : "never"_string);
++i;
}
}
@ -68,7 +87,10 @@ public:
void add_record(Messages::ResourceRecord record)
{
m_valid = true;
auto expiration = record.ttl > 0 ? Optional<Core::DateTime>(Core::DateTime::from_timestamp(Core::DateTime::now().timestamp() + record.ttl)) : OptionalNone();
auto expiration = record.ttl > 0
? Optional<Core::DateTime>(
Core::DateTime::from_timestamp(Core::DateTime::now().timestamp() + record.ttl))
: OptionalNone();
m_cached_records.append({ move(record), move(expiration) });
}
@ -80,6 +102,35 @@ public:
return result;
}
Vector<Messages::ResourceRecord> records(Messages::ResourceType type) const
{
Vector<Messages::ResourceRecord> result;
for (auto& re : m_cached_records) {
if (re.record.type == type)
result.append(re.record);
}
return result;
}
Messages::ResourceRecord const& record(Messages::ResourceType type) const
{
for (auto const& re : m_cached_records) {
if (re.record.type == type)
return re.record;
}
VERIFY_NOT_REACHED();
}
template<typename RR>
RR const& record() const
{
for (auto const& re : m_cached_records) {
if (re.record.type == RR::type)
return re.record.record.get<RR>();
}
VERIFY_NOT_REACHED();
}
bool has_record_of_type(Messages::ResourceType type, bool later = false) const
{
if (later && m_desired_types.contains(type))
@ -100,18 +151,30 @@ public:
bool can_be_removed() const { return !m_valid && m_request_done; }
bool is_done() const { return m_request_done; }
void set_dnssec_validated(bool validated) { m_dnssec_validated = validated; }
bool is_dnssec_validated() const { return m_dnssec_validated; }
void set_being_dnssec_validated(bool validated) { m_being_dnssec_validated = validated; }
bool is_being_dnssec_validated() const { return m_being_dnssec_validated; }
Messages::DomainName const& name() const { return m_name; }
Vector<Messages::Records::DNSKEY> const& used_dnskeys() const { return m_used_dnskeys; }
void add_dnskey(Messages::Records::DNSKEY key) { m_used_dnskeys.append(move(key)); }
private:
bool m_valid { false };
bool m_request_done { false };
bool m_dnssec_validated { false };
bool m_being_dnssec_validated { false };
Messages::DomainName m_name;
struct RecordWithExpiration {
Messages::ResourceRecord record;
Optional<Core::DateTime> expiration;
};
Vector<RecordWithExpiration> m_cached_records;
HashTable<Messages::ResourceType> m_desired_types;
Vector<Messages::Records::DNSKEY> m_used_dnskeys {};
u16 m_id { 0 };
};
@ -119,6 +182,7 @@ class Resolver {
struct PendingLookup {
u16 id { 0 };
ByteString name;
Messages::DomainName parsed_name;
WeakPtr<LookupResult> result;
NonnullRefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> promise;
NonnullRefPtr<Core::Timer> repeat_timer;
@ -131,6 +195,13 @@ public:
UDP,
};
struct LookupOptions {
bool validate_dnssec_locally { false };
PendingLookup* repeating_lookup { nullptr };
static LookupOptions default_() { return {}; }
};
struct SocketResult {
MaybeOwned<Core::Socket> socket;
ConnectionMode mode;
@ -212,25 +283,31 @@ public:
});
}
NonnullRefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> lookup(ByteString name, Messages::Class class_ = Messages::Class::IN)
NonnullRefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> lookup(ByteString name, Messages::Class class_ = Messages::Class::IN, LookupOptions options = LookupOptions::default_())
{
return lookup(move(name), class_, { Messages::ResourceType::A, Messages::ResourceType::AAAA });
return lookup(move(name), class_, { Messages::ResourceType::A, Messages::ResourceType::AAAA }, options);
}
NonnullRefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> lookup(ByteString name, Messages::Class class_, Vector<Messages::ResourceType> desired_types, PendingLookup* repeating_lookup = nullptr)
NonnullRefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> lookup(ByteString name, Messages::Class class_, Vector<Messages::ResourceType> desired_types, LookupOptions options = LookupOptions::default_())
{
flush_cache();
if (repeating_lookup && repeating_lookup->times_repeated >= 5) {
auto promise = repeating_lookup->promise;
if (options.repeating_lookup && options.repeating_lookup->times_repeated >= 5) {
dbgln_if(DNS_DEBUG, "DNS: Repeating lookup for {} timed out", name);
auto promise = options.repeating_lookup->promise;
promise->reject(Error::from_string_literal("DNS lookup timed out"));
m_pending_lookups.with_write_locked([&](auto& lookups) { lookups->remove(repeating_lookup->id); });
m_pending_lookups.with_write_locked([&](auto& lookups) {
lookups->remove(options.repeating_lookup->id);
});
return promise;
}
auto promise = repeating_lookup ? repeating_lookup->promise : Core::Promise<NonnullRefPtr<LookupResult const>>::construct();
auto promise = options.repeating_lookup
? options.repeating_lookup->promise
: Core::Promise<NonnullRefPtr<LookupResult const>>::construct();
if (auto maybe_ipv4 = IPv4Address::from_string(name); maybe_ipv4.has_value()) {
dbgln_if(DNS_DEBUG, "DNS: Resolving {} as IPv4", name);
if (desired_types.contains_slow(Messages::ResourceType::A)) {
auto result = make_ref_counted<LookupResult>(Messages::DomainName {});
result->add_record({ .name = {}, .type = Messages::ResourceType::A, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::A { maybe_ipv4.release_value() }, .raw = {} });
@ -241,6 +318,7 @@ public:
}
if (auto maybe_ipv6 = IPv6Address::from_string(name); maybe_ipv6.has_value()) {
dbgln_if(DNS_DEBUG, "DNS: Resolving {} as IPv6", name);
if (desired_types.contains_slow(Messages::ResourceType::AAAA)) {
auto result = make_ref_counted<LookupResult>(Messages::DomainName {});
result->add_record({ .name = {}, .type = Messages::ResourceType::AAAA, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::AAAA { maybe_ipv6.release_value() }, .raw = {} });
@ -251,13 +329,23 @@ public:
}
if (auto result = lookup_in_cache(name, class_, desired_types)) {
promise->resolve(result.release_nonnull());
return promise;
dbgln_if(DNS_DEBUG, "DNS: Resolving {} from cache...", name);
if (!options.validate_dnssec_locally || result->is_dnssec_validated()) {
dbgln_if(DNS_DEBUG, "DNS: Resolved {} from cache", name);
promise->resolve(result.release_nonnull());
return promise;
}
dbgln_if(DNS_DEBUG, "DNS: Cache entry for {} is not DNSSEC validated (and we expect that), re-resolving", name);
}
auto domain_name = Messages::DomainName::from_string(name);
if (!has_connection()) {
if (options.validate_dnssec_locally) {
promise->reject(Error::from_string_literal("No connection available to validate DNSSEC"));
return promise;
}
// Use system resolver
// FIXME: Use an underlying resolver instead.
dbgln_if(DNS_DEBUG, "Not ready to resolve, using system resolver and skipping cache for {}", name);
@ -285,27 +373,38 @@ public:
auto already_in_cache = false;
auto result = m_cache.with_write_locked([&](auto& cache) -> NonnullRefPtr<LookupResult> {
dbgln_if(DNS_DEBUG, "DNS: Resolving {}...", name);
auto existing = [&] -> RefPtr<LookupResult> {
if (cache.contains(name)) {
dbgln_if(DNS_DEBUG, "DNS: Resolving {} from cache...", name);
auto ptr = *cache.get(name);
already_in_cache = true;
already_in_cache = (!options.validate_dnssec_locally && !ptr->is_being_dnssec_validated()) || ptr->is_dnssec_validated();
for (auto const& type : desired_types) {
if (!ptr->has_record_of_type(type, true)) {
if (!ptr->has_record_of_type(type, !options.validate_dnssec_locally && !ptr->is_being_dnssec_validated())) {
already_in_cache = false;
break;
}
}
dbgln_if(DNS_DEBUG, "DNS: Found {} in cache, already_in_cache={}", name, already_in_cache);
dbgln_if(DNS_DEBUG, "DNS: That entry is {} DNSSEC validated", ptr->is_dnssec_validated() ? "already" : "not");
for (auto const& entry : ptr->records())
dbgln_if(DNS_DEBUG, "DNS: Found record of type {}", Messages::to_string(entry.type));
return ptr;
}
return nullptr;
}();
if (existing)
if (existing) {
dbgln_if(DNS_DEBUG, "DNS: Resolved {} from cache", name);
return *existing;
}
dbgln_if(DNS_DEBUG, "DNS: Adding {} to cache", name);
auto ptr = make_ref_counted<LookupResult>(domain_name);
if (!ptr->is_dnssec_validated())
ptr->set_dnssec_validated(options.validate_dnssec_locally);
for (auto const& type : desired_types)
ptr->will_add_record_of_type(type);
cache.set(name, ptr);
@ -316,11 +415,12 @@ public:
if (already_in_cache) {
auto id = result->id();
cached_result_id = id;
auto existing_promise = m_pending_lookups.with_write_locked([&](auto& lookups) -> RefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> {
if (auto* lookup = lookups->find(id))
return lookup->promise;
return nullptr;
});
auto existing_promise = m_pending_lookups.with_write_locked(
[&](auto& lookups) -> RefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> {
if (auto* lookup = lookups->find(id))
return lookup->promise;
return nullptr;
});
if (existing_promise)
return existing_promise.release_nonnull();
@ -333,9 +433,9 @@ public:
}
Messages::Message query;
if (repeating_lookup) {
query.header.id = repeating_lookup->id;
repeating_lookup->times_repeated++;
if (options.repeating_lookup) {
query.header.id = options.repeating_lookup->id;
options.repeating_lookup->times_repeated++;
} else {
m_pending_lookups.with_read_locked([&](auto& lookups) {
do
@ -363,23 +463,48 @@ public:
});
}
auto cached_entry = repeating_lookup ? nullptr : m_pending_lookups.with_write_locked([&](auto& pending_lookups) -> PendingLookup* {
// One more try to make sure we're not overwriting an existing lookup
if (cached_result_id.has_value()) {
if (auto* lookup = pending_lookups->find(*cached_result_id))
return lookup;
}
pending_lookups->insert(query.header.id, { query.header.id, name, result->make_weak_ptr(), promise, Core::Timer::create(), 0 });
auto p = pending_lookups->find(query.header.id);
p->repeat_timer->set_single_shot(true);
p->repeat_timer->set_interval(1000);
p->repeat_timer->on_timeout = [=, this] {
(void)lookup(name, class_, desired_types, p);
if (options.validate_dnssec_locally) {
query.header.additional_count = 1;
query.header.options.set_checking_disabled(true);
query.header.options.set_authenticated_data(true);
auto opt = Messages::Records::OPT {
.udp_payload_size = 4096,
.extended_rcode_and_flags = 0,
.options = {},
};
opt.set_dnssec_ok(true);
return nullptr;
});
query.additional_records.append(Messages::ResourceRecord {
.name = Messages::DomainName::from_string(""sv),
.type = Messages::ResourceType::OPT,
.class_ = class_,
.ttl = 0,
.record = move(opt),
.raw = {},
});
}
result->set_id(query.header.id);
auto cached_entry = options.repeating_lookup
? nullptr
: m_pending_lookups.with_write_locked([&](auto& pending_lookups) -> PendingLookup* {
// One more try to make sure we're not overwriting an existing lookup
if (cached_result_id.has_value()) {
if (auto* lookup = pending_lookups->find(*cached_result_id))
return lookup;
}
pending_lookups->insert(query.header.id, { query.header.id, name, domain_name, result->make_weak_ptr(), promise, Core::Timer::create(), 0 });
auto p = pending_lookups->find(query.header.id);
p->repeat_timer->set_single_shot(true);
p->repeat_timer->set_interval(1000);
p->repeat_timer->on_timeout = [=, this] {
(void)lookup(name, class_, desired_types, { .validate_dnssec_locally = options.validate_dnssec_locally, .repeating_lookup = p });
};
return nullptr;
});
if (cached_entry) {
dbgln_if(DNS_DEBUG, "DNS::lookup({}) -> Lookup already underway", name);
@ -446,7 +571,10 @@ private:
void process_incoming_messages()
{
while (true) {
if (auto result = m_socket.with_read_locked([](auto& socket) { return (*socket)->can_read_without_blocking(); }); result.is_error() || !result.value())
if (auto result = m_socket.with_read_locked([](auto& socket) {
return (*socket)->can_read_without_blocking();
});
result.is_error() || !result.value())
break;
auto message_or_err = parse_one_message();
if (message_or_err.is_error()) {
@ -461,12 +589,17 @@ private:
if (!lookup)
return Error::from_string_literal("No pending lookup found for this message");
if (lookup->result.is_null())
if (lookup->result.is_null()) {
dbgln_if(DNS_DEBUG, "DNS: Received a message with no pending lookup (id={})", message.header.id);
return {}; // Message is a response to a lookup that's been purged from the cache, ignore it
}
lookup->repeat_timer->stop();
auto result = lookup->result.strong_ref();
if (result->is_dnssec_validated())
return validate_dnssec(move(message), *lookup, *result);
for (auto& record : message.answers)
result->add_record(move(record));
@ -475,11 +608,468 @@ private:
lookups->remove(message.header.id);
return {};
});
if (result.is_error()) {
if (result.is_error())
dbgln_if(DNS_DEBUG, "DNS: Received a message with no pending lookup: {}", result.error());
continue;
}
}
using RRSet = Vector<Messages::ResourceRecord>;
struct CanonicalizedRRSetWithRRSIG {
RRSet rrset;
Messages::Records::RRSIG rrsig;
Vector<Messages::Records::DNSKEY> dnskeys;
};
NonnullRefPtr<Core::Promise<bool>> validate_dnssec_chain_step(Messages::DomainName const& name, bool top_level = false)
{
dbgln_if(DNS_DEBUG, "DNS: Validating DNSSEC chain for {}", name.to_string());
auto promise = Core::Promise<bool>::construct();
// - If this is the root, we're done, just return true.
if (name.labels.size() == 0) {
promise->resolve(true);
return promise;
}
// - Lookup the SOA record for the domain.
auto soa_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::SOA }, { .validate_dnssec_locally = !top_level })->await()));
// - If we have no SOA record-
if (!soa_result->has_record_of_type(Messages::ResourceType::SOA)) {
dbgln_if(DNS_DEBUG, "DNS: No SOA record found for {}", name.to_string());
// - First, check for a DS record-
auto ds_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::DS }, { .validate_dnssec_locally = !top_level })->await()));
// - If there's no DS record, check for an NS record-
if (!ds_result->has_record_of_type(Messages::ResourceType::DS)) {
dbgln_if(DNS_DEBUG, "DNS: No DS record found for {}", name.to_string());
// - If there's no DS record, check for an NS record-
auto ns_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::NS }, { .validate_dnssec_locally = !top_level })->await()));
if (ns_result->has_record_of_type(Messages::ResourceType::NS)) {
// - but if there _is_ an NS record, this is a broken delegation, so reject.
dbgln_if(DNS_DEBUG, "DNS: Found NS record for {}", name.to_string());
promise->resolve(false);
return promise;
}
dbgln_if(DNS_DEBUG, "DNS: No NS record found for {}", name.to_string());
// this is just part of the parent delegation, so go up one level.
return validate_dnssec_chain_step(name.parent());
}
// - If there is a DS record, this is a separate zone...but since we don't have an SOA record, this is a misconfigured zone.
// Let's just reject.
dbgln_if(DNS_DEBUG, "DNS: Found DS record for {}", name.to_string());
promise->resolve(false);
return promise;
}
// So we have an SOA record, there's much rejoicing and we can continue.
auto& soa = soa_result->record<Messages::Records::SOA>();
dbgln_if(DNS_DEBUG, "DNS: Found SOA record for {}: {}", name.to_string(), soa.mname.to_string());
if (soa.mname == name.parent()) {
// Just go up one level, all is well.
return validate_dnssec_chain_step(name.parent());
}
// So this is a separate zone, let's look up the DS record.
auto ds_result = TRY_OR_REJECT_PROMISE(promise, (lookup(name.to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::DS }, { .validate_dnssec_locally = false })->await()));
if (!ds_result->has_record_of_type(Messages::ResourceType::DS)) {
// If there's no DS record, this is a misconfigured zone.
dbgln_if(DNS_DEBUG, "DNS: No DS record found for {}", name.to_string());
promise->resolve(false);
return promise;
}
promise->resolve(true);
return promise;
}
ErrorOr<void> validate_dnssec(Messages::Message message, PendingLookup& lookup, NonnullRefPtr<LookupResult> result)
{
struct RecordAndRRSIG {
Vector<Messages::ResourceRecord> records;
Messages::Records::RRSIG rrsig;
};
HashMap<Messages::ResourceType, RecordAndRRSIG> records_with_rrsigs;
for (auto& record : message.answers) {
if (record.type == Messages::ResourceType::RRSIG) {
auto& rrsig = record.record.get<Messages::Records::RRSIG>();
auto type = rrsig.type_covered;
if (auto found = records_with_rrsigs.get(type); found.has_value())
found->rrsig = move(rrsig);
else
records_with_rrsigs.set(type, { {}, move(rrsig) });
} else {
auto type = record.type;
if (auto found = records_with_rrsigs.get(record.type); found.has_value())
found->records.append(move(record));
else
records_with_rrsigs.set(type, { { move(record) }, {} });
}
}
if (records_with_rrsigs.is_empty()) {
dbgln_if(DNS_DEBUG, "DNS: No RRSIG records found in DNSSEC response");
return {};
}
auto name = result->name();
Core::deferred_invoke([this, lookup, name, records_with_rrsigs = move(records_with_rrsigs), result = move(result)] mutable {
dbgln_if(DNS_DEBUG, "DNS: Resolving DNSKEY for {}", name.to_string());
result->set_dnssec_validated(false); // Will be set to true if we successfully validate the RRSIGs.
result->set_being_dnssec_validated(true);
Vector<Messages::Records::DNSKEY> parent_zone_keys;
auto is_root_zone = lookup.parsed_name.labels.size() == 0;
if (!is_root_zone) {
auto chain_valid_result = validate_dnssec_chain_step(name, true)->await();
if (chain_valid_result.is_error()) {
lookup.promise->reject(chain_valid_result.release_error());
return;
}
if (!chain_valid_result.value()) {
lookup.promise->reject(Error::from_string_literal("DNSSEC chain is invalid"));
return;
}
auto parent_result = this->lookup(lookup.parsed_name.parent().to_string().to_byte_string(), Messages::Class::IN, { Messages::ResourceType::DNSKEY }, { .validate_dnssec_locally = true })
->await();
if (parent_result.is_error()) {
lookup.promise->reject(parent_result.release_error());
return;
}
if (!parent_result.value()->is_dnssec_validated()) {
lookup.promise->reject(Error::from_string_literal("Parent zone is not DNSSEC validated"));
return;
}
parent_zone_keys = parent_result.value()->used_dnskeys();
for (auto& rr : parent_result.value()->records(Messages::ResourceType::DNSKEY))
parent_zone_keys.append(rr.record.get<Messages::Records::DNSKEY>());
dbgln("Found {} DNSKEYs for parent zone ({})", parent_zone_keys.size(), lookup.parsed_name.parent().to_string());
}
auto resolve_using_keys = [=, this, records_with_rrsigs = move(records_with_rrsigs)](Vector<Messages::Records::DNSKEY> keys) mutable {
dbgln_if(DNS_DEBUG, "DNS: Validating {} RRSIGs for {}; starting with {} keys", records_with_rrsigs.size(), name.to_string(), keys.size());
for (auto& key : keys)
dbgln_if(DNS_DEBUG, "- DNSKEY: {}", key.to_string());
Vector<NonnullRefPtr<Core::Promise<Empty>>> promises;
for (auto& record_and_rrsig : records_with_rrsigs) {
auto& records = record_and_rrsig.value.records;
if (record_and_rrsig.key == Messages::ResourceType::DNSKEY) {
for (auto& record : records)
keys.append(record.record.get<Messages::Records::DNSKEY>());
}
}
dbgln_if(DNS_DEBUG, "DNS: Found {} keys total", keys.size());
// (owner | type | class) -> (RRSet, RRSIG, DNSKey*)
HashMap<String, CanonicalizedRRSetWithRRSIG> rrsets_with_rrsigs;
for (auto& [type, pair] : records_with_rrsigs) {
auto& records = pair.records;
auto& rrsig = pair.rrsig;
for (auto& record : records) {
auto canonicalized_name = record.name.to_canonical_string();
auto key = MUST(String::formatted("{}|{}|{}", canonicalized_name, to_underlying(record.type), to_underlying(record.class_)));
if (!rrsets_with_rrsigs.contains(key)) {
auto dnskeys = [&] -> Vector<Messages::Records::DNSKEY> {
Vector<Messages::Records::DNSKEY> relevant_keys;
for (auto& key : keys) {
if (key.algorithm == rrsig.algorithm)
relevant_keys.append(key);
}
return relevant_keys;
}();
dbgln_if(DNS_DEBUG, "DNS: Found {} relevant DNSKEYs for key {}", dnskeys.size(), key);
rrsets_with_rrsigs.set(key, CanonicalizedRRSetWithRRSIG { {}, move(rrsig), move(dnskeys) });
}
auto& rrset_with_rrsig = *rrsets_with_rrsigs.get(key);
rrset_with_rrsig.rrset.append(move(record));
}
}
for (auto& entry : rrsets_with_rrsigs) {
auto& rrset_with_rrsig = entry.value;
if (rrset_with_rrsig.dnskeys.is_empty()) {
dbgln_if(DNS_DEBUG, "DNS: No DNSKEY found for validation of {} RRs", rrset_with_rrsig.rrset.size());
continue;
}
promises.append(validate_rrset_with_rrsig(move(rrset_with_rrsig), result));
}
auto promise = Core::Promise<Empty>::after(move(promises))
->when_resolved([result, lookup](Empty) {
result->set_dnssec_validated(true);
result->set_being_dnssec_validated(false);
result->finished_request();
lookup.promise->resolve(result);
})
.when_rejected([result, lookup](Error& error) {
result->finished_request();
result->set_being_dnssec_validated(false);
lookup.promise->reject(move(error));
})
.map<NonnullRefPtr<LookupResult const>>([result](Empty&) { return result; });
lookup.promise = move(promise);
};
if (is_root_zone) {
resolve_using_keys(Vector<Messages::Records::DNSKEY> {
{
.flags = 257,
.protocol = 3,
.algorithm = Messages::DNSSEC::Algorithm::RSASHA256,
.public_key = MUST(decode_base64("AwEAAaz/tAm8yTn4Mfeh5eyI96WSVexTBAvkMgJzkKTOiW1vkIbzxeF3+/4RgWOq7HrxRixHlFlExOLAJr5emLvN7SWXgnLh4+B5xQlNVz8Og8kvArMtNROxVQuCaSnIDdD5LKyWbRd2n9WGe2R8PzgCmr3EgVLrjyBxWezF0jLHwVN8efS3rCj/EWgvIWgb9tarpVUDK/b58Da+sqqls3eNbuv7pr+eoZG+SrDK6nWeL3c6H5Apxz7LjVc1uTIdsIXxuOLYA4/ilBmSVIzuDWfdRUfhHdY6+cn8HFRm+2hM8AnXGXws9555KrUB5qihylGa8subX2Nn6UwNR1AkUTV74bU="sv)),
.calculated_key_tag = 20326,
},
{
.flags = 256,
.protocol = 3,
.algorithm = Messages::DNSSEC::Algorithm::RSASHA256,
.public_key = MUST(decode_base64("AwEAAa96jeuknZlaeSrvyAJj6ZHv28hhOKkx3rLGXVaC6rXTsDc449/cidltpkyGwCJNnOAlFNKF2jBosZBU5eeHspaQWOmOElZsjICMQMC3aeHbGiShvZsx4wMYSjH8e7Vrhbu6irwCzVBApESjbUdpWWmEnhathWu1jo+siFUiRAAxm9qyJNg/wOZqqzL/dL/q8PkcRU5oUKEpUge71M3ej2/7CPqpdVwuMoTvoB+ZOT4YeGyxMvHmbrxlFzGOHOijtzN+u1TQNatX2XBuzZNQ1K+s2CXkPIZo7s6JgZyvaBevYtxPvYLw4z9mR7K2vaF18UYH9Z9GNUUeayffKC73PYc="sv)),
.calculated_key_tag = 38696,
},
});
return;
}
dbgln_if(DNS_DEBUG, "DNS: Starting DNSKEY lookup for {}", lookup.name);
this->lookup(lookup.name, Messages::Class::IN, { Messages::ResourceType::DNSKEY }, { .validate_dnssec_locally = false })
->when_resolved([=](NonnullRefPtr<LookupResult const>& dnskey_lookup_result) mutable {
dbgln_if(DNS_DEBUG, "DNSKEY for {}:", name.to_string());
auto key_records = dnskey_lookup_result->records(Messages::ResourceType::DNSKEY);
for (auto& record : key_records)
dbgln_if(DNS_DEBUG, "- DNSKEY: {}", record.to_string());
Vector<Messages::Records::DNSKEY> keys;
keys.ensure_capacity(parent_zone_keys.size() + dnskey_lookup_result->records().size());
for (auto& record : parent_zone_keys)
keys.append(record);
for (auto& record : key_records)
keys.append(move(record.record).get<Messages::Records::DNSKEY>());
resolve_using_keys(move(keys));
})
.when_rejected([=](auto& error) mutable {
if (parent_zone_keys.is_empty()) {
dbgln_if(DNS_DEBUG, "Failed to resolve DNSKEY for {}: {}", name.to_string(), error);
lookup.promise->reject(move(error));
}
resolve_using_keys(move(parent_zone_keys));
});
});
return {};
}
Messages::Records::DNSKEY const* find_dnskey(CanonicalizedRRSetWithRRSIG const& rrset_with_rrsig)
{
for (auto& key : rrset_with_rrsig.dnskeys) {
if (key.calculated_key_tag == rrset_with_rrsig.rrsig.key_tag)
return &key;
dbgln_if(DNS_DEBUG, "DNS: DNSKEY with tag {} does not match RRSIG with tag {}", key.calculated_key_tag, rrset_with_rrsig.rrsig.key_tag);
}
return nullptr;
}
NonnullRefPtr<Core::Promise<Empty>> validate_rrset_with_rrsig(CanonicalizedRRSetWithRRSIG rrset_with_rrsig, NonnullRefPtr<LookupResult> result)
{
auto promise = Core::Promise<Empty>::construct();
auto& rrsig = rrset_with_rrsig.rrsig;
Vector<ByteBuffer> canon_encoded_rrs;
auto total_size = 0uz;
for (auto& rr : rrset_with_rrsig.rrset) {
rr.ttl = rrsig.original_ttl;
canon_encoded_rrs.empend();
auto& canon_encoded_rr = canon_encoded_rrs.last();
TRY_OR_REJECT_PROMISE(promise, rr.to_raw(canon_encoded_rr));
total_size += canon_encoded_rr.size();
}
quick_sort(canon_encoded_rrs, [](auto const& a, auto const& b) {
return memcmp(a.data(), b.data(), min(a.size(), b.size())) < 0;
});
ByteBuffer canon_encoded;
TRY_OR_REJECT_PROMISE(promise, canon_encoded.try_ensure_capacity(total_size));
for (auto& rr : canon_encoded_rrs)
canon_encoded.append(rr);
auto& dnskey = *find_dnskey(rrset_with_rrsig);
if constexpr (DNS_DEBUG) {
dbgln("Validating RRSet with RRSIG for {}", result->name().to_string());
for (auto& rr : rrset_with_rrsig.rrset)
dbgln("- RR {}", rr.to_string());
for (auto& canon : canon_encoded_rrs) {
FixedMemoryStream stream { canon.bytes() };
CountingStream rr_counting_stream { MaybeOwned<Stream>(stream) };
DNS::Messages::ParseContext rr_ctx { rr_counting_stream, make<RedBlackTree<u16, Messages::DomainName>>() };
auto maybe_decoded = Messages::ResourceRecord::from_raw(rr_ctx);
if (maybe_decoded.is_error())
dbgln("-- Failed to decode RR: {}", maybe_decoded.error());
else
dbgln("-- Canon encoded (decoded): {}", maybe_decoded.value().to_string());
}
dbgln( "- DNSKEY {}", dnskey.to_string());
dbgln("- RRSIG {}", rrsig.to_string());
}
ByteBuffer to_be_signed;
{
// 2 bytes: type_covered
// 1 byte : algorithm
// 1 byte : labels
// 4 bytes: original_ttl
// 4 bytes: signature_expiration
// 4 bytes: signature_inception
// 2 bytes: key_tag
// (wire-format encoded signer name)
to_be_signed = TRY_OR_REJECT_PROMISE(promise, ByteBuffer::create_uninitialized(2 + 1 + 1 + 4 + 4 + 4 + 2));
auto write_u16_be = [&](size_t offset, u16 value) {
to_be_signed.bytes()[offset + 0] = (value >> 8) & 0xff;
to_be_signed.bytes()[offset + 1] = (value >> 0) & 0xff;
};
auto write_u32_be = [&](size_t offset, u32 value) {
to_be_signed.bytes()[offset + 0] = (value >> 24) & 0xff;
to_be_signed.bytes()[offset + 1] = (value >> 16) & 0xff;
to_be_signed.bytes()[offset + 2] = (value >> 8) & 0xff;
to_be_signed.bytes()[offset + 3] = (value >> 0) & 0xff;
};
size_t offset = 0;
write_u16_be(offset, to_underlying(rrsig.type_covered));
offset += 2;
to_be_signed[offset++] = static_cast<u8>(rrsig.algorithm);
to_be_signed[offset++] = rrsig.label_count;
write_u32_be(offset, rrsig.original_ttl);
offset += 4;
write_u32_be(offset, rrsig.expiration.seconds_since_epoch());
offset += 4;
write_u32_be(offset, rrsig.inception.seconds_since_epoch());
offset += 4;
write_u16_be(offset, rrsig.key_tag);
}
TRY_OR_REJECT_PROMISE(promise, rrsig.signers_name.to_raw(to_be_signed));
TRY_OR_REJECT_PROMISE(promise, to_be_signed.try_append(canon_encoded.data(), canon_encoded.size()));
dbgln_if(DNS_DEBUG, "To be signed: {:hex-dump}", to_be_signed.bytes());
// constexpr auto rsa_prefix_for = [](Crypto::Hash::HashKind kind) -> ReadonlyBytes {
// if (kind == Crypto::Hash::HashKind::SHA256)
// return "\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20"sv.bytes();
// if (kind == Crypto::Hash::HashKind::SHA512)
// return "\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40"sv.bytes();
// VERIFY_NOT_REACHED();
// };
switch (dnskey.algorithm) {
case Messages::DNSSEC::Algorithm::RSAMD5: {
auto md5 = Crypto::Hash::MD5::create();
md5->update(to_be_signed.data(), to_be_signed.size());
auto digest = md5->digest();
auto public_key = TRY_OR_REJECT_PROMISE(promise, Crypto::PK::RSA::parse_rsa_key(dnskey.public_key, false, {}));
auto const& signature_data = rrsig.signature; // ByteBuffer with raw RSA/MD5 signature
if (signature_data.is_empty()) {
promise->reject(Error::from_string_literal("RRSIG has an empty signature"));
return promise;
}
Crypto::PK::RSA_PKCS1_EME rsa { public_key };
if (auto const ok = TRY_OR_REJECT_PROMISE(promise, rsa.verify(digest.bytes(), signature_data)); !ok) {
promise->reject(Error::from_string_literal("RSA/MD5 signature validation failed"));
return promise;
}
break;
}
case Messages::DNSSEC::Algorithm::ECDSAP256SHA256: {
auto sha256 = Crypto::Hash::SHA256::hash(to_be_signed);
auto keys = TRY_OR_REJECT_PROMISE(promise, Crypto::PK::EC::parse_ec_key(dnskey.public_key, false, {}));
auto signature = TRY_OR_REJECT_PROMISE(promise, Crypto::Curves::SECPxxxr1Signature::from_raw(Crypto::ASN1::secp256r1_oid, rrsig.signature));
Crypto::Curves::SECP256r1 curve;
if (auto ok = TRY_OR_REJECT_PROMISE(promise, curve.verify(sha256.bytes(), keys.public_key.to_secpxxxr1_point(), signature)); !ok) {
promise->reject(Error::from_string_literal("ECDSA/SHA256 signature validation failed"));
return promise;
}
break;
}
case Messages::DNSSEC::Algorithm::ECDSAP384SHA384: {
auto sha384 = Crypto::Hash::SHA384::hash(to_be_signed);
auto keys = TRY_OR_REJECT_PROMISE(promise, Crypto::PK::EC::parse_ec_key(dnskey.public_key, false, {}));
auto signature = TRY_OR_REJECT_PROMISE(promise, Crypto::Curves::SECPxxxr1Signature::from_raw(Crypto::ASN1::secp384r1_oid, rrsig.signature));
Crypto::Curves::SECP384r1 curve;
if (auto ok = TRY_OR_REJECT_PROMISE(promise, curve.verify(sha384.bytes(), keys.public_key.to_secpxxxr1_point(), signature)); !ok) {
promise->reject(Error::from_string_literal("ECDSA/SHA384 signature validation failed"));
return promise;
}
break;
}
case Messages::DNSSEC::Algorithm::RSASHA512: {
auto n = Crypto::UnsignedBigInteger::import_data(dnskey.public_key_rsa_modulus());
auto e = Crypto::UnsignedBigInteger::import_data(dnskey.public_key_rsa_exponent());
Crypto::PK::RSA_PKCS1_EMSA rsa { Crypto::Hash::HashKind::SHA512, Crypto::PK::RSAPublicKey { move(n), move(e) } };
if (auto ok = TRY_OR_REJECT_PROMISE(promise, rsa.verify(to_be_signed, rrsig.signature)); !ok) {
promise->reject(Error::from_string_literal("RSA/SHA512 signature validation failed"));
return promise;
}
break;
}
case Messages::DNSSEC::Algorithm::RSASHA1: {
auto n = Crypto::UnsignedBigInteger::import_data(dnskey.public_key_rsa_modulus());
auto e = Crypto::UnsignedBigInteger::import_data(dnskey.public_key_rsa_exponent());
Crypto::PK::RSA_PKCS1_EMSA rsa { Crypto::Hash::HashKind::SHA1, Crypto::PK::RSAPublicKey { move(n), move(e) } };
if (auto ok = TRY_OR_REJECT_PROMISE(promise, rsa.verify(to_be_signed, rrsig.signature)); !ok) {
promise->reject(Error::from_string_literal("RSA/SHA1 signature validation failed"));
return promise;
}
break;
}
case Messages::DNSSEC::Algorithm::RSASHA256: {
auto n = Crypto::UnsignedBigInteger::import_data(dnskey.public_key_rsa_modulus());
auto e = Crypto::UnsignedBigInteger::import_data(dnskey.public_key_rsa_exponent());
Crypto::PK::RSA_PKCS1_EMSA rsa { Crypto::Hash::HashKind::SHA256, Crypto::PK::RSAPublicKey { move(n), move(e) } };
if (auto ok = TRY_OR_REJECT_PROMISE(promise, rsa.verify(to_be_signed, rrsig.signature)); !ok) {
promise->reject(Error::from_string_literal("RSA/SHA256 signature validation failed"));
return promise;
}
break;
}
case Messages::DNSSEC::Algorithm::ED25519: {
Crypto::Curves::Ed25519 ed25519;
if (!TRY_OR_REJECT_PROMISE(promise, ed25519.verify(dnskey.public_key.bytes(), rrsig.signature.bytes(), to_be_signed.bytes()))) {
promise->reject(Error::from_string_literal("ED25519 signature validation failed"));
return promise;
}
break;
}
case Messages::DNSSEC::Algorithm::DSA:
case Messages::DNSSEC::Algorithm::RSASHA1NSEC3SHA1:
// Not implemented yet.
case Messages::DNSSEC::Algorithm::Unknown:
dbgln("DNS: Unsupported algorithm for DNSSEC validation: {}", to_string(dnskey.algorithm));
promise->reject(Error::from_string_literal("Unsupported algorithm for DNSSEC validation"));
break;
}
// If we haven't rejected by now, we consider the RRSet valid.
if (!promise->is_rejected()) {
// Typically you'd store these validated RRs in the lookup result.
for (auto& record : rrset_with_rrsig.rrset)
result->add_record(move(record));
// Resolve with an empty success.
promise->resolve({});
}
return promise;
}
bool has_connection(bool attempt_restart = true)
@ -542,5 +1132,6 @@ private:
ConnectionMode m_mode { ConnectionMode::UDP };
Vector<NonnullRefPtr<Core::Promise<Empty>>> m_socket_ready_promises;
};
}
#undef TRY_OR_REJECT_PROMISE

View file

@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <LibCore/ElapsedTimer.h>
#include <LibCore/Promise.h>
#include <LibCrypto/OpenSSL.h>
#include <LibTLS/TLSv12.h>
@ -134,7 +135,22 @@ ErrorOr<bool> TLSv12::can_read_without_blocking(int timeout) const
if (!m_ssl)
return Error::from_string_literal("SSL connection is closed");
return m_socket->can_read_without_blocking(timeout);
if (SSL_has_pending(m_ssl))
return true;
while (timeout > 0) {
auto timer = Core::ElapsedTimer();
if (!TRY(m_socket->can_read_without_blocking(timeout)))
return SSL_has_pending(m_ssl);
if (SSL_has_pending(m_ssl))
return true;
auto elapsed = timer.elapsed_milliseconds();
if (elapsed >= timeout)
break;
timeout -= elapsed;
}
return false;
}
ErrorOr<void> TLSv12::set_blocking(bool)

View file

@ -55,7 +55,7 @@ public:
virtual void close() override;
virtual ErrorOr<size_t> pending_bytes() const override;
virtual ErrorOr<bool> can_read_without_blocking(int = 0) const override;
virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override;
virtual ErrorOr<void> set_blocking(bool block) override;
virtual ErrorOr<void> set_close_on_exec(bool enabled) override;

View file

@ -380,7 +380,7 @@ void ConnectionFromClient::start_request(i32 request_id, ByteString method, URL:
if (host.starts_with("["sv) && host.ends_with("]"sv))
host = host.substring(1, host.length() - 2);
m_resolver->dns.lookup(host, DNS::Messages::Class::IN, { DNS::Messages::ResourceType::A, DNS::Messages::ResourceType::AAAA })
m_resolver->dns.lookup(host, DNS::Messages::Class::IN, { DNS::Messages::ResourceType::A, DNS::Messages::ResourceType::AAAA }, {.validate_dnssec_locally = true})
->when_rejected([this, request_id](auto const& error) {
dbgln("StartRequest: DNS lookup failed: {}", error);
// FIXME: Implement timing info for DNS lookup failure.
@ -702,7 +702,7 @@ void ConnectionFromClient::ensure_connection(URL::URL url, ::RequestServer::Cach
}
if (cache_level == CacheLevel::ResolveOnly) {
[[maybe_unused]] auto promise = m_resolver->dns.lookup(url.serialized_host().to_byte_string(), DNS::Messages::Class::IN, { DNS::Messages::ResourceType::A, DNS::Messages::ResourceType::AAAA });
[[maybe_unused]] auto promise = m_resolver->dns.lookup(url.serialized_host().to_byte_string(), DNS::Messages::Class::IN, { DNS::Messages::ResourceType::A, DNS::Messages::ResourceType::AAAA }, {.validate_dnssec_locally = true});
if constexpr (REQUESTSERVER_DEBUG) {
Core::ElapsedTimer timer;
timer.start();

View file

@ -22,11 +22,13 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
StringView server_address;
StringView cert_path;
bool use_tls = false;
bool dnssec = false;
Core::ArgsParser args_parser;
args_parser.add_option(cert_path, "Path to the CA certificate file", "ca-certs", 'C', "file");
args_parser.add_option(server_address, "The address of the DNS server to query", "server", 's', "addr");
args_parser.add_option(use_tls, "Use TLS to connect to the server", "tls", 0);
args_parser.add_option(dnssec, "Validate DNSSEC records locally", "dnssec", 0);
args_parser.add_positional_argument(Core::ArgsParser::Arg {
.help_string = "The resource types and name of the DNS record to query",
.name = "rr,rr@name",
@ -105,7 +107,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
size_t pending_requests = requests.size();
for (auto& request : requests) {
resolver.lookup(request.name, DNS::Messages::Class::IN, request.types)
resolver.lookup(request.name, DNS::Messages::Class::IN, request.types, { .validate_dnssec_locally = dnssec })
->when_resolved([&](auto& result) {
outln("Resolved {}:", request.name);
HashTable<DNS::Messages::ResourceType> types;