Browse Source

Kernel: Dont try to register ephemeral TCP ports twice

stelar7 4 years ago
parent
commit
01e5af487f
4 changed files with 28 additions and 19 deletions
  1. 10 9
      Kernel/Net/IPv4Socket.cpp
  2. 7 2
      Kernel/Net/IPv4Socket.h
  3. 10 7
      Kernel/Net/TCPSocket.cpp
  4. 1 1
      Kernel/Net/TCPSocket.h

+ 10 - 9
Kernel/Net/IPv4Socket.cpp

@@ -121,8 +121,9 @@ KResult IPv4Socket::bind(Userspace<const sockaddr*> user_address, socklen_t addr
 KResult IPv4Socket::listen(size_t backlog)
 KResult IPv4Socket::listen(size_t backlog)
 {
 {
     Locker locker(lock());
     Locker locker(lock());
-    if (auto result = allocate_local_port_if_needed(); result.is_error() && result.error() != -ENOPROTOOPT)
-        return result.error();
+    auto result = allocate_local_port_if_needed();
+    if (result.error_or_port.is_error() && result.error_or_port.error() != -ENOPROTOOPT)
+        return result.error_or_port.error();
 
 
     set_backlog(backlog);
     set_backlog(backlog);
     m_role = Role::Listener;
     m_role = Role::Listener;
@@ -130,7 +131,7 @@ KResult IPv4Socket::listen(size_t backlog)
 
 
     dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket({}) listening with backlog={}", this, backlog);
     dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket({}) listening with backlog={}", this, backlog);
 
 
-    return protocol_listen();
+    return protocol_listen(result.did_allocate);
 }
 }
 
 
 KResult IPv4Socket::connect(FileDescription& description, Userspace<const sockaddr*> address, socklen_t address_size, ShouldBlock should_block)
 KResult IPv4Socket::connect(FileDescription& description, Userspace<const sockaddr*> address, socklen_t address_size, ShouldBlock should_block)
@@ -172,16 +173,16 @@ bool IPv4Socket::can_write(const FileDescription&, size_t) const
     return is_connected();
     return is_connected();
 }
 }
 
 
-KResultOr<u16> IPv4Socket::allocate_local_port_if_needed()
+PortAllocationResult IPv4Socket::allocate_local_port_if_needed()
 {
 {
     Locker locker(lock());
     Locker locker(lock());
     if (m_local_port)
     if (m_local_port)
-        return m_local_port;
+        return { m_local_port, false };
     auto port_or_error = protocol_allocate_local_port();
     auto port_or_error = protocol_allocate_local_port();
     if (port_or_error.is_error())
     if (port_or_error.is_error())
-        return port_or_error.error();
+        return { port_or_error.error(), false };
     m_local_port = port_or_error.value();
     m_local_port = port_or_error.value();
-    return port_or_error.value();
+    return { m_local_port, true };
 }
 }
 
 
 KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const UserOrKernelBuffer& data, size_t data_length, [[maybe_unused]] int flags, Userspace<const sockaddr*> addr, socklen_t addr_length)
 KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const UserOrKernelBuffer& data, size_t data_length, [[maybe_unused]] int flags, Userspace<const sockaddr*> addr, socklen_t addr_length)
@@ -212,8 +213,8 @@ KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const UserOrKernelBuffer&
     if (m_local_address.to_u32() == 0)
     if (m_local_address.to_u32() == 0)
         m_local_address = routing_decision.adapter->ipv4_address();
         m_local_address = routing_decision.adapter->ipv4_address();
 
 
-    if (auto result = allocate_local_port_if_needed(); result.is_error() && result.error() != -ENOPROTOOPT)
-        return result.error();
+    if (auto result = allocate_local_port_if_needed(); result.error_or_port.is_error() && result.error_or_port.error() != -ENOPROTOOPT)
+        return result.error_or_port.error();
 
 
     dbgln_if(IPV4_SOCKET_DEBUG, "sendto: destination={}:{}", m_peer_address, m_peer_port);
     dbgln_if(IPV4_SOCKET_DEBUG, "sendto: destination={}:{}", m_peer_address, m_peer_port);
 
 

+ 7 - 2
Kernel/Net/IPv4Socket.h

@@ -21,6 +21,11 @@ class NetworkAdapter;
 class TCPPacket;
 class TCPPacket;
 class TCPSocket;
 class TCPSocket;
 
 
+struct PortAllocationResult {
+    KResultOr<u16> error_or_port;
+    bool did_allocate;
+};
+
 class IPv4Socket : public Socket {
 class IPv4Socket : public Socket {
 public:
 public:
     static KResultOr<NonnullRefPtr<Socket>> create(int type, int protocol);
     static KResultOr<NonnullRefPtr<Socket>> create(int type, int protocol);
@@ -72,10 +77,10 @@ protected:
     IPv4Socket(int type, int protocol);
     IPv4Socket(int type, int protocol);
     virtual const char* class_name() const override { return "IPv4Socket"; }
     virtual const char* class_name() const override { return "IPv4Socket"; }
 
 
-    KResultOr<u16> allocate_local_port_if_needed();
+    PortAllocationResult allocate_local_port_if_needed();
 
 
     virtual KResult protocol_bind() { return KSuccess; }
     virtual KResult protocol_bind() { return KSuccess; }
-    virtual KResult protocol_listen() { return KSuccess; }
+    virtual KResult protocol_listen([[maybe_unused]] bool did_allocate_port) { return KSuccess; }
     virtual KResultOr<size_t> protocol_receive(ReadonlyBytes /* raw_ipv4_packet */, UserOrKernelBuffer&, size_t, int) { return ENOTIMPL; }
     virtual KResultOr<size_t> protocol_receive(ReadonlyBytes /* raw_ipv4_packet */, UserOrKernelBuffer&, size_t, int) { return ENOTIMPL; }
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) { return ENOTIMPL; }
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) { return ENOTIMPL; }
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; }
     virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; }

+ 10 - 7
Kernel/Net/TCPSocket.cpp

@@ -365,12 +365,15 @@ KResult TCPSocket::protocol_bind()
     return KSuccess;
     return KSuccess;
 }
 }
 
 
-KResult TCPSocket::protocol_listen()
+KResult TCPSocket::protocol_listen(bool did_allocate_port)
 {
 {
-    Locker locker(sockets_by_tuple().lock());
-    if (sockets_by_tuple().resource().contains(tuple()))
-        return EADDRINUSE;
-    sockets_by_tuple().resource().set(tuple(), this);
+    if (!did_allocate_port) {
+        Locker socket_locker(sockets_by_tuple().lock());
+        if (sockets_by_tuple().resource().contains(tuple()))
+            return EADDRINUSE;
+        sockets_by_tuple().resource().set(tuple(), this);
+    }
+
     set_direction(Direction::Passive);
     set_direction(Direction::Passive);
     set_state(State::Listen);
     set_state(State::Listen);
     set_setup_state(SetupState::Completed);
     set_setup_state(SetupState::Completed);
@@ -387,8 +390,8 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh
     if (!has_specific_local_address())
     if (!has_specific_local_address())
         set_local_address(routing_decision.adapter->ipv4_address());
         set_local_address(routing_decision.adapter->ipv4_address());
 
 
-    if (auto result = allocate_local_port_if_needed(); result.is_error())
-        return result.error();
+    if (auto result = allocate_local_port_if_needed(); result.error_or_port.is_error())
+        return result.error_or_port.error();
 
 
     m_sequence_number = get_good_random<u32>();
     m_sequence_number = get_good_random<u32>();
     m_ack_number = 0;
     m_ack_number = 0;

+ 1 - 1
Kernel/Net/TCPSocket.h

@@ -176,7 +176,7 @@ private:
     virtual KResultOr<u16> protocol_allocate_local_port() override;
     virtual KResultOr<u16> protocol_allocate_local_port() override;
     virtual bool protocol_is_disconnected() const override;
     virtual bool protocol_is_disconnected() const override;
     virtual KResult protocol_bind() override;
     virtual KResult protocol_bind() override;
-    virtual KResult protocol_listen() override;
+    virtual KResult protocol_listen(bool did_allocate_port) override;
 
 
     void enqueue_for_retransmit();
     void enqueue_for_retransmit();
     void dequeue_for_retransmit();
     void dequeue_for_retransmit();