Browse Source

LibDNS: Prefer spans over raw pointers when parsing DNS packets

This means we don't have to keep track of the pointer and size
separately.
Tim Ledbetter 1 year ago
parent
commit
4e3b59a4bb

+ 6 - 6
Userland/Libraries/LibDNS/Name.cpp

@@ -21,14 +21,14 @@ Name::Name(DeprecatedString const& name)
         m_name = name;
 }
 
-Name Name::parse(u8 const* data, size_t& offset, size_t max_offset, size_t recursion_level)
+Name Name::parse(ReadonlyBytes data, size_t& offset, size_t recursion_level)
 {
     if (recursion_level > 4)
         return {};
 
     StringBuilder builder;
     while (true) {
-        if (offset >= max_offset)
+        if (offset >= data.size())
             return {};
         u8 b = data[offset++];
         if (b == '\0') {
@@ -36,17 +36,17 @@ Name Name::parse(u8 const* data, size_t& offset, size_t max_offset, size_t recur
             return builder.to_deprecated_string();
         } else if ((b & 0xc0) == 0xc0) {
             // The two bytes tell us the offset when to continue from.
-            if (offset >= max_offset)
+            if (offset >= data.size())
                 return {};
             size_t dummy = (b & 0x3f) << 8 | data[offset++];
-            auto rest_of_name = parse(data, dummy, max_offset, recursion_level + 1);
+            auto rest_of_name = parse(data, dummy, recursion_level + 1);
             builder.append(rest_of_name.as_string());
             return builder.to_deprecated_string();
         } else {
             // This is the length of a part.
-            if (offset + b >= max_offset)
+            if (offset + b >= data.size())
                 return {};
-            builder.append((char const*)&data[offset], (size_t)b);
+            builder.append({ data.offset_pointer(offset), b });
             builder.append('.');
             offset += b;
         }

+ 1 - 1
Userland/Libraries/LibDNS/Name.h

@@ -17,7 +17,7 @@ public:
     Name() = default;
     Name(DeprecatedString const&);
 
-    static Name parse(u8 const* data, size_t& offset, size_t max_offset, size_t recursion_level = 0);
+    static Name parse(ReadonlyBytes data, size_t& offset, size_t recursion_level = 0);
 
     size_t serialized_size() const;
     DeprecatedString const& as_string() const { return m_name; }

+ 10 - 12
Userland/Libraries/LibDNS/Packet.cpp

@@ -97,14 +97,14 @@ private:
 
 static_assert(sizeof(DNSRecordWithoutName) == 10);
 
-Optional<Packet> Packet::from_raw_packet(u8 const* raw_data, size_t raw_size)
+Optional<Packet> Packet::from_raw_packet(ReadonlyBytes bytes)
 {
-    if (raw_size < sizeof(PacketHeader)) {
-        dbgln("DNS response not large enough ({} out of {}) to be a DNS packet.", raw_size, sizeof(PacketHeader));
+    if (bytes.size() < sizeof(PacketHeader)) {
+        dbgln("DNS response not large enough ({} out of {}) to be a DNS packet.", bytes.size(), sizeof(PacketHeader));
         return {};
     }
 
-    auto& header = *(PacketHeader const*)(raw_data);
+    auto const& header = *bit_cast<PacketHeader const*>(bytes.data());
     dbgln_if(LOOKUPSERVER_DEBUG, "Got packet (ID: {})", header.id());
     dbgln_if(LOOKUPSERVER_DEBUG, "  Question count: {}", header.question_count());
     dbgln_if(LOOKUPSERVER_DEBUG, "    Answer count: {}", header.answer_count());
@@ -123,12 +123,12 @@ Optional<Packet> Packet::from_raw_packet(u8 const* raw_data, size_t raw_size)
     size_t offset = sizeof(PacketHeader);
 
     for (u16 i = 0; i < header.question_count(); i++) {
-        auto name = Name::parse(raw_data, offset, raw_size);
+        auto name = Name::parse(bytes, offset);
         struct RawDNSAnswerQuestion {
             NetworkOrdered<u16> record_type;
             NetworkOrdered<u16> class_code;
         };
-        auto& record_and_class = *(RawDNSAnswerQuestion const*)&raw_data[offset];
+        auto const& record_and_class = *bit_cast<RawDNSAnswerQuestion const*>(bytes.offset_pointer(offset));
         u16 class_code = record_and_class.class_code & ~MDNS_WANTS_UNICAST_RESPONSE;
         bool mdns_wants_unicast_response = record_and_class.class_code & MDNS_WANTS_UNICAST_RESPONSE;
         packet.m_questions.empend(name, (RecordType)(u16)record_and_class.record_type, (RecordClass)class_code, mdns_wants_unicast_response);
@@ -138,18 +138,16 @@ Optional<Packet> Packet::from_raw_packet(u8 const* raw_data, size_t raw_size)
     }
 
     for (u16 i = 0; i < header.answer_count(); ++i) {
-        auto name = Name::parse(raw_data, offset, raw_size);
-
-        auto& record = *(DNSRecordWithoutName const*)(&raw_data[offset]);
+        auto name = Name::parse(bytes, offset);
+        auto const& record = *bit_cast<DNSRecordWithoutName const*>(bytes.offset_pointer(offset));
+        offset += sizeof(DNSRecordWithoutName);
 
         DeprecatedString data;
 
-        offset += sizeof(DNSRecordWithoutName);
-
         switch ((RecordType)record.type()) {
         case RecordType::PTR: {
             size_t dummy_offset = offset;
-            data = Name::parse(raw_data, dummy_offset, raw_size).as_string();
+            data = Name::parse(bytes, dummy_offset).as_string();
             break;
         }
         case RecordType::CNAME:

+ 1 - 1
Userland/Libraries/LibDNS/Packet.h

@@ -24,7 +24,7 @@ class Packet {
 public:
     Packet() = default;
 
-    static Optional<Packet> from_raw_packet(u8 const*, size_t);
+    static Optional<Packet> from_raw_packet(ReadonlyBytes bytes);
     ErrorOr<ByteBuffer> to_byte_buffer() const;
 
     bool is_query() const { return !m_query_or_response; }

+ 1 - 1
Userland/Services/LookupServer/DNSServer.cpp

@@ -29,7 +29,7 @@ ErrorOr<void> DNSServer::handle_client()
 {
     sockaddr_in client_address;
     auto buffer = TRY(receive(1024, client_address));
-    auto optional_request = Packet::from_raw_packet(buffer.data(), buffer.size());
+    auto optional_request = Packet::from_raw_packet(buffer);
     if (!optional_request.has_value()) {
         dbgln("Got an invalid DNS packet");
         return {};

+ 2 - 2
Userland/Services/LookupServer/LookupServer.cpp

@@ -263,13 +263,13 @@ ErrorOr<Vector<Answer>> LookupServer::lookup(Name const& name, DeprecatedString
     TRY(udp_socket->write_until_depleted(buffer));
 
     u8 response_buffer[4096];
-    int nrecv = TRY(udp_socket->read_some({ response_buffer, sizeof(response_buffer) })).size();
+    auto nrecv = TRY(udp_socket->read_some({ response_buffer, sizeof(response_buffer) })).size();
     if (udp_socket->is_eof())
         return Vector<Answer> {};
 
     did_get_response = true;
 
-    auto o_response = Packet::from_raw_packet(response_buffer, nrecv);
+    auto o_response = Packet::from_raw_packet({ response_buffer, nrecv });
     if (!o_response.has_value())
         return Vector<Answer> {};
 

+ 2 - 2
Userland/Services/LookupServer/MulticastDNS.cpp

@@ -56,7 +56,7 @@ MulticastDNS::MulticastDNS(Core::EventReceiver* parent)
 ErrorOr<void> MulticastDNS::handle_packet()
 {
     auto buffer = TRY(receive(1024));
-    auto optional_packet = Packet::from_raw_packet(buffer.data(), buffer.size());
+    auto optional_packet = Packet::from_raw_packet(buffer);
     if (!optional_packet.has_value()) {
         dbgln("Got an invalid mDNS packet");
         return {};
@@ -175,7 +175,7 @@ ErrorOr<Vector<Answer>> MulticastDNS::lookup(Name const& name, RecordType record
         auto buffer = TRY(receive(1024));
         if (buffer.is_empty())
             return Vector<Answer> {};
-        auto optional_packet = Packet::from_raw_packet(buffer.data(), buffer.size());
+        auto optional_packet = Packet::from_raw_packet(buffer);
         if (!optional_packet.has_value()) {
             dbgln("Got an invalid mDNS packet");
             continue;