/* * Copyright (c) 2024, Ali Mohammad Pur * * SPDX-License-Identifier: BSD-2-Clause */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace DNS { class Resolver; class LookupResult : public AtomicRefCounted , public Weakable { public: explicit LookupResult(Messages::DomainName name) : m_name(move(name)) { } Vector> cached_addresses() const { Vector> result; for (auto& re : m_cached_records) { re.record.record.visit( [&](Messages::Records::A const& a) { result.append(a.address); }, [&](Messages::Records::AAAA const& aaaa) { result.append(aaaa.address); }, [](auto&) {}); } return result; } void check_expiration() { if (!m_valid) return; auto now = Core::DateTime::now(); for (size_t i = 0; i < m_cached_records.size();) { auto& record = m_cached_records[i]; if (record.expiration.has_value() && record.expiration.value() < now) { dbgln_if(DNS_DEBUG, "DNS: Removing expired record for {}", m_name.to_string()); m_cached_records.remove(i); } else { dbgln_if(DNS_DEBUG, "DNS: Keeping record for {} (expires in {})", m_name.to_string(), record.expiration.has_value() ? record.expiration.value().to_string() : "never"_string); ++i; } } if (m_cached_records.is_empty()) m_valid = false; } void add_record(Messages::ResourceRecord record) { m_valid = true; auto expiration = record.ttl > 0 ? Optional(Core::DateTime::from_timestamp(Core::DateTime::now().timestamp() + record.ttl)) : OptionalNone(); m_cached_records.append({ move(record), move(expiration) }); } Vector records() const { Vector result; for (auto& re : m_cached_records) result.append(re.record); return result; } bool has_record_of_type(Messages::ResourceType type, bool later = false) const { if (later && m_desired_types.contains(type)) return true; for (auto const& re : m_cached_records) { if (re.record.type == type) return true; } return false; } void will_add_record_of_type(Messages::ResourceType type) { m_desired_types.set(type); } void set_id(u16 id) { m_id = id; } u16 id() { return m_id; } bool is_valid() const { return m_valid; } Messages::DomainName const& name() const { return m_name; } private: bool m_valid { false }; Messages::DomainName m_name; struct RecordWithExpiration { Messages::ResourceRecord record; Optional expiration; }; Vector m_cached_records; HashTable m_desired_types; u16 m_id { 0 }; }; class Resolver { public: enum class ConnectionMode { TCP, UDP, }; struct SocketResult { MaybeOwned socket; ConnectionMode mode; }; Resolver(Function()> create_socket) : m_pending_lookups(make>()) , m_create_socket(move(create_socket)) { m_cache.with_write_locked([&](auto& cache) { auto add_v4v6_entry = [&cache](StringView name_string, IPv4Address v4, IPv6Address v6) { auto name = Messages::DomainName::from_string(name_string); auto ptr = make_ref_counted(name); ptr->will_add_record_of_type(Messages::ResourceType::A); ptr->will_add_record_of_type(Messages::ResourceType::AAAA); cache.set(name_string, ptr); ptr->add_record({ .name = {}, .type = Messages::ResourceType::A, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::A { v4 }, .raw = {} }); ptr->add_record({ .name = {}, .type = Messages::ResourceType::AAAA, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::AAAA { v6 }, .raw = {} }); }; add_v4v6_entry("localhost"sv, { 127, 0, 0, 1 }, IPv6Address::loopback()); }); } NonnullRefPtr> when_socket_ready() { auto promise = Core::Promise::construct(); m_socket_ready_promises.append(promise); if (has_connection(false)) { promise->resolve({}); return promise; } if (!has_connection()) promise->reject(Error::from_string_literal("Failed to create socket")); return promise; } void reset_connection() { m_socket.with_write_locked([&](auto& socket) { socket = {}; }); } NonnullRefPtr expect_cached(StringView name, Messages::Class class_ = Messages::Class::IN) { return expect_cached(name, class_, Array { Messages::ResourceType::A, Messages::ResourceType::AAAA }); } NonnullRefPtr expect_cached(StringView name, Messages::Class class_, Span desired_types) { auto result = lookup_in_cache(name, class_, desired_types); VERIFY(!result.is_null()); dbgln_if(DNS_DEBUG, "DNS::expect({}) -> OK", name); return *result; } RefPtr lookup_in_cache(StringView name, Messages::Class class_ = Messages::Class::IN) { return lookup_in_cache(name, class_, Array { Messages::ResourceType::A, Messages::ResourceType::AAAA }); } RefPtr lookup_in_cache(StringView name, Messages::Class, Span desired_types) { return m_cache.with_read_locked([&](auto& cache) -> RefPtr { auto it = cache.find(name); if (it == cache.end()) return {}; auto& result = *it->value; for (auto const& type : desired_types) { if (!result.has_record_of_type(type)) return {}; } return result; }); } NonnullRefPtr>> lookup(ByteString name, Messages::Class class_ = Messages::Class::IN) { return lookup(move(name), class_, Array { Messages::ResourceType::A, Messages::ResourceType::AAAA }); } NonnullRefPtr>> lookup(ByteString name, Messages::Class class_, Span desired_types) { flush_cache(); auto promise = Core::Promise>::construct(); if (auto maybe_ipv4 = IPv4Address::from_string(name); maybe_ipv4.has_value()) { if (desired_types.contains_slow(Messages::ResourceType::A)) { auto result = make_ref_counted(Messages::DomainName {}); result->add_record({ .name = {}, .type = Messages::ResourceType::A, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::A { maybe_ipv4.release_value() }, .raw = {} }); promise->resolve(move(result)); return promise; } } if (auto maybe_ipv6 = IPv6Address::from_string(name); maybe_ipv6.has_value()) { if (desired_types.contains_slow(Messages::ResourceType::AAAA)) { auto result = make_ref_counted(Messages::DomainName {}); result->add_record({ .name = {}, .type = Messages::ResourceType::AAAA, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::AAAA { maybe_ipv6.release_value() }, .raw = {} }); promise->resolve(move(result)); return promise; } } if (auto result = lookup_in_cache(name, class_, desired_types)) { promise->resolve(result.release_nonnull()); return promise; } auto domain_name = Messages::DomainName::from_string(name); if (!has_connection()) { // Use system resolver // FIXME: Use an underlying resolver instead. dbgln_if(DNS_DEBUG, "Not ready to resolve, using system resolver and skipping cache for {}", name); auto record_or_error = Core::Socket::resolve_host(name, Core::Socket::SocketType::Stream); if (record_or_error.is_error()) { promise->reject(record_or_error.release_error()); return promise; } auto result = make_ref_counted(domain_name); auto record = record_or_error.release_value(); record.visit( [&](IPv4Address const& address) { result->add_record({ .name = {}, .type = Messages::ResourceType::A, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::A { address }, .raw = {} }); }, [&](IPv6Address const& address) { result->add_record({ .name = {}, .type = Messages::ResourceType::AAAA, .class_ = Messages::Class::IN, .ttl = 0, .record = Messages::Records::AAAA { address }, .raw = {} }); }); promise->resolve(result); return promise; } auto already_in_cache = false; auto result = m_cache.with_write_locked([&](auto& cache) -> NonnullRefPtr { auto existing = [&] -> RefPtr { if (cache.contains(name)) { auto ptr = *cache.get(name); already_in_cache = true; for (auto const& type : desired_types) { if (!ptr->has_record_of_type(type, true)) { already_in_cache = false; break; } } return ptr; } return nullptr; }(); if (existing) return *existing; auto ptr = make_ref_counted(domain_name); for (auto const& type : desired_types) ptr->will_add_record_of_type(type); cache.set(name, ptr); return ptr; }); Optional cached_result_id; if (already_in_cache) { auto id = result->id(); cached_result_id = id; auto existing_promise = m_pending_lookups.with_write_locked([&](auto& lookups) -> RefPtr>> { if (auto* lookup = lookups->find(id)) return lookup->promise; return nullptr; }); if (existing_promise) return existing_promise.release_nonnull(); promise->resolve(*result); return promise; } Messages::Message query; m_pending_lookups.with_read_locked([&](auto& lookups) { do fill_with_random({ &query.header.id, sizeof(query.header.id) }); while (lookups->find(query.header.id) != nullptr); }); query.header.question_count = max(1u, desired_types.size()); query.header.options.set_response_code(Messages::Options::ResponseCode::NoError); query.header.options.set_recursion_desired(true); query.header.options.set_op_code(Messages::OpCode::Query); for (auto const& type : desired_types) { query.questions.append(Messages::Question { .name = domain_name, .type = type, .class_ = class_, }); } if (query.questions.is_empty()) { query.questions.append(Messages::Question { .name = Messages::DomainName::from_string(name), .type = Messages::ResourceType::A, .class_ = class_, }); } auto cached_entry = m_pending_lookups.with_write_locked([&](auto& pending_lookups) -> RefPtr>> { // One more try to make sure we're not overwriting an existing lookup if (cached_result_id.has_value()) { if (auto* lookup = pending_lookups->find(*cached_result_id)) return lookup->promise; } pending_lookups->insert(query.header.id, { query.header.id, name, result->make_weak_ptr(), promise }); return nullptr; }); if (cached_entry) { dbgln_if(DNS_DEBUG, "DNS::lookup({}) -> Already in cache", name); return cached_entry.release_nonnull(); } ByteBuffer query_bytes; MUST(query.to_raw(query_bytes)); if (m_mode == ConnectionMode::TCP) { auto original_query_bytes = query_bytes; query_bytes = MUST(ByteBuffer::create_uninitialized(query_bytes.size() + sizeof(u16))); NetworkOrdered size = original_query_bytes.size(); query_bytes.overwrite(0, &size, sizeof(size)); query_bytes.overwrite(sizeof(size), original_query_bytes.data(), original_query_bytes.size()); } auto write_result = m_socket.with_write_locked([&](auto& socket) { return (*socket)->write_until_depleted(query_bytes.bytes()); }); if (write_result.is_error()) { promise->reject(write_result.release_error()); return promise; } return promise; } private: struct PendingLookup { u16 id { 0 }; ByteString name; WeakPtr result; NonnullRefPtr>> promise; }; ErrorOr parse_one_message() { if (m_mode == ConnectionMode::UDP) return m_socket.with_write_locked([&](auto& socket) { return Messages::Message::from_raw(**socket); }); return m_socket.with_write_locked([&](auto& socket) -> ErrorOr { if (!TRY((*socket)->can_read_without_blocking())) return Error::from_errno(EAGAIN); auto size = TRY((*socket)->template read_value>()); auto buffer = TRY(ByteBuffer::create_uninitialized(size)); TRY((*socket)->read_until_filled(buffer)); FixedMemoryStream stream { static_cast(buffer) }; return Messages::Message::from_raw(stream); }); } void process_incoming_messages() { while (true) { if (auto result = m_socket.with_read_locked([](auto& socket) { return (*socket)->can_read_without_blocking(); }); result.is_error() || !result.value()) break; auto message_or_err = parse_one_message(); if (message_or_err.is_error()) { if (!message_or_err.error().is_errno() || message_or_err.error().code() != EAGAIN) dbgln("DNS: Failed to receive message: {}", message_or_err.error()); break; } auto message = message_or_err.release_value(); auto result = m_pending_lookups.with_write_locked([&](auto& lookups) -> ErrorOr { auto* lookup = lookups->find(message.header.id); if (!lookup) return Error::from_string_literal("No pending lookup found for this message"); if (lookup->result.is_null()) return {}; // Message is a response to a lookup that's been purged from the cache, ignore it auto result = lookup->result.strong_ref(); for (auto& record : message.answers) result->add_record(move(record)); lookup->promise->resolve(*result); lookups->remove(message.header.id); return {}; }); if (result.is_error()) { dbgln_if(DNS_DEBUG, "DNS: Received a message with no pending lookup: {}", result.error()); continue; } } } bool has_connection(bool attempt_restart = true) { auto result = m_socket.with_read_locked( [&](auto& socket) { return socket.has_value() && (*socket)->is_open(); }); if (attempt_restart && !result && !m_attempting_restart) { TemporaryChange change(m_attempting_restart, true); auto create_result = m_create_socket(); if (create_result.is_error()) { dbgln_if(DNS_DEBUG, "DNS: Failed to create socket: {}", create_result.error()); return false; } auto [socket, mode] = MUST(move(create_result)); set_socket(move(socket), mode); result = true; } return result; } void set_socket(MaybeOwned socket, ConnectionMode mode = ConnectionMode::UDP) { m_mode = mode; m_socket.with_write_locked([&](auto& s) { s = move(socket); (*s)->on_ready_to_read = [this] { process_incoming_messages(); }; (*s)->set_notifications_enabled(true); }); for (auto& promise : m_socket_ready_promises) promise->resolve({}); m_socket_ready_promises.clear(); } void flush_cache() { m_cache.with_write_locked([&](auto& cache) { HashTable to_remove; for (auto& entry : cache) { entry.value->check_expiration(); if (!entry.value->is_valid()) to_remove.set(entry.key); } for (auto const& key : to_remove) cache.remove(key); }); } Threading::RWLockProtected>> m_cache; Threading::RWLockProtected>> m_pending_lookups; Threading::RWLockProtected>> m_socket; Function()> m_create_socket; bool m_attempting_restart { false }; ConnectionMode m_mode { ConnectionMode::UDP }; Vector>> m_socket_ready_promises; }; }