From 6911c45bab4ff512c63c5b8156b25c7d22295cdb Mon Sep 17 00:00:00 2001 From: Ali Mohammad Pur Date: Tue, 5 Nov 2024 18:05:24 +0100 Subject: [PATCH] LibDNS: Respect records' TTL in the resolver cache --- Libraries/LibDNS/Resolver.h | 66 +++++++++++++++++++++++++++++++++---- Utilities/dns.cpp | 2 +- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/Libraries/LibDNS/Resolver.h b/Libraries/LibDNS/Resolver.h index 422f336767e..6d2fb52e2b3 100644 --- a/Libraries/LibDNS/Resolver.h +++ b/Libraries/LibDNS/Resolver.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -32,8 +33,8 @@ public: Vector> cached_addresses() const { Vector> result; - for (auto& record : m_cached_records) { - record.record.visit( + for (auto& re : m_cached_records) { + re.record.record.visit( [&](Messages::Records::A const& a) { result.append(a.address); }, [&](Messages::Records::AAAA const& aaaa) { result.append(aaaa.address); }, [](auto&) {}); @@ -41,21 +42,49 @@ public: return result; } + void check_expiration() + { + if (!m_valid) + return; + + auto now = Core::DateTime::now(); + for (size_t i = 0; i < m_cached_records.size();) { + auto& record = m_cached_records[i]; + if (record.expiration.has_value() && record.expiration.value() < now) { + dbgln("DNS: Removing expired record for {}", m_name.to_string()); + m_cached_records.remove(i); + } else { + dbgln("DNS: Keeping record for {} (expires in {})", m_name.to_string(), record.expiration.has_value() ? record.expiration.value().to_string() : "never"_string); + ++i; + } + } + + if (m_cached_records.is_empty()) + m_valid = false; + } + void add_record(Messages::ResourceRecord record) { m_valid = true; - m_cached_records.append(move(record)); + auto expiration = record.ttl > 0 ? Optional(Core::DateTime::from_timestamp(Core::DateTime::now().timestamp() + record.ttl)) : OptionalNone(); + m_cached_records.append({ move(record), move(expiration) }); } - Vector const& records() const { return m_cached_records; } + Vector records() const + { + Vector result; + for (auto& re : m_cached_records) + result.append(re.record); + return result; + } bool has_record_of_type(Messages::ResourceType type, bool later = false) const { if (later && m_desired_types.contains(type)) return true; - for (auto const& record : m_cached_records) { - if (record.type == type) + for (auto const& re : m_cached_records) { + if (re.record.type == type) return true; } return false; @@ -72,7 +101,11 @@ public: private: bool m_valid { false }; Messages::DomainName m_name; - Vector m_cached_records; + struct RecordWithExpiration { + Messages::ResourceRecord record; + Optional expiration; + }; + Vector m_cached_records; HashTable m_desired_types; u16 m_id { 0 }; }; @@ -157,6 +190,8 @@ public: NonnullRefPtr>> lookup(ByteString name, Messages::Class class_, Span desired_types) { + flush_cache(); + auto promise = Core::Promise>::construct(); if (auto result = lookup_in_cache(name, class_, desired_types)) { @@ -339,6 +374,9 @@ private: if (!lookup) return Error::from_string_literal("No pending lookup found for this message"); + if (lookup->result.is_null()) + return {}; // Message is a response to a lookup that's been purged from the cache, ignore it + auto result = lookup->result.strong_ref(); for (auto& record : message.answers) result->add_record(move(record)); @@ -392,6 +430,20 @@ private: m_socket_ready_promises.clear(); } + void flush_cache() + { + m_cache.with_write_locked([&](auto& cache) { + HashTable to_remove; + for (auto& entry : cache) { + entry.value->check_expiration(); + if (!entry.value->is_valid()) + to_remove.set(entry.key); + } + for (auto const& key : to_remove) + cache.remove(key); + }); + } + Threading::RWLockProtected>> m_cache; Threading::RWLockProtected>> m_pending_lookups; Threading::RWLockProtected>> m_socket; diff --git a/Utilities/dns.cpp b/Utilities/dns.cpp index 419169052ab..c6b28d7b83c 100644 --- a/Utilities/dns.cpp +++ b/Utilities/dns.cpp @@ -104,7 +104,7 @@ ErrorOr serenity_main(Main::Arguments arguments) ->when_resolved([&](auto& result) { outln("Resolved {}:", request.name); HashTable types; - auto& recs = result->records(); + auto recs = result->records(); for (auto& record : recs) types.set(record.type);