This commit is contained in:
Ali Mohammad Pur 2024-12-17 14:37:30 +01:00
commit 8c3b3a4062
2 changed files with 121 additions and 19 deletions

View file

@ -24,6 +24,42 @@ public:
Function<ErrorOr<void>(Result&)> on_resolution;
Function<void(ErrorType&)> on_rejection;
static NonnullRefPtr<Promise> after(Vector<NonnullRefPtr<Promise>>&& promises)
{
auto promise = Promise::construct();
struct Resolved : RefCounted<Resolved> {
explicit Resolved(size_t n)
: needed(n)
{
}
size_t count { 0 };
size_t needed { 0 };
};
auto resolved = make_ref_counted<Resolved>(promises.size());
auto weak_promise = promise->template make_weak_ptr<Promise>();
for (auto p : promises) {
p->when_resolved([weak_promise, resolved](Result&) -> ErrorOr<void> {
if (weak_promise->is_rejected())
return {};
if (++resolved->count == resolved->needed)
weak_promise->resolve({});
return {};
});
p->when_rejected([weak_promise, resolved](ErrorType& error) {
++resolved->count;
weak_promise->reject(move(error));
});
promise->add_child(*p);
}
return promise;
}
void resolve(Result&& result)
{
m_result_or_rejection = move(result);
@ -82,7 +118,7 @@ public:
template<CallableAs<void, Result&> F>
Promise& when_resolved(F handler)
{
return when_resolved([handler = move(handler)](Result& result) -> ErrorOr<void> {
return when_resolved([handler = move(handler)](Result& result) mutable -> ErrorOr<void> {
handler(result);
return {};
});

View file

@ -17,10 +17,15 @@
#include <LibCore/Promise.h>
#include <LibCore/Socket.h>
#include <LibCore/Timer.h>
#include <LibCrypto/Certificate/Certificate.h>
#include <LibCrypto/PK/RSA.h>
#include <LibDNS/Message.h>
#include <LibThreading/MutexProtected.h>
#include <LibThreading/RWLockProtected.h>
#undef DNS_DEBUG
#define DNS_DEBUG 1
namespace DNS {
class Resolver;
@ -528,7 +533,7 @@ private:
auto result = lookup->result.strong_ref();
if (result->is_dnssec_validated())
return validate_dnssec(move(message), *lookup);
return validate_dnssec(move(message), *lookup, *result);
for (auto& record : message.answers)
result->add_record(move(record));
@ -543,25 +548,25 @@ private:
}
}
ErrorOr<void> validate_dnssec(Messages::Message message, PendingLookup& lookup)
ErrorOr<void> validate_dnssec(Messages::Message message, PendingLookup& lookup, NonnullRefPtr<LookupResult> result)
{
struct RecordAndRRSIG {
Messages::ResourceRecord* record { nullptr };
Messages::Records::RRSIG* rrsig { nullptr };
Optional<Messages::ResourceRecord> record;
Messages::Records::RRSIG rrsig;
};
HashMap<Messages::ResourceType, RecordAndRRSIG> records_with_rrsigs;
for (auto& record : message.answers) {
if (record.type == Messages::ResourceType::RRSIG) {
auto& rrsig = record.record.get<Messages::Records::RRSIG>();
if (auto found = records_with_rrsigs.get(rrsig.type_covered); found.has_value())
found->rrsig = &rrsig;
found->rrsig = move(rrsig);
else
records_with_rrsigs.set(rrsig.type_covered, { nullptr, &rrsig });
records_with_rrsigs.set(rrsig.type_covered, { {}, move(rrsig) });
} else {
if (auto found = records_with_rrsigs.get(record.type); found.has_value())
found->record = &record;
found->record = move(record);
else
records_with_rrsigs.set(record.type, { &record, nullptr });
records_with_rrsigs.set(record.type, { move(record), {} });
}
}
@ -570,35 +575,58 @@ private:
return {};
}
auto name = lookup.result->name();
auto name = result->name();
Core::deferred_invoke([this, lookup, name, records_with_rrsigs = move(records_with_rrsigs)] {
Core::deferred_invoke([this, lookup, name, records_with_rrsigs = move(records_with_rrsigs), result = move(result)] mutable {
dbgln("DNS: Resolving DNSKEY for {}", name.to_string());
this->lookup(lookup.name, Messages::Class::IN, Array { Messages::ResourceType::DNSKEY }.span(), { .validate_dnssec_locally = false })
->when_resolved([=](NonnullRefPtr<LookupResult const>& dnskey_lookup_result) {
->when_resolved([=, this, records_with_rrsigs = move(records_with_rrsigs)](NonnullRefPtr<LookupResult const>& dnskey_lookup_result) mutable {
dbgln("DNSKEY for {}:", name.to_string());
for (auto& record : dnskey_lookup_result->records())
dbgln("DNSKEY: {}", record.to_string());
dbgln("DNS: Validating {} RRSIGs for {}", records_with_rrsigs.size(), name.to_string());
Vector<NonnullRefPtr<Core::Promise<Empty>>> promises;
for (auto& [type, pair] : records_with_rrsigs) {
if (!pair.record)
if (!pair.record.has_value())
continue;
auto& record = *pair.record;
auto& rrsig = *pair.rrsig;
auto& rrsig = pair.rrsig;
dbgln("Validating RRSIG for {} with DNSKEY", record.to_string());
auto dnskey = [&] {
auto dnskey = [&] -> Optional<Messages::Records::DNSKEY> {
for (auto& dnskey_record : dnskey_lookup_result->records()) {
if (dnskey_record.type == Messages::ResourceType::DNSKEY)
return dnskey_record.record.get<Messages::Records::DNSKEY>();
if (auto r = dnskey_record.record.get_pointer<Messages::Records::DNSKEY>())
return *r;
}
return Messages::Records::DNSKEY {};
return {};
}();
dbgln("DNSKEY: {}", dnskey.to_string());
dbgln("RRSIG: {}", rrsig.to_string());
if (!dnskey.has_value()) {
dbgln("DNS: No DNSKEY found for RRSIG validation of {}", record.to_string());
continue;
}
dbgln("DNSKEY: {}", dnskey->to_string());
dbgln("RRSIG: {}", rrsig.to_string());
promises.append(validate_record_with_rrsig(record, rrsig, *dnskey, result));
}
auto promise = Core::Promise<Empty>::after(move(promises))
->when_resolved([result, lookup](Empty) {
result->set_dnssec_validated(true);
result->finished_request();
lookup.promise->resolve(result);
})
.when_rejected([result, lookup](Error& error) {
result->finished_request();
lookup.promise->reject(move(error));
}).map<NonnullRefPtr<LookupResult const>>([result](Empty&) { return result; });
lookup.promise = move(promise);
})
.when_rejected([=](auto& error) {
dbgln("Failed to resolve DNSKEY for {}: {}", name.to_string(), error);
@ -609,6 +637,44 @@ private:
return {};
}
NonnullRefPtr<Core::Promise<Empty>> validate_record_with_rrsig(Messages::ResourceRecord const& record, Messages::Records::RRSIG const& rrsig, Messages::Records::DNSKEY const& dnskey, NonnullRefPtr<LookupResult> result)
{
dbgln("Validating RRSIG {} for RR: {} with DNSKEY: {}", rrsig.to_string(), record.to_string(), dnskey.to_string());
auto promise = Core::Promise<Empty>::construct();
switch (dnskey.algorithm) {
case Messages::DNSSEC::Algorithm::RSAMD5: {
Crypto::PK::RSA rsa { dnskey.public_key };
break;
}
case Messages::DNSSEC::Algorithm::DSA:
break;
case Messages::DNSSEC::Algorithm::RSASHA1:
break;
case Messages::DNSSEC::Algorithm::RSASHA256:
break;
case Messages::DNSSEC::Algorithm::RSASHA512:
break;
case Messages::DNSSEC::Algorithm::ECDSAP256SHA256: {
auto key = MUST(Crypto::PK::EC::parse_ec_key(dnskey.public_key, false, {}));
Crypto::PK::EC ec { key };
break;
}
case Messages::DNSSEC::Algorithm::ECDSAP384SHA384:
break;
case Messages::DNSSEC::Algorithm::RSASHA1NSEC3SHA1:
case Messages::DNSSEC::Algorithm::ED25519:
case Messages::DNSSEC::Algorithm::Unknown:
dbgln("DNS: Unsupported algorithm for DNSSEC validation: {}", to_string(dnskey.algorithm));
promise->reject(Error::from_string_literal("Unsupported algorithm for DNSSEC validation"));
return promise;
}
result->add_record(record);
promise->resolve({});
return promise;
}
bool has_connection(bool attempt_restart = true)
{
auto result = m_socket.with_read_locked(