LibDNS: Respect records' TTL in the resolver cache

This commit is contained in:
Ali Mohammad Pur 2024-11-05 18:05:24 +01:00 committed by Ali Mohammad Pur
parent 879ae94183
commit 6911c45bab
Notes: github-actions[bot] 2024-11-20 20:44:10 +00:00
2 changed files with 60 additions and 8 deletions

View file

@ -12,6 +12,7 @@
#include <AK/Random.h>
#include <AK/StringView.h>
#include <AK/TemporaryChange.h>
#include <LibCore/DateTime.h>
#include <LibCore/Promise.h>
#include <LibCore/SocketAddress.h>
#include <LibDNS/Message.h>
@ -32,8 +33,8 @@ public:
Vector<Variant<IPv4Address, IPv6Address>> cached_addresses() const
{
Vector<Variant<IPv4Address, IPv6Address>> result;
for (auto& record : m_cached_records) {
record.record.visit(
for (auto& re : m_cached_records) {
re.record.record.visit(
[&](Messages::Records::A const& a) { result.append(a.address); },
[&](Messages::Records::AAAA const& aaaa) { result.append(aaaa.address); },
[](auto&) {});
@ -41,21 +42,49 @@ public:
return result;
}
void check_expiration()
{
if (!m_valid)
return;
auto now = Core::DateTime::now();
for (size_t i = 0; i < m_cached_records.size();) {
auto& record = m_cached_records[i];
if (record.expiration.has_value() && record.expiration.value() < now) {
dbgln("DNS: Removing expired record for {}", m_name.to_string());
m_cached_records.remove(i);
} else {
dbgln("DNS: Keeping record for {} (expires in {})", m_name.to_string(), record.expiration.has_value() ? record.expiration.value().to_string() : "never"_string);
++i;
}
}
if (m_cached_records.is_empty())
m_valid = false;
}
void add_record(Messages::ResourceRecord record)
{
m_valid = true;
m_cached_records.append(move(record));
auto expiration = record.ttl > 0 ? Optional<Core::DateTime>(Core::DateTime::from_timestamp(Core::DateTime::now().timestamp() + record.ttl)) : OptionalNone();
m_cached_records.append({ move(record), move(expiration) });
}
Vector<Messages::ResourceRecord> const& records() const { return m_cached_records; }
Vector<Messages::ResourceRecord> records() const
{
Vector<Messages::ResourceRecord> result;
for (auto& re : m_cached_records)
result.append(re.record);
return result;
}
bool has_record_of_type(Messages::ResourceType type, bool later = false) const
{
if (later && m_desired_types.contains(type))
return true;
for (auto const& record : m_cached_records) {
if (record.type == type)
for (auto const& re : m_cached_records) {
if (re.record.type == type)
return true;
}
return false;
@ -72,7 +101,11 @@ public:
private:
bool m_valid { false };
Messages::DomainName m_name;
Vector<Messages::ResourceRecord> m_cached_records;
struct RecordWithExpiration {
Messages::ResourceRecord record;
Optional<Core::DateTime> expiration;
};
Vector<RecordWithExpiration> m_cached_records;
HashTable<Messages::ResourceType> m_desired_types;
u16 m_id { 0 };
};
@ -157,6 +190,8 @@ public:
NonnullRefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> lookup(ByteString name, Messages::Class class_, Span<Messages::ResourceType const> desired_types)
{
flush_cache();
auto promise = Core::Promise<NonnullRefPtr<LookupResult const>>::construct();
if (auto result = lookup_in_cache(name, class_, desired_types)) {
@ -339,6 +374,9 @@ private:
if (!lookup)
return Error::from_string_literal("No pending lookup found for this message");
if (lookup->result.is_null())
return {}; // Message is a response to a lookup that's been purged from the cache, ignore it
auto result = lookup->result.strong_ref();
for (auto& record : message.answers)
result->add_record(move(record));
@ -392,6 +430,20 @@ private:
m_socket_ready_promises.clear();
}
void flush_cache()
{
m_cache.with_write_locked([&](auto& cache) {
HashTable<ByteString> to_remove;
for (auto& entry : cache) {
entry.value->check_expiration();
if (!entry.value->is_valid())
to_remove.set(entry.key);
}
for (auto const& key : to_remove)
cache.remove(key);
});
}
Threading::RWLockProtected<HashMap<ByteString, NonnullRefPtr<LookupResult>>> m_cache;
Threading::RWLockProtected<NonnullOwnPtr<RedBlackTree<u16, PendingLookup>>> m_pending_lookups;
Threading::RWLockProtected<Optional<MaybeOwned<Core::Socket>>> m_socket;

View file

@ -104,7 +104,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
->when_resolved([&](auto& result) {
outln("Resolved {}:", request.name);
HashTable<DNS::Messages::ResourceType> types;
auto& recs = result->records();
auto recs = result->records();
for (auto& record : recs)
types.set(record.type);