Преглед изворни кода

LibDNS: Use `AllocatingMemoryStream` in DNS package construction

Tim Schumacher пре 2 година
родитељ
комит
87c64834ca

+ 1 - 1
Userland/Libraries/LibDNS/CMakeLists.txt

@@ -5,4 +5,4 @@ set(SOURCES
 )
 
 serenity_lib(LibDNS dns)
-target_link_libraries(LibDNS PRIVATE LibIPC)
+target_link_libraries(LibDNS PRIVATE LibCore LibIPC)

+ 6 - 6
Userland/Libraries/LibDNS/Name.cpp

@@ -75,15 +75,15 @@ void Name::randomize_case()
     m_name = builder.to_deprecated_string();
 }
 
-OutputStream& operator<<(OutputStream& stream, Name const& name)
+ErrorOr<void> Name::write_to_stream(Core::Stream::Stream& stream) const
 {
-    auto parts = name.as_string().split_view('.');
+    auto parts = as_string().split_view('.');
     for (auto& part : parts) {
-        stream << (u8)part.length();
-        stream << part.bytes();
+        TRY(stream.write_trivial_value<u8>(part.length()));
+        TRY(stream.write_entire_buffer(part.bytes()));
     }
-    stream << '\0';
-    return stream;
+    TRY(stream.write_trivial_value('\0'));
+    return {};
 }
 
 unsigned Name::Traits::hash(Name const& name)

+ 2 - 2
Userland/Libraries/LibDNS/Name.h

@@ -9,6 +9,7 @@
 
 #include <AK/DeprecatedString.h>
 #include <AK/Forward.h>
+#include <LibCore/Stream.h>
 
 namespace DNS {
 
@@ -21,6 +22,7 @@ public:
 
     size_t serialized_size() const;
     DeprecatedString const& as_string() const { return m_name; }
+    ErrorOr<void> write_to_stream(Core::Stream::Stream&) const;
 
     void randomize_case();
 
@@ -36,8 +38,6 @@ private:
     DeprecatedString m_name;
 };
 
-OutputStream& operator<<(OutputStream& stream, Name const&);
-
 }
 
 template<>

+ 18 - 15
Userland/Libraries/LibDNS/Packet.cpp

@@ -11,6 +11,7 @@
 #include <AK/Debug.h>
 #include <AK/MemoryStream.h>
 #include <AK/StringBuilder.h>
+#include <LibCore/MemoryStream.h>
 #include <arpa/inet.h>
 
 namespace DNS {
@@ -29,7 +30,7 @@ void Packet::add_answer(Answer const& answer)
     VERIFY(m_answers.size() <= UINT16_MAX);
 }
 
-ByteBuffer Packet::to_byte_buffer() const
+ErrorOr<ByteBuffer> Packet::to_byte_buffer() const
 {
     PacketHeader header;
     header.set_id(m_id);
@@ -48,30 +49,32 @@ ByteBuffer Packet::to_byte_buffer() const
     header.set_question_count(m_questions.size());
     header.set_answer_count(m_answers.size());
 
-    DuplexMemoryStream stream;
+    Core::Stream::AllocatingMemoryStream stream;
 
-    stream << ReadonlyBytes { &header, sizeof(header) };
+    TRY(stream.write_trivial_value(header));
     for (auto& question : m_questions) {
-        stream << question.name();
-        stream << htons((u16)question.record_type());
-        stream << htons(question.raw_class_code());
+        TRY(question.name().write_to_stream(stream));
+        TRY(stream.write_trivial_value(htons((u16)question.record_type())));
+        TRY(stream.write_trivial_value(htons(question.raw_class_code())));
     }
     for (auto& answer : m_answers) {
-        stream << answer.name();
-        stream << htons((u16)answer.type());
-        stream << htons(answer.raw_class_code());
-        stream << htonl(answer.ttl());
+        TRY(answer.name().write_to_stream(stream));
+        TRY(stream.write_trivial_value(htons((u16)answer.type())));
+        TRY(stream.write_trivial_value(htons(answer.raw_class_code())));
+        TRY(stream.write_trivial_value(htonl(answer.ttl())));
         if (answer.type() == RecordType::PTR) {
             Name name { answer.record_data() };
-            stream << htons(name.serialized_size());
-            stream << name;
+            TRY(stream.write_trivial_value(htons(name.serialized_size())));
+            TRY(name.write_to_stream(stream));
         } else {
-            stream << htons(answer.record_data().length());
-            stream << answer.record_data().bytes();
+            TRY(stream.write_trivial_value(htons(answer.record_data().length())));
+            TRY(stream.write_entire_buffer(answer.record_data().bytes()));
         }
     }
 
-    return stream.copy_into_contiguous_buffer();
+    auto buffer = TRY(ByteBuffer::create_uninitialized(stream.used_buffer_size()));
+    TRY(stream.read_entire_buffer(buffer));
+    return buffer;
 }
 
 class [[gnu::packed]] DNSRecordWithoutName {

+ 1 - 1
Userland/Libraries/LibDNS/Packet.h

@@ -25,7 +25,7 @@ public:
     Packet() = default;
 
     static Optional<Packet> from_raw_packet(u8 const*, size_t);
-    ByteBuffer to_byte_buffer() const;
+    ErrorOr<ByteBuffer> to_byte_buffer() const;
 
     bool is_query() const { return !m_query_or_response; }
     bool is_response() const { return m_query_or_response; }

+ 1 - 1
Userland/Services/LookupServer/DNSServer.cpp

@@ -62,7 +62,7 @@ ErrorOr<void> DNSServer::handle_client()
     else
         response.set_code(Packet::Code::NOERROR);
 
-    buffer = response.to_byte_buffer();
+    buffer = TRY(response.to_byte_buffer());
 
     TRY(send(buffer, client_address));
     return {};

+ 1 - 1
Userland/Services/LookupServer/LookupServer.cpp

@@ -234,7 +234,7 @@ ErrorOr<Vector<Answer>> LookupServer::lookup(Name const& name, DeprecatedString
         name_in_question.randomize_case();
     request.add_question({ name_in_question, record_type, RecordClass::IN, false });
 
-    auto buffer = request.to_byte_buffer();
+    auto buffer = TRY(request.to_byte_buffer());
 
     auto udp_socket = TRY(Core::Stream::UDPSocket::connect(nameserver, 53, Time::from_seconds(1)));
     TRY(udp_socket->set_blocking(true));

+ 1 - 1
Userland/Services/LookupServer/MulticastDNS.cpp

@@ -110,7 +110,7 @@ void MulticastDNS::announce()
 
 ErrorOr<size_t> MulticastDNS::emit_packet(Packet const& packet, sockaddr_in const* destination)
 {
-    auto buffer = packet.to_byte_buffer();
+    auto buffer = TRY(packet.to_byte_buffer());
     if (!destination)
         destination = &mdns_addr;