mirror of
https://github.com/LadybirdBrowser/ladybird.git
synced 2024-11-22 07:30:19 +00:00
c5afe58540
In the future, we may want to parse /etc/hosts (or equivalent) into the cache; this commit only adds localhost to make the normal workflow work.
470 lines
17 KiB
C++
470 lines
17 KiB
C++
/*
|
|
* Copyright (c) 2024, Ali Mohammad Pur <mpfard@serenityos.org>
|
|
*
|
|
* SPDX-License-Identifier: BSD-2-Clause
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <AK/AtomicRefCounted.h>
|
|
#include <AK/HashTable.h>
|
|
#include <AK/MemoryStream.h>
|
|
#include <AK/Random.h>
|
|
#include <AK/StringView.h>
|
|
#include <AK/TemporaryChange.h>
|
|
#include <LibCore/DateTime.h>
|
|
#include <LibCore/Promise.h>
|
|
#include <LibCore/SocketAddress.h>
|
|
#include <LibDNS/Message.h>
|
|
#include <LibThreading/MutexProtected.h>
|
|
#include <LibThreading/RWLockProtected.h>
|
|
|
|
namespace DNS {
|
|
class Resolver;
|
|
|
|
class LookupResult : public AtomicRefCounted<LookupResult>
|
|
, public Weakable<LookupResult> {
|
|
public:
|
|
explicit LookupResult(Messages::DomainName name)
|
|
: m_name(move(name))
|
|
{
|
|
}
|
|
|
|
Vector<Variant<IPv4Address, IPv6Address>> cached_addresses() const
|
|
{
|
|
Vector<Variant<IPv4Address, IPv6Address>> result;
|
|
for (auto& re : m_cached_records) {
|
|
re.record.record.visit(
|
|
[&](Messages::Records::A const& a) { result.append(a.address); },
|
|
[&](Messages::Records::AAAA const& aaaa) { result.append(aaaa.address); },
|
|
[](auto&) {});
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void check_expiration()
|
|
{
|
|
if (!m_valid)
|
|
return;
|
|
|
|
auto now = Core::DateTime::now();
|
|
for (size_t i = 0; i < m_cached_records.size();) {
|
|
auto& record = m_cached_records[i];
|
|
if (record.expiration.has_value() && record.expiration.value() < now) {
|
|
dbgln_if(DNS_DEBUG, "DNS: Removing expired record for {}", m_name.to_string());
|
|
m_cached_records.remove(i);
|
|
} else {
|
|
dbgln_if(DNS_DEBUG, "DNS: Keeping record for {} (expires in {})", m_name.to_string(), record.expiration.has_value() ? record.expiration.value().to_string() : "never"_string);
|
|
++i;
|
|
}
|
|
}
|
|
|
|
if (m_cached_records.is_empty())
|
|
m_valid = false;
|
|
}
|
|
|
|
void add_record(Messages::ResourceRecord record)
|
|
{
|
|
m_valid = true;
|
|
auto expiration = record.ttl > 0 ? Optional<Core::DateTime>(Core::DateTime::from_timestamp(Core::DateTime::now().timestamp() + record.ttl)) : OptionalNone();
|
|
m_cached_records.append({ move(record), move(expiration) });
|
|
}
|
|
|
|
Vector<Messages::ResourceRecord> records() const
|
|
{
|
|
Vector<Messages::ResourceRecord> result;
|
|
for (auto& re : m_cached_records)
|
|
result.append(re.record);
|
|
return result;
|
|
}
|
|
|
|
bool has_record_of_type(Messages::ResourceType type, bool later = false) const
|
|
{
|
|
if (later && m_desired_types.contains(type))
|
|
return true;
|
|
|
|
for (auto const& re : m_cached_records) {
|
|
if (re.record.type == type)
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void will_add_record_of_type(Messages::ResourceType type) { m_desired_types.set(type); }
|
|
|
|
void set_id(u16 id) { m_id = id; }
|
|
u16 id() { return m_id; }
|
|
|
|
bool is_valid() const { return m_valid; }
|
|
Messages::DomainName const& name() const { return m_name; }
|
|
|
|
private:
|
|
bool m_valid { false };
|
|
Messages::DomainName m_name;
|
|
struct RecordWithExpiration {
|
|
Messages::ResourceRecord record;
|
|
Optional<Core::DateTime> expiration;
|
|
};
|
|
Vector<RecordWithExpiration> m_cached_records;
|
|
HashTable<Messages::ResourceType> m_desired_types;
|
|
u16 m_id { 0 };
|
|
};
|
|
|
|
class Resolver {
|
|
public:
|
|
enum class ConnectionMode {
|
|
TCP,
|
|
UDP,
|
|
};
|
|
|
|
struct SocketResult {
|
|
MaybeOwned<Core::Socket> socket;
|
|
ConnectionMode mode;
|
|
};
|
|
|
|
Resolver(Function<ErrorOr<SocketResult>()> create_socket)
|
|
: m_pending_lookups(make<RedBlackTree<u16, PendingLookup>>())
|
|
, m_create_socket(move(create_socket))
|
|
{
|
|
m_cache.with_write_locked([&](auto& cache) {
|
|
auto add_v4v6_entry = [&cache](StringView name_string, IPv4Address v4, IPv6Address v6) {
|
|
auto name = Messages::DomainName::from_string(name_string);
|
|
auto ptr = make_ref_counted<LookupResult>(name);
|
|
ptr->will_add_record_of_type(Messages::ResourceType::A);
|
|
ptr->will_add_record_of_type(Messages::ResourceType::AAAA);
|
|
cache.set(name_string, ptr);
|
|
|
|
ptr->add_record({ .name = {}, .type = Messages::ResourceType::A, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::A { v4 }, .raw = {} });
|
|
ptr->add_record({ .name = {}, .type = Messages::ResourceType::AAAA, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::AAAA { v6 }, .raw = {} });
|
|
};
|
|
|
|
add_v4v6_entry("localhost"sv, { 127, 0, 0, 1 }, IPv6Address::loopback());
|
|
});
|
|
}
|
|
|
|
NonnullRefPtr<Core::Promise<Empty>> when_socket_ready()
|
|
{
|
|
auto promise = Core::Promise<Empty>::construct();
|
|
m_socket_ready_promises.append(promise);
|
|
if (has_connection(false)) {
|
|
promise->resolve({});
|
|
return promise;
|
|
}
|
|
|
|
if (!has_connection())
|
|
promise->reject(Error::from_string_literal("Failed to create socket"));
|
|
|
|
return promise;
|
|
}
|
|
|
|
void reset_connection()
|
|
{
|
|
m_socket.with_write_locked([&](auto& socket) { socket = {}; });
|
|
}
|
|
|
|
NonnullRefPtr<LookupResult const> expect_cached(StringView name, Messages::Class class_ = Messages::Class::IN)
|
|
{
|
|
return expect_cached(name, class_, Array { Messages::ResourceType::A, Messages::ResourceType::AAAA });
|
|
}
|
|
|
|
NonnullRefPtr<LookupResult const> expect_cached(StringView name, Messages::Class class_, Span<Messages::ResourceType const> desired_types)
|
|
{
|
|
auto result = lookup_in_cache(name, class_, desired_types);
|
|
VERIFY(!result.is_null());
|
|
dbgln_if(DNS_DEBUG, "DNS::expect({}) -> OK", name);
|
|
return *result;
|
|
}
|
|
|
|
RefPtr<LookupResult const> lookup_in_cache(StringView name, Messages::Class class_ = Messages::Class::IN)
|
|
{
|
|
return lookup_in_cache(name, class_, Array { Messages::ResourceType::A, Messages::ResourceType::AAAA });
|
|
}
|
|
|
|
RefPtr<LookupResult const> lookup_in_cache(StringView name, Messages::Class, Span<Messages::ResourceType const> desired_types)
|
|
{
|
|
return m_cache.with_read_locked([&](auto& cache) -> RefPtr<LookupResult const> {
|
|
auto it = cache.find(name);
|
|
if (it == cache.end())
|
|
return {};
|
|
|
|
auto& result = *it->value;
|
|
for (auto const& type : desired_types) {
|
|
if (!result.has_record_of_type(type))
|
|
return {};
|
|
}
|
|
|
|
return result;
|
|
});
|
|
}
|
|
|
|
NonnullRefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> lookup(ByteString name, Messages::Class class_ = Messages::Class::IN)
|
|
{
|
|
return lookup(move(name), class_, Array { Messages::ResourceType::A, Messages::ResourceType::AAAA });
|
|
}
|
|
|
|
NonnullRefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> lookup(ByteString name, Messages::Class class_, Span<Messages::ResourceType const> desired_types)
|
|
{
|
|
flush_cache();
|
|
|
|
auto promise = Core::Promise<NonnullRefPtr<LookupResult const>>::construct();
|
|
|
|
if (auto result = lookup_in_cache(name, class_, desired_types)) {
|
|
promise->resolve(result.release_nonnull());
|
|
return promise;
|
|
}
|
|
|
|
auto domain_name = Messages::DomainName::from_string(name);
|
|
|
|
if (!has_connection()) {
|
|
// Use system resolver
|
|
// FIXME: Use an underlying resolver instead.
|
|
dbgln_if(DNS_DEBUG, "Not ready to resolve, using system resolver and skipping cache for {}", name);
|
|
auto record_or_error = Core::Socket::resolve_host(name, Core::Socket::SocketType::Stream);
|
|
if (record_or_error.is_error()) {
|
|
promise->reject(record_or_error.release_error());
|
|
return promise;
|
|
}
|
|
auto result = make_ref_counted<LookupResult>(domain_name);
|
|
auto record = record_or_error.release_value();
|
|
record.visit(
|
|
[&](IPv4Address const& address) {
|
|
result->add_record({ .name = {}, .type = Messages::ResourceType::A, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::A { address }, .raw = {} });
|
|
},
|
|
[&](IPv6Address const& address) {
|
|
result->add_record({ .name = {}, .type = Messages::ResourceType::AAAA, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::AAAA { address }, .raw = {} });
|
|
});
|
|
promise->resolve(result);
|
|
return promise;
|
|
}
|
|
|
|
auto already_in_cache = false;
|
|
auto result = m_cache.with_write_locked([&](auto& cache) -> NonnullRefPtr<LookupResult> {
|
|
auto existing = [&] -> RefPtr<LookupResult> {
|
|
if (cache.contains(name)) {
|
|
auto ptr = *cache.get(name);
|
|
|
|
already_in_cache = true;
|
|
for (auto const& type : desired_types) {
|
|
if (!ptr->has_record_of_type(type, true)) {
|
|
already_in_cache = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
return ptr;
|
|
}
|
|
return nullptr;
|
|
}();
|
|
|
|
if (existing)
|
|
return *existing;
|
|
|
|
auto ptr = make_ref_counted<LookupResult>(domain_name);
|
|
for (auto const& type : desired_types)
|
|
ptr->will_add_record_of_type(type);
|
|
cache.set(name, ptr);
|
|
return ptr;
|
|
});
|
|
|
|
Optional<u16> cached_result_id;
|
|
if (already_in_cache) {
|
|
auto id = result->id();
|
|
cached_result_id = id;
|
|
auto existing_promise = m_pending_lookups.with_write_locked([&](auto& lookups) -> RefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> {
|
|
if (auto* lookup = lookups->find(id))
|
|
return lookup->promise;
|
|
return nullptr;
|
|
});
|
|
if (existing_promise)
|
|
return existing_promise.release_nonnull();
|
|
|
|
promise->resolve(*result);
|
|
return promise;
|
|
}
|
|
|
|
Messages::Message query;
|
|
m_pending_lookups.with_read_locked([&](auto& lookups) {
|
|
do
|
|
fill_with_random({ &query.header.id, sizeof(query.header.id) });
|
|
while (lookups->find(query.header.id) != nullptr);
|
|
});
|
|
query.header.question_count = max(1u, desired_types.size());
|
|
query.header.options.set_response_code(Messages::Options::ResponseCode::NoError);
|
|
query.header.options.set_recursion_desired(true);
|
|
query.header.options.set_op_code(Messages::OpCode::Query);
|
|
for (auto const& type : desired_types) {
|
|
query.questions.append(Messages::Question {
|
|
.name = domain_name,
|
|
.type = type,
|
|
.class_ = class_,
|
|
});
|
|
}
|
|
|
|
if (query.questions.is_empty()) {
|
|
query.questions.append(Messages::Question {
|
|
.name = Messages::DomainName::from_string(name),
|
|
.type = Messages::ResourceType::A,
|
|
.class_ = class_,
|
|
});
|
|
}
|
|
|
|
auto cached_entry = m_pending_lookups.with_write_locked([&](auto& pending_lookups) -> RefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> {
|
|
// One more try to make sure we're not overwriting an existing lookup
|
|
if (cached_result_id.has_value()) {
|
|
if (auto* lookup = pending_lookups->find(*cached_result_id))
|
|
return lookup->promise;
|
|
}
|
|
|
|
pending_lookups->insert(query.header.id, { query.header.id, name, result->make_weak_ptr(), promise });
|
|
return nullptr;
|
|
});
|
|
if (cached_entry) {
|
|
dbgln_if(DNS_DEBUG, "DNS::lookup({}) -> Already in cache", name);
|
|
return cached_entry.release_nonnull();
|
|
}
|
|
|
|
ByteBuffer query_bytes;
|
|
MUST(query.to_raw(query_bytes));
|
|
|
|
if (m_mode == ConnectionMode::TCP) {
|
|
auto original_query_bytes = query_bytes;
|
|
query_bytes = MUST(ByteBuffer::create_uninitialized(query_bytes.size() + sizeof(u16)));
|
|
NetworkOrdered<u16> size = original_query_bytes.size();
|
|
query_bytes.overwrite(0, &size, sizeof(size));
|
|
query_bytes.overwrite(sizeof(size), original_query_bytes.data(), original_query_bytes.size());
|
|
}
|
|
|
|
auto write_result = m_socket.with_write_locked([&](auto& socket) {
|
|
return (*socket)->write_until_depleted(query_bytes.bytes());
|
|
});
|
|
if (write_result.is_error()) {
|
|
promise->reject(write_result.release_error());
|
|
return promise;
|
|
}
|
|
|
|
return promise;
|
|
}
|
|
|
|
private:
|
|
struct PendingLookup {
|
|
u16 id { 0 };
|
|
ByteString name;
|
|
WeakPtr<LookupResult> result;
|
|
NonnullRefPtr<Core::Promise<NonnullRefPtr<LookupResult const>>> promise;
|
|
};
|
|
|
|
ErrorOr<Messages::Message> parse_one_message()
|
|
{
|
|
if (m_mode == ConnectionMode::UDP)
|
|
return m_socket.with_write_locked([&](auto& socket) { return Messages::Message::from_raw(**socket); });
|
|
|
|
return m_socket.with_write_locked([&](auto& socket) -> ErrorOr<Messages::Message> {
|
|
if (!TRY((*socket)->can_read_without_blocking()))
|
|
return Error::from_errno(EAGAIN);
|
|
|
|
auto size = TRY((*socket)->template read_value<NetworkOrdered<u16>>());
|
|
auto buffer = TRY(ByteBuffer::create_uninitialized(size));
|
|
TRY((*socket)->read_until_filled(buffer));
|
|
FixedMemoryStream stream { static_cast<ReadonlyBytes>(buffer) };
|
|
return Messages::Message::from_raw(stream);
|
|
});
|
|
}
|
|
|
|
void process_incoming_messages()
|
|
{
|
|
while (true) {
|
|
if (auto result = m_socket.with_read_locked([](auto& socket) { return (*socket)->can_read_without_blocking(); }); result.is_error() || !result.value())
|
|
break;
|
|
auto message_or_err = parse_one_message();
|
|
if (message_or_err.is_error()) {
|
|
if (!message_or_err.error().is_errno() || message_or_err.error().code() != EAGAIN)
|
|
dbgln("DNS: Failed to receive message: {}", message_or_err.error());
|
|
break;
|
|
}
|
|
|
|
auto message = message_or_err.release_value();
|
|
auto result = m_pending_lookups.with_write_locked([&](auto& lookups) -> ErrorOr<void> {
|
|
auto* lookup = lookups->find(message.header.id);
|
|
if (!lookup)
|
|
return Error::from_string_literal("No pending lookup found for this message");
|
|
|
|
if (lookup->result.is_null())
|
|
return {}; // Message is a response to a lookup that's been purged from the cache, ignore it
|
|
|
|
auto result = lookup->result.strong_ref();
|
|
for (auto& record : message.answers)
|
|
result->add_record(move(record));
|
|
|
|
lookup->promise->resolve(*result);
|
|
lookups->remove(message.header.id);
|
|
return {};
|
|
});
|
|
if (result.is_error()) {
|
|
dbgln_if(DNS_DEBUG, "DNS: Received a message with no pending lookup: {}", result.error());
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
bool has_connection(bool attempt_restart = true)
|
|
{
|
|
auto result = m_socket.with_read_locked(
|
|
[&](auto& socket) { return socket.has_value() && (*socket)->is_open(); });
|
|
|
|
if (attempt_restart && !result && !m_attempting_restart) {
|
|
TemporaryChange change(m_attempting_restart, true);
|
|
auto create_result = m_create_socket();
|
|
if (create_result.is_error()) {
|
|
dbgln_if(DNS_DEBUG, "DNS: Failed to create socket: {}", create_result.error());
|
|
return false;
|
|
}
|
|
|
|
auto [socket, mode] = MUST(move(create_result));
|
|
set_socket(move(socket), mode);
|
|
result = true;
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
void set_socket(MaybeOwned<Core::Socket> socket, ConnectionMode mode = ConnectionMode::UDP)
|
|
{
|
|
m_mode = mode;
|
|
m_socket.with_write_locked([&](auto& s) {
|
|
s = move(socket);
|
|
(*s)->on_ready_to_read = [this] {
|
|
process_incoming_messages();
|
|
};
|
|
(*s)->set_notifications_enabled(true);
|
|
});
|
|
|
|
for (auto& promise : m_socket_ready_promises)
|
|
promise->resolve({});
|
|
|
|
m_socket_ready_promises.clear();
|
|
}
|
|
|
|
void flush_cache()
|
|
{
|
|
m_cache.with_write_locked([&](auto& cache) {
|
|
HashTable<ByteString> to_remove;
|
|
for (auto& entry : cache) {
|
|
entry.value->check_expiration();
|
|
if (!entry.value->is_valid())
|
|
to_remove.set(entry.key);
|
|
}
|
|
for (auto const& key : to_remove)
|
|
cache.remove(key);
|
|
});
|
|
}
|
|
|
|
Threading::RWLockProtected<HashMap<ByteString, NonnullRefPtr<LookupResult>>> m_cache;
|
|
Threading::RWLockProtected<NonnullOwnPtr<RedBlackTree<u16, PendingLookup>>> m_pending_lookups;
|
|
Threading::RWLockProtected<Optional<MaybeOwned<Core::Socket>>> m_socket;
|
|
Function<ErrorOr<SocketResult>()> m_create_socket;
|
|
bool m_attempting_restart { false };
|
|
ConnectionMode m_mode { ConnectionMode::UDP };
|
|
Vector<NonnullRefPtr<Core::Promise<Empty>>> m_socket_ready_promises;
|
|
};
|
|
|
|
}
|