LibDNS: Prefer spans over raw pointers when parsing DNS packets

This means we don't have to keep track of the pointer and size
separately.
This commit is contained in:
Tim Ledbetter 2023-11-02 20:04:11 +00:00 committed by Andreas Kling
parent c1d7a51391
commit 4e3b59a4bb
Notes: sideshowbarker 2024-07-17 02:05:41 +09:00
7 changed files with 23 additions and 25 deletions

View file

@ -21,14 +21,14 @@ Name::Name(DeprecatedString const& name)
m_name = name;
}
Name Name::parse(u8 const* data, size_t& offset, size_t max_offset, size_t recursion_level)
Name Name::parse(ReadonlyBytes data, size_t& offset, size_t recursion_level)
{
if (recursion_level > 4)
return {};
StringBuilder builder;
while (true) {
if (offset >= max_offset)
if (offset >= data.size())
return {};
u8 b = data[offset++];
if (b == '\0') {
@ -36,17 +36,17 @@ Name Name::parse(u8 const* data, size_t& offset, size_t max_offset, size_t recur
return builder.to_deprecated_string();
} else if ((b & 0xc0) == 0xc0) {
// The two bytes tell us the offset when to continue from.
if (offset >= max_offset)
if (offset >= data.size())
return {};
size_t dummy = (b & 0x3f) << 8 | data[offset++];
auto rest_of_name = parse(data, dummy, max_offset, recursion_level + 1);
auto rest_of_name = parse(data, dummy, recursion_level + 1);
builder.append(rest_of_name.as_string());
return builder.to_deprecated_string();
} else {
// This is the length of a part.
if (offset + b >= max_offset)
if (offset + b >= data.size())
return {};
builder.append((char const*)&data[offset], (size_t)b);
builder.append({ data.offset_pointer(offset), b });
builder.append('.');
offset += b;
}

View file

@ -17,7 +17,7 @@ public:
Name() = default;
Name(DeprecatedString const&);
static Name parse(u8 const* data, size_t& offset, size_t max_offset, size_t recursion_level = 0);
static Name parse(ReadonlyBytes data, size_t& offset, size_t recursion_level = 0);
size_t serialized_size() const;
DeprecatedString const& as_string() const { return m_name; }

View file

@ -97,14 +97,14 @@ private:
static_assert(sizeof(DNSRecordWithoutName) == 10);
Optional<Packet> Packet::from_raw_packet(u8 const* raw_data, size_t raw_size)
Optional<Packet> Packet::from_raw_packet(ReadonlyBytes bytes)
{
if (raw_size < sizeof(PacketHeader)) {
dbgln("DNS response not large enough ({} out of {}) to be a DNS packet.", raw_size, sizeof(PacketHeader));
if (bytes.size() < sizeof(PacketHeader)) {
dbgln("DNS response not large enough ({} out of {}) to be a DNS packet.", bytes.size(), sizeof(PacketHeader));
return {};
}
auto& header = *(PacketHeader const*)(raw_data);
auto const& header = *bit_cast<PacketHeader const*>(bytes.data());
dbgln_if(LOOKUPSERVER_DEBUG, "Got packet (ID: {})", header.id());
dbgln_if(LOOKUPSERVER_DEBUG, " Question count: {}", header.question_count());
dbgln_if(LOOKUPSERVER_DEBUG, " Answer count: {}", header.answer_count());
@ -123,12 +123,12 @@ Optional<Packet> Packet::from_raw_packet(u8 const* raw_data, size_t raw_size)
size_t offset = sizeof(PacketHeader);
for (u16 i = 0; i < header.question_count(); i++) {
auto name = Name::parse(raw_data, offset, raw_size);
auto name = Name::parse(bytes, offset);
struct RawDNSAnswerQuestion {
NetworkOrdered<u16> record_type;
NetworkOrdered<u16> class_code;
};
auto& record_and_class = *(RawDNSAnswerQuestion const*)&raw_data[offset];
auto const& record_and_class = *bit_cast<RawDNSAnswerQuestion const*>(bytes.offset_pointer(offset));
u16 class_code = record_and_class.class_code & ~MDNS_WANTS_UNICAST_RESPONSE;
bool mdns_wants_unicast_response = record_and_class.class_code & MDNS_WANTS_UNICAST_RESPONSE;
packet.m_questions.empend(name, (RecordType)(u16)record_and_class.record_type, (RecordClass)class_code, mdns_wants_unicast_response);
@ -138,18 +138,16 @@ Optional<Packet> Packet::from_raw_packet(u8 const* raw_data, size_t raw_size)
}
for (u16 i = 0; i < header.answer_count(); ++i) {
auto name = Name::parse(raw_data, offset, raw_size);
auto& record = *(DNSRecordWithoutName const*)(&raw_data[offset]);
auto name = Name::parse(bytes, offset);
auto const& record = *bit_cast<DNSRecordWithoutName const*>(bytes.offset_pointer(offset));
offset += sizeof(DNSRecordWithoutName);
DeprecatedString data;
offset += sizeof(DNSRecordWithoutName);
switch ((RecordType)record.type()) {
case RecordType::PTR: {
size_t dummy_offset = offset;
data = Name::parse(raw_data, dummy_offset, raw_size).as_string();
data = Name::parse(bytes, dummy_offset).as_string();
break;
}
case RecordType::CNAME:

View file

@ -24,7 +24,7 @@ class Packet {
public:
Packet() = default;
static Optional<Packet> from_raw_packet(u8 const*, size_t);
static Optional<Packet> from_raw_packet(ReadonlyBytes bytes);
ErrorOr<ByteBuffer> to_byte_buffer() const;
bool is_query() const { return !m_query_or_response; }

View file

@ -29,7 +29,7 @@ ErrorOr<void> DNSServer::handle_client()
{
sockaddr_in client_address;
auto buffer = TRY(receive(1024, client_address));
auto optional_request = Packet::from_raw_packet(buffer.data(), buffer.size());
auto optional_request = Packet::from_raw_packet(buffer);
if (!optional_request.has_value()) {
dbgln("Got an invalid DNS packet");
return {};

View file

@ -263,13 +263,13 @@ ErrorOr<Vector<Answer>> LookupServer::lookup(Name const& name, DeprecatedString
TRY(udp_socket->write_until_depleted(buffer));
u8 response_buffer[4096];
int nrecv = TRY(udp_socket->read_some({ response_buffer, sizeof(response_buffer) })).size();
auto nrecv = TRY(udp_socket->read_some({ response_buffer, sizeof(response_buffer) })).size();
if (udp_socket->is_eof())
return Vector<Answer> {};
did_get_response = true;
auto o_response = Packet::from_raw_packet(response_buffer, nrecv);
auto o_response = Packet::from_raw_packet({ response_buffer, nrecv });
if (!o_response.has_value())
return Vector<Answer> {};

View file

@ -56,7 +56,7 @@ MulticastDNS::MulticastDNS(Core::EventReceiver* parent)
ErrorOr<void> MulticastDNS::handle_packet()
{
auto buffer = TRY(receive(1024));
auto optional_packet = Packet::from_raw_packet(buffer.data(), buffer.size());
auto optional_packet = Packet::from_raw_packet(buffer);
if (!optional_packet.has_value()) {
dbgln("Got an invalid mDNS packet");
return {};
@ -175,7 +175,7 @@ ErrorOr<Vector<Answer>> MulticastDNS::lookup(Name const& name, RecordType record
auto buffer = TRY(receive(1024));
if (buffer.is_empty())
return Vector<Answer> {};
auto optional_packet = Packet::from_raw_packet(buffer.data(), buffer.size());
auto optional_packet = Packet::from_raw_packet(buffer);
if (!optional_packet.has_value()) {
dbgln("Got an invalid mDNS packet");
continue;