Przeglądaj źródła

Kernel/Net: Make IPv4Socket::protocol_receive() take a ReadonlyBytes

The overrides of this function don't need to know how the original
packet was stored, so let's just give them a ReadonlyBytes view of
the raw packet data.
Andreas Kling 4 lat temu
rodzic
commit
8cc81c2953

+ 2 - 2
Kernel/Net/IPv4Socket.cpp

@@ -363,7 +363,7 @@ KResultOr<size_t> IPv4Socket::receive_packet_buffered(FileDescription& descripti
         return bytes_written;
         return bytes_written;
     }
     }
 
 
-    return protocol_receive(packet.data.value(), buffer, buffer_length, flags);
+    return protocol_receive(ReadonlyBytes { packet.data.value().data(), packet.data.value().size() }, buffer, buffer_length, flags);
 }
 }
 
 
 KResultOr<size_t> IPv4Socket::recvfrom(FileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*> user_addr, Userspace<socklen_t*> user_addr_length, timeval& packet_timestamp)
 KResultOr<size_t> IPv4Socket::recvfrom(FileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*> user_addr, Userspace<socklen_t*> user_addr_length, timeval& packet_timestamp)
@@ -408,7 +408,7 @@ bool IPv4Socket::did_receive(const IPv4Address& source_address, u16 source_port,
             return false;
             return false;
         }
         }
         auto scratch_buffer = UserOrKernelBuffer::for_kernel_buffer(m_scratch_buffer.value().data());
         auto scratch_buffer = UserOrKernelBuffer::for_kernel_buffer(m_scratch_buffer.value().data());
-        auto nreceived_or_error = protocol_receive(packet, scratch_buffer, m_scratch_buffer.value().size(), 0);
+        auto nreceived_or_error = protocol_receive(ReadonlyBytes { packet.data(), packet.size() }, scratch_buffer, m_scratch_buffer.value().size(), 0);
         if (nreceived_or_error.is_error())
         if (nreceived_or_error.is_error())
             return false;
             return false;
         ssize_t nwritten = m_receive_buffer.write(scratch_buffer, nreceived_or_error.value());
         ssize_t nwritten = m_receive_buffer.write(scratch_buffer, nreceived_or_error.value());

+ 1 - 1
Kernel/Net/IPv4Socket.h

@@ -96,7 +96,7 @@ protected:
 
 
     virtual KResult protocol_bind() { return KSuccess; }
     virtual KResult protocol_bind() { return KSuccess; }
     virtual KResult protocol_listen() { return KSuccess; }
     virtual KResult protocol_listen() { return KSuccess; }
-    virtual KResultOr<size_t> protocol_receive(const KBuffer&, UserOrKernelBuffer&, size_t, int) { return -ENOTIMPL; }
+    virtual KResultOr<size_t> protocol_receive(ReadonlyBytes /* raw_ipv4_packet */, UserOrKernelBuffer&, size_t, int) { return -ENOTIMPL; }
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) { return -ENOTIMPL; }
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) { return -ENOTIMPL; }
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; }
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; }
     virtual int protocol_allocate_local_port() { return 0; }
     virtual int protocol_allocate_local_port() { return 0; }

+ 3 - 3
Kernel/Net/TCPSocket.cpp

@@ -167,12 +167,12 @@ NonnullRefPtr<TCPSocket> TCPSocket::create(int protocol)
     return adopt(*new TCPSocket(protocol));
     return adopt(*new TCPSocket(protocol));
 }
 }
 
 
-KResultOr<size_t> TCPSocket::protocol_receive(const KBuffer& packet_buffer, UserOrKernelBuffer& buffer, size_t buffer_size, int flags)
+KResultOr<size_t> TCPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags)
 {
 {
     (void)flags;
     (void)flags;
-    auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.data());
+    auto& ipv4_packet = *reinterpret_cast<const IPv4Packet*>(raw_ipv4_packet.data());
     auto& tcp_packet = *static_cast<const TCPPacket*>(ipv4_packet.payload());
     auto& tcp_packet = *static_cast<const TCPPacket*>(ipv4_packet.payload());
-    size_t payload_size = packet_buffer.size() - sizeof(IPv4Packet) - tcp_packet.header_size();
+    size_t payload_size = raw_ipv4_packet.size() - sizeof(IPv4Packet) - tcp_packet.header_size();
 #ifdef TCP_SOCKET_DEBUG
 #ifdef TCP_SOCKET_DEBUG
     klog() << "payload_size " << payload_size << ", will it fit in " << buffer_size << "?";
     klog() << "payload_size " << payload_size << ", will it fit in " << buffer_size << "?";
 #endif
 #endif

+ 1 - 1
Kernel/Net/TCPSocket.h

@@ -177,7 +177,7 @@ private:
 
 
     virtual void shut_down_for_writing() override;
     virtual void shut_down_for_writing() override;
 
 
-    virtual KResultOr<size_t> protocol_receive(const KBuffer&, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
+    virtual KResultOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) override;
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) override;
     virtual int protocol_allocate_local_port() override;
     virtual int protocol_allocate_local_port() override;

+ 2 - 2
Kernel/Net/UDPSocket.cpp

@@ -79,10 +79,10 @@ NonnullRefPtr<UDPSocket> UDPSocket::create(int protocol)
     return adopt(*new UDPSocket(protocol));
     return adopt(*new UDPSocket(protocol));
 }
 }
 
 
-KResultOr<size_t> UDPSocket::protocol_receive(const KBuffer& packet_buffer, UserOrKernelBuffer& buffer, size_t buffer_size, int flags)
+KResultOr<size_t> UDPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags)
 {
 {
     (void)flags;
     (void)flags;
-    auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.data());
+    auto& ipv4_packet = *(const IPv4Packet*)(raw_ipv4_packet.data());
     auto& udp_packet = *static_cast<const UDPPacket*>(ipv4_packet.payload());
     auto& udp_packet = *static_cast<const UDPPacket*>(ipv4_packet.payload());
     ASSERT(udp_packet.length() >= sizeof(UDPPacket)); // FIXME: This should be rejected earlier.
     ASSERT(udp_packet.length() >= sizeof(UDPPacket)); // FIXME: This should be rejected earlier.
     ASSERT(buffer_size >= (udp_packet.length() - sizeof(UDPPacket)));
     ASSERT(buffer_size >= (udp_packet.length() - sizeof(UDPPacket)));

+ 1 - 1
Kernel/Net/UDPSocket.h

@@ -43,7 +43,7 @@ private:
     virtual const char* class_name() const override { return "UDPSocket"; }
     virtual const char* class_name() const override { return "UDPSocket"; }
     static Lockable<HashMap<u16, UDPSocket*>>& sockets_by_port();
     static Lockable<HashMap<u16, UDPSocket*>>& sockets_by_port();
 
 
-    virtual KResultOr<size_t> protocol_receive(const KBuffer&, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
+    virtual KResultOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) override;
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) override;
     virtual int protocol_allocate_local_port() override;
     virtual int protocol_allocate_local_port() override;