Ver Fonte

Kernel/IPv4: Propagate errors from local port allocation

Remove hacks and assumptions and make the EADDRINUSE propagate all
the way from the point of failure to the syscall layer.
Andreas Kling há 4 anos atrás
pai
commit
71a10eb8e7

+ 11 - 12
Kernel/Net/IPv4Socket.cpp

@@ -109,9 +109,8 @@ KResult IPv4Socket::bind(Userspace<const sockaddr*> user_address, socklen_t addr
 KResult IPv4Socket::listen(size_t backlog)
 {
     Locker locker(lock());
-    int rc = allocate_local_port_if_needed();
-    if (rc < 0)
-        return EADDRINUSE;
+    if (auto result = allocate_local_port_if_needed(); result.is_error())
+        return result.error();
 
     set_backlog(backlog);
     m_role = Role::Listener;
@@ -159,15 +158,16 @@ bool IPv4Socket::can_write(const FileDescription&, size_t) const
     return is_connected();
 }
 
-int IPv4Socket::allocate_local_port_if_needed()
+KResultOr<u16> IPv4Socket::allocate_local_port_if_needed()
 {
+    Locker locker(lock());
     if (m_local_port)
         return m_local_port;
-    int port = protocol_allocate_local_port();
-    if (port < 0)
-        return port;
-    m_local_port = (u16)port;
-    return port;
+    auto port_or_error = protocol_allocate_local_port();
+    if (port_or_error.is_error())
+        return port_or_error.error();
+    m_local_port = port_or_error.value();
+    return port_or_error.value();
 }
 
 KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const UserOrKernelBuffer& data, size_t data_length, [[maybe_unused]] int flags, Userspace<const sockaddr*> addr, socklen_t addr_length)
@@ -198,9 +198,8 @@ KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const UserOrKernelBuffer&
     if (m_local_address.to_u32() == 0)
         m_local_address = routing_decision.adapter->ipv4_address();
 
-    int rc = allocate_local_port_if_needed();
-    if (rc < 0)
-        return rc;
+    if (auto result = allocate_local_port_if_needed(); result.is_error())
+        return result.error();
 
     dbgln_if(IPV4_SOCKET_DEBUG, "sendto: destination={}:{}", m_peer_address, m_peer_port);
 

+ 2 - 2
Kernel/Net/IPv4Socket.h

@@ -70,14 +70,14 @@ protected:
     IPv4Socket(int type, int protocol);
     virtual const char* class_name() const override { return "IPv4Socket"; }
 
-    int allocate_local_port_if_needed();
+    KResultOr<u16> allocate_local_port_if_needed();
 
     virtual KResult protocol_bind() { return KSuccess; }
     virtual KResult protocol_listen() { return KSuccess; }
     virtual KResultOr<size_t> protocol_receive(ReadonlyBytes /* raw_ipv4_packet */, UserOrKernelBuffer&, size_t, int) { return -ENOTIMPL; }
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) { return -ENOTIMPL; }
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; }
-    virtual int protocol_allocate_local_port() { return 0; }
+    virtual KResultOr<u16> protocol_allocate_local_port() { return ENOPROTOOPT; }
     virtual bool protocol_is_disconnected() const { return false; }
 
     virtual void shut_down_for_reading() override;

+ 4 - 3
Kernel/Net/TCPSocket.cpp

@@ -369,7 +369,8 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh
     if (!has_specific_local_address())
         set_local_address(routing_decision.adapter->ipv4_address());
 
-    allocate_local_port_if_needed();
+    if (auto result = allocate_local_port_if_needed(); result.is_error())
+        return result.error();
 
     m_sequence_number = get_good_random<u32>();
     m_ack_number = 0;
@@ -401,7 +402,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh
     return EINPROGRESS;
 }
 
-int TCPSocket::protocol_allocate_local_port()
+KResultOr<u16> TCPSocket::protocol_allocate_local_port()
 {
     static const u16 first_ephemeral_port = 32768;
     static const u16 last_ephemeral_port = 60999;
@@ -424,7 +425,7 @@ int TCPSocket::protocol_allocate_local_port()
         if (port == first_scan_port)
             break;
     }
-    return -EADDRINUSE;
+    return EADDRINUSE;
 }
 
 bool TCPSocket::protocol_is_disconnected() const

+ 1 - 1
Kernel/Net/TCPSocket.h

@@ -160,7 +160,7 @@ private:
     virtual KResultOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) override;
-    virtual int protocol_allocate_local_port() override;
+    virtual KResultOr<u16> protocol_allocate_local_port() override;
     virtual bool protocol_is_disconnected() const override;
     virtual KResult protocol_bind() override;
     virtual KResult protocol_listen() override;

+ 2 - 2
Kernel/Net/UDPSocket.cpp

@@ -97,7 +97,7 @@ KResult UDPSocket::protocol_connect(FileDescription&, ShouldBlock)
     return KSuccess;
 }
 
-int UDPSocket::protocol_allocate_local_port()
+KResultOr<u16> UDPSocket::protocol_allocate_local_port()
 {
     static const u16 first_ephemeral_port = 32768;
     static const u16 last_ephemeral_port = 60999;
@@ -118,7 +118,7 @@ int UDPSocket::protocol_allocate_local_port()
         if (port == first_scan_port)
             break;
     }
-    return -EADDRINUSE;
+    return EADDRINUSE;
 }
 
 KResult UDPSocket::protocol_bind()

+ 1 - 1
Kernel/Net/UDPSocket.h

@@ -26,7 +26,7 @@ private:
     virtual KResultOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) override;
-    virtual int protocol_allocate_local_port() override;
+    virtual KResultOr<u16> protocol_allocate_local_port() override;
     virtual KResult protocol_bind() override;
 };