This commit is contained in:
Ali Mohammad Pur 2024-12-17 14:37:30 +01:00
commit 8c3b3a4062
2 changed files with 121 additions and 19 deletions

View file

@ -24,6 +24,42 @@ public:
Function<ErrorOr<void>(Result&)> on_resolution; Function<ErrorOr<void>(Result&)> on_resolution;
Function<void(ErrorType&)> on_rejection; Function<void(ErrorType&)> on_rejection;
static NonnullRefPtr<Promise> after(Vector<NonnullRefPtr<Promise>>&& promises)
{
auto promise = Promise::construct();
struct Resolved : RefCounted<Resolved> {
explicit Resolved(size_t n)
: needed(n)
{
}
size_t count { 0 };
size_t needed { 0 };
};
auto resolved = make_ref_counted<Resolved>(promises.size());
auto weak_promise = promise->template make_weak_ptr<Promise>();
for (auto p : promises) {
p->when_resolved([weak_promise, resolved](Result&) -> ErrorOr<void> {
if (weak_promise->is_rejected())
return {};
if (++resolved->count == resolved->needed)
weak_promise->resolve({});
return {};
});
p->when_rejected([weak_promise, resolved](ErrorType& error) {
++resolved->count;
weak_promise->reject(move(error));
});
promise->add_child(*p);
}
return promise;
}
void resolve(Result&& result) void resolve(Result&& result)
{ {
m_result_or_rejection = move(result); m_result_or_rejection = move(result);
@ -82,7 +118,7 @@ public:
template<CallableAs<void, Result&> F> template<CallableAs<void, Result&> F>
Promise& when_resolved(F handler) Promise& when_resolved(F handler)
{ {
return when_resolved([handler = move(handler)](Result& result) -> ErrorOr<void> { return when_resolved([handler = move(handler)](Result& result) mutable -> ErrorOr<void> {
handler(result); handler(result);
return {}; return {};
}); });

View file

@ -17,10 +17,15 @@
#include <LibCore/Promise.h> #include <LibCore/Promise.h>
#include <LibCore/Socket.h> #include <LibCore/Socket.h>
#include <LibCore/Timer.h> #include <LibCore/Timer.h>
#include <LibCrypto/Certificate/Certificate.h>
#include <LibCrypto/PK/RSA.h>
#include <LibDNS/Message.h> #include <LibDNS/Message.h>
#include <LibThreading/MutexProtected.h> #include <LibThreading/MutexProtected.h>
#include <LibThreading/RWLockProtected.h> #include <LibThreading/RWLockProtected.h>
#undef DNS_DEBUG
#define DNS_DEBUG 1
namespace DNS { namespace DNS {
class Resolver; class Resolver;
@ -528,7 +533,7 @@ private:
auto result = lookup->result.strong_ref(); auto result = lookup->result.strong_ref();
if (result->is_dnssec_validated()) if (result->is_dnssec_validated())
return validate_dnssec(move(message), *lookup); return validate_dnssec(move(message), *lookup, *result);
for (auto& record : message.answers) for (auto& record : message.answers)
result->add_record(move(record)); result->add_record(move(record));
@ -543,25 +548,25 @@ private:
} }
} }
ErrorOr<void> validate_dnssec(Messages::Message message, PendingLookup& lookup) ErrorOr<void> validate_dnssec(Messages::Message message, PendingLookup& lookup, NonnullRefPtr<LookupResult> result)
{ {
struct RecordAndRRSIG { struct RecordAndRRSIG {
Messages::ResourceRecord* record { nullptr }; Optional<Messages::ResourceRecord> record;
Messages::Records::RRSIG* rrsig { nullptr }; Messages::Records::RRSIG rrsig;
}; };
HashMap<Messages::ResourceType, RecordAndRRSIG> records_with_rrsigs; HashMap<Messages::ResourceType, RecordAndRRSIG> records_with_rrsigs;
for (auto& record : message.answers) { for (auto& record : message.answers) {
if (record.type == Messages::ResourceType::RRSIG) { if (record.type == Messages::ResourceType::RRSIG) {
auto& rrsig = record.record.get<Messages::Records::RRSIG>(); auto& rrsig = record.record.get<Messages::Records::RRSIG>();
if (auto found = records_with_rrsigs.get(rrsig.type_covered); found.has_value()) if (auto found = records_with_rrsigs.get(rrsig.type_covered); found.has_value())
found->rrsig = &rrsig; found->rrsig = move(rrsig);
else else
records_with_rrsigs.set(rrsig.type_covered, { nullptr, &rrsig }); records_with_rrsigs.set(rrsig.type_covered, { {}, move(rrsig) });
} else { } else {
if (auto found = records_with_rrsigs.get(record.type); found.has_value()) if (auto found = records_with_rrsigs.get(record.type); found.has_value())
found->record = &record; found->record = move(record);
else else
records_with_rrsigs.set(record.type, { &record, nullptr }); records_with_rrsigs.set(record.type, { move(record), {} });
} }
} }
@ -570,35 +575,58 @@ private:
return {}; return {};
} }
auto name = lookup.result->name(); auto name = result->name();
Core::deferred_invoke([this, lookup, name, records_with_rrsigs = move(records_with_rrsigs)] { Core::deferred_invoke([this, lookup, name, records_with_rrsigs = move(records_with_rrsigs), result = move(result)] mutable {
dbgln("DNS: Resolving DNSKEY for {}", name.to_string()); dbgln("DNS: Resolving DNSKEY for {}", name.to_string());
this->lookup(lookup.name, Messages::Class::IN, Array { Messages::ResourceType::DNSKEY }.span(), { .validate_dnssec_locally = false }) this->lookup(lookup.name, Messages::Class::IN, Array { Messages::ResourceType::DNSKEY }.span(), { .validate_dnssec_locally = false })
->when_resolved([=](NonnullRefPtr<LookupResult const>& dnskey_lookup_result) { ->when_resolved([=, this, records_with_rrsigs = move(records_with_rrsigs)](NonnullRefPtr<LookupResult const>& dnskey_lookup_result) mutable {
dbgln("DNSKEY for {}:", name.to_string()); dbgln("DNSKEY for {}:", name.to_string());
for (auto& record : dnskey_lookup_result->records()) for (auto& record : dnskey_lookup_result->records())
dbgln("DNSKEY: {}", record.to_string()); dbgln("DNSKEY: {}", record.to_string());
dbgln("DNS: Validating {} RRSIGs for {}", records_with_rrsigs.size(), name.to_string());
Vector<NonnullRefPtr<Core::Promise<Empty>>> promises;
for (auto& [type, pair] : records_with_rrsigs) { for (auto& [type, pair] : records_with_rrsigs) {
if (!pair.record) if (!pair.record.has_value())
continue; continue;
auto& record = *pair.record; auto& record = *pair.record;
auto& rrsig = *pair.rrsig; auto& rrsig = pair.rrsig;
dbgln("Validating RRSIG for {} with DNSKEY", record.to_string()); dbgln("Validating RRSIG for {} with DNSKEY", record.to_string());
auto dnskey = [&] { auto dnskey = [&] -> Optional<Messages::Records::DNSKEY> {
for (auto& dnskey_record : dnskey_lookup_result->records()) { for (auto& dnskey_record : dnskey_lookup_result->records()) {
if (dnskey_record.type == Messages::ResourceType::DNSKEY) if (auto r = dnskey_record.record.get_pointer<Messages::Records::DNSKEY>())
return dnskey_record.record.get<Messages::Records::DNSKEY>(); return *r;
} }
return Messages::Records::DNSKEY {}; return {};
}(); }();
dbgln("DNSKEY: {}", dnskey.to_string()); if (!dnskey.has_value()) {
dbgln("RRSIG: {}", rrsig.to_string()); dbgln("DNS: No DNSKEY found for RRSIG validation of {}", record.to_string());
continue;
} }
dbgln("DNSKEY: {}", dnskey->to_string());
dbgln("RRSIG: {}", rrsig.to_string());
promises.append(validate_record_with_rrsig(record, rrsig, *dnskey, result));
}
auto promise = Core::Promise<Empty>::after(move(promises))
->when_resolved([result, lookup](Empty) {
result->set_dnssec_validated(true);
result->finished_request();
lookup.promise->resolve(result);
})
.when_rejected([result, lookup](Error& error) {
result->finished_request();
lookup.promise->reject(move(error));
}).map<NonnullRefPtr<LookupResult const>>([result](Empty&) { return result; });
lookup.promise = move(promise);
}) })
.when_rejected([=](auto& error) { .when_rejected([=](auto& error) {
dbgln("Failed to resolve DNSKEY for {}: {}", name.to_string(), error); dbgln("Failed to resolve DNSKEY for {}: {}", name.to_string(), error);
@ -609,6 +637,44 @@ private:
return {}; return {};
} }
NonnullRefPtr<Core::Promise<Empty>> validate_record_with_rrsig(Messages::ResourceRecord const& record, Messages::Records::RRSIG const& rrsig, Messages::Records::DNSKEY const& dnskey, NonnullRefPtr<LookupResult> result)
{
dbgln("Validating RRSIG {} for RR: {} with DNSKEY: {}", rrsig.to_string(), record.to_string(), dnskey.to_string());
auto promise = Core::Promise<Empty>::construct();
switch (dnskey.algorithm) {
case Messages::DNSSEC::Algorithm::RSAMD5: {
Crypto::PK::RSA rsa { dnskey.public_key };
break;
}
case Messages::DNSSEC::Algorithm::DSA:
break;
case Messages::DNSSEC::Algorithm::RSASHA1:
break;
case Messages::DNSSEC::Algorithm::RSASHA256:
break;
case Messages::DNSSEC::Algorithm::RSASHA512:
break;
case Messages::DNSSEC::Algorithm::ECDSAP256SHA256: {
auto key = MUST(Crypto::PK::EC::parse_ec_key(dnskey.public_key, false, {}));
Crypto::PK::EC ec { key };
break;
}
case Messages::DNSSEC::Algorithm::ECDSAP384SHA384:
break;
case Messages::DNSSEC::Algorithm::RSASHA1NSEC3SHA1:
case Messages::DNSSEC::Algorithm::ED25519:
case Messages::DNSSEC::Algorithm::Unknown:
dbgln("DNS: Unsupported algorithm for DNSSEC validation: {}", to_string(dnskey.algorithm));
promise->reject(Error::from_string_literal("Unsupported algorithm for DNSSEC validation"));
return promise;
}
result->add_record(record);
promise->resolve({});
return promise;
}
bool has_connection(bool attempt_restart = true) bool has_connection(bool attempt_restart = true)
{ {
auto result = m_socket.with_read_locked( auto result = m_socket.with_read_locked(