浏览代码

Kernel: Make TCPSocket::create API OOM safe

Note that the changes to IPv4Socket::create are unfortunately needed as
the return type of TCPSocket::create and IPv4Socket::create don't match.

 - KResultOr<NonnullRefPtr<TcpSocket>>>
   vs
 - KResultOr<NonnullRefPtr<Socket>>>

To handle this we are forced to manually decompose the KResultOr<T> and
return the value() and error() separately.
Brian Gianforcaro 4 年之前
父节点
当前提交
46ce7adf7b
共有 3 个文件被更改,包括 17 次插入6 次删除
  1. 6 2
      Kernel/Net/IPv4Socket.cpp
  2. 9 3
      Kernel/Net/TCPSocket.cpp
  3. 2 1
      Kernel/Net/TCPSocket.h

+ 6 - 2
Kernel/Net/IPv4Socket.cpp

@@ -36,8 +36,12 @@ Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets()
 
 KResultOr<NonnullRefPtr<Socket>> IPv4Socket::create(int type, int protocol)
 {
-    if (type == SOCK_STREAM)
-        return TCPSocket::create(protocol);
+    if (type == SOCK_STREAM) {
+        auto tcp_socket = TCPSocket::create(protocol);
+        if (tcp_socket.is_error())
+            return tcp_socket.error();
+        return tcp_socket.release_value();
+    }
     if (type == SOCK_DGRAM)
         return UDPSocket::create(protocol);
     if (type == SOCK_RAW)

+ 9 - 3
Kernel/Net/TCPSocket.cpp

@@ -98,8 +98,11 @@ RefPtr<TCPSocket> TCPSocket::create_client(const IPv4Address& new_local_address,
     if (sockets_by_tuple().resource().contains(tuple))
         return {};
 
-    auto client = TCPSocket::create(protocol());
+    auto result = TCPSocket::create(protocol());
+    if (result.is_error())
+        return {};
 
+    auto client = result.release_value();
     client->set_setup_state(SetupState::InProgress);
     client->set_local_address(new_local_address);
     client->set_local_port(new_local_port);
@@ -142,9 +145,12 @@ TCPSocket::~TCPSocket()
     dbgln_if(TCP_SOCKET_DEBUG, "~TCPSocket in state {}", to_string(state()));
 }
 
-NonnullRefPtr<TCPSocket> TCPSocket::create(int protocol)
+KResultOr<NonnullRefPtr<TCPSocket>> TCPSocket::create(int protocol)
 {
-    return adopt_ref(*new TCPSocket(protocol));
+    auto socket = adopt_ref_if_nonnull(new TCPSocket(protocol));
+    if (socket)
+        return socket.release_nonnull();
+    return ENOMEM;
 }
 
 KResultOr<size_t> TCPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, [[maybe_unused]] int flags)

+ 2 - 1
Kernel/Net/TCPSocket.h

@@ -10,6 +10,7 @@
 #include <AK/HashMap.h>
 #include <AK/SinglyLinkedList.h>
 #include <AK/WeakPtr.h>
+#include <Kernel/KResult.h>
 #include <Kernel/Net/IPv4Socket.h>
 
 namespace Kernel {
@@ -17,7 +18,7 @@ namespace Kernel {
 class TCPSocket final : public IPv4Socket {
 public:
     static void for_each(Function<void(const TCPSocket&)>);
-    static NonnullRefPtr<TCPSocket> create(int protocol);
+    static KResultOr<NonnullRefPtr<TCPSocket>> create(int protocol);
     virtual ~TCPSocket() override;
 
     enum class Direction {