Explorar el Código

Kernel: Handle OOM when allocating IPv4Socket optional scratch buffer

Brian Gianforcaro hace 4 años
padre
commit
c1a0e379e6

+ 6 - 5
Kernel/Net/IPv4Socket.cpp

@@ -59,7 +59,7 @@ KResultOr<NonnullRefPtr<Socket>> IPv4Socket::create(int type, int protocol)
         return udp_socket.release_value();
     }
     if (type == SOCK_RAW) {
-        auto raw_socket = adopt_ref_if_nonnull(new (nothrow) IPv4Socket(type, protocol, receive_buffer.release_nonnull()));
+        auto raw_socket = adopt_ref_if_nonnull(new (nothrow) IPv4Socket(type, protocol, receive_buffer.release_nonnull(), {}));
         if (raw_socket)
             return raw_socket.release_nonnull();
         return ENOMEM;
@@ -67,14 +67,15 @@ KResultOr<NonnullRefPtr<Socket>> IPv4Socket::create(int type, int protocol)
     return EINVAL;
 }
 
-IPv4Socket::IPv4Socket(int type, int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
+IPv4Socket::IPv4Socket(int type, int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer, OwnPtr<KBuffer> optional_scratch_buffer)
     : Socket(AF_INET, type, protocol)
     , m_receive_buffer(move(receive_buffer))
+    , m_scratch_buffer(move(optional_scratch_buffer))
 {
     dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket({}) created with type={}, protocol={}", this, type, protocol);
     m_buffer_mode = type == SOCK_STREAM ? BufferMode::Bytes : BufferMode::Packets;
     if (m_buffer_mode == BufferMode::Bytes) {
-        m_scratch_buffer = KBuffer::create_with_size(65536);
+        VERIFY(m_scratch_buffer);
     }
     MutexLocker locker(all_sockets().lock());
     all_sockets().resource().set(this);
@@ -422,8 +423,8 @@ bool IPv4Socket::did_receive(const IPv4Address& source_address, u16 source_port,
             VERIFY(m_can_read);
             return false;
         }
-        auto scratch_buffer = UserOrKernelBuffer::for_kernel_buffer(m_scratch_buffer.value().data());
-        auto nreceived_or_error = protocol_receive(ReadonlyBytes { packet.data(), packet.size() }, scratch_buffer, m_scratch_buffer.value().size(), 0);
+        auto scratch_buffer = UserOrKernelBuffer::for_kernel_buffer(m_scratch_buffer->data());
+        auto nreceived_or_error = protocol_receive(packet, scratch_buffer, m_scratch_buffer->size(), 0);
         if (nreceived_or_error.is_error())
             return false;
         auto nwritten_or_error = m_receive_buffer->write(scratch_buffer, nreceived_or_error.value());

+ 2 - 2
Kernel/Net/IPv4Socket.h

@@ -74,7 +74,7 @@ public:
     BufferMode buffer_mode() const { return m_buffer_mode; }
 
 protected:
-    IPv4Socket(int type, int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
+    IPv4Socket(int type, int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer, OwnPtr<KBuffer> optional_scratch_buffer);
     virtual StringView class_name() const override { return "IPv4Socket"; }
 
     PortAllocationResult allocate_local_port_if_needed();
@@ -130,7 +130,7 @@ private:
 
     BufferMode m_buffer_mode { BufferMode::Packets };
 
-    Optional<KBuffer> m_scratch_buffer;
+    OwnPtr<KBuffer> m_scratch_buffer;
 };
 
 }

+ 8 - 3
Kernel/Net/TCPSocket.cpp

@@ -134,8 +134,8 @@ void TCPSocket::release_for_accept(RefPtr<TCPSocket> socket)
     [[maybe_unused]] auto rc = queue_connection_from(*socket);
 }
 
-TCPSocket::TCPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
-    : IPv4Socket(SOCK_STREAM, protocol, move(receive_buffer))
+TCPSocket::TCPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer, OwnPtr<KBuffer> scratch_buffer)
+    : IPv4Socket(SOCK_STREAM, protocol, move(receive_buffer), move(scratch_buffer))
 {
     m_last_retransmit_time = kgettimeofday();
 }
@@ -152,7 +152,12 @@ TCPSocket::~TCPSocket()
 
 KResultOr<NonnullRefPtr<TCPSocket>> TCPSocket::create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
 {
-    auto socket = adopt_ref_if_nonnull(new (nothrow) TCPSocket(protocol, move(receive_buffer)));
+    // Note: Scratch buffer is only used for SOCK_STREAM sockets.
+    auto scratch_buffer = KBuffer::try_create_with_size(65536);
+    if (!scratch_buffer)
+        return ENOMEM;
+
+    auto socket = adopt_ref_if_nonnull(new (nothrow) TCPSocket(protocol, move(receive_buffer), move(scratch_buffer)));
     if (socket)
         return socket.release_nonnull();
     return ENOMEM;

+ 1 - 1
Kernel/Net/TCPSocket.h

@@ -165,7 +165,7 @@ protected:
     void set_direction(Direction direction) { m_direction = direction; }
 
 private:
-    explicit TCPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
+    explicit TCPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer, OwnPtr<KBuffer> scratch_buffer);
     virtual StringView class_name() const override { return "TCPSocket"; }
 
     virtual void shut_down_for_writing() override;

+ 1 - 1
Kernel/Net/UDPSocket.cpp

@@ -44,7 +44,7 @@ SocketHandle<UDPSocket> UDPSocket::from_port(u16 port)
 }
 
 UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
-    : IPv4Socket(SOCK_DGRAM, protocol, move(receive_buffer))
+    : IPv4Socket(SOCK_DGRAM, protocol, move(receive_buffer), {})
 {
 }