From 95bcffd713d49cb52bb74bd9d1e90408fb80fa73 Mon Sep 17 00:00:00 2001 From: Sergey Bugaev Date: Sun, 23 Jul 2023 15:43:45 +0300 Subject: [PATCH] Kernel/Net: Rework ephemeral port allocation Currently, ephemeral port allocation is handled by the allocate_local_port_if_needed() and protocol_allocate_local_port() methods. Actually binding the socket to an address (which means inserting the socket/address pair into a global map) is performed either in protocol_allocate_local_port() (for ephemeral ports) or in protocol_listen() (for non-ephemeral ports); the latter will fail with EADDRINUSE if the address is already used by an existing pair present in the map. There used to be a bug where for listen() without an explicit bind(), the port allocation would conflict with itself: first an ephemeral port would get allocated and inserted into the map, and then protocol_listen() would check again for the port being free, find the just-created map entry, and error out. This was fixed in commit 01e5af487f9513696dbcacab15d3e0036446f586 by passing an additional flag did_allocate_port into protocol_listen() which specifies whether the port was just allocated, and skipping the check in protocol_listen() if the flag is set. However, this only helps if the socket is bound to an ephemeral port inside of this very listen() call. But calling bind(sin_port = 0) from userspace should succeed and bind to an allocated ephemeral port, in the same was as using an unbound socket for connect() does. The port number can then be retrieved from userspace by calling getsockname (), and it should be possible to either connect() or listen() on this socket, keeping the allocated port number. Also, calling bind() when already bound (either explicitly or implicitly) should always result in EINVAL. To untangle this, introduce an explicit m_bound state in IPv4Socket, just like LocalSocket has already. Once a socket is bound, further attempt to bind it fail. Some operations cause the socket to implicitly get bound to an (ephemeral) address; this is implemented by the new ensure_bound() method. The protocol_allocate_local_port() method is gone; it is now up to a protocol to assign a port to the socket inside protocol_bind() if it finds that the socket has local_port() == 0. protocol_bind() is now called in more cases, such as inside listen() if the socket wasn't bound before that. --- Kernel/Net/IPv4Socket.cpp | 39 ++++++++++----------- Kernel/Net/IPv4Socket.h | 12 +++---- Kernel/Net/TCPSocket.cpp | 74 +++++++++++++++++++++------------------ Kernel/Net/TCPSocket.h | 3 +- Kernel/Net/UDPSocket.cpp | 63 +++++++++++++++++---------------- Kernel/Net/UDPSocket.h | 1 - 6 files changed, 96 insertions(+), 96 deletions(-) diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 019417f8133..9c80ac6ff05 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -95,8 +95,23 @@ void IPv4Socket::get_peer_address(sockaddr* address, socklen_t* address_size) *address_size = sizeof(sockaddr_in); } +ErrorOr IPv4Socket::ensure_bound() +{ + dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket::ensure_bound() m_bound {}", m_bound); + if (m_bound) + return {}; + + auto result = protocol_bind(); + if (!result.is_error()) + m_bound = true; + return result; +} + ErrorOr IPv4Socket::bind(Credentials const& credentials, Userspace user_address, socklen_t address_size) { + if (m_bound) + return set_so_error(EINVAL); + VERIFY(setup_state() == SetupState::Unstarted); if (address_size != sizeof(sockaddr_in)) return set_so_error(EINVAL); @@ -120,23 +135,20 @@ ErrorOr IPv4Socket::bind(Credentials const& credentials, Userspace IPv4Socket::listen(size_t backlog) { MutexLocker locker(mutex()); - auto result = allocate_local_port_if_needed(); - if (result.error_or_port.is_error() && result.error_or_port.error().code() != ENOPROTOOPT) - return result.error_or_port.release_error(); - + TRY(ensure_bound()); set_backlog(backlog); set_role(Role::Listener); evaluate_block_conditions(); dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket({}) listening with backlog={}", this, backlog); - return protocol_listen(result.did_allocate); + return protocol_listen(); } ErrorOr IPv4Socket::connect(Credentials const&, OpenFileDescription& description, Userspace address, socklen_t address_size) @@ -176,18 +188,6 @@ bool IPv4Socket::can_write(OpenFileDescription const&, u64) const return true; } -PortAllocationResult IPv4Socket::allocate_local_port_if_needed() -{ - MutexLocker locker(mutex()); - if (m_local_port) - return { m_local_port, false }; - auto port_or_error = protocol_allocate_local_port(); - if (port_or_error.is_error()) - return { port_or_error.release_error(), false }; - m_local_port = port_or_error.release_value(); - return { m_local_port, true }; -} - ErrorOr IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer const& data, size_t data_length, [[maybe_unused]] int flags, Userspace addr, socklen_t addr_length) { MutexLocker locker(mutex()); @@ -220,8 +220,7 @@ ErrorOr IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer cons if (m_local_address.to_u32() == 0) m_local_address = routing_decision.adapter->ipv4_address(); - if (auto result = allocate_local_port_if_needed(); result.error_or_port.is_error() && result.error_or_port.error().code() != ENOPROTOOPT) - return result.error_or_port.release_error(); + TRY(ensure_bound()); dbgln_if(IPV4_SOCKET_DEBUG, "sendto: destination={}:{}", m_peer_address, m_peer_port); diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h index aa0b5878a2c..e9ef4eef592 100644 --- a/Kernel/Net/IPv4Socket.h +++ b/Kernel/Net/IPv4Socket.h @@ -21,11 +21,6 @@ class NetworkAdapter; class TCPPacket; class TCPSocket; -struct PortAllocationResult { - ErrorOr error_or_port; - bool did_allocate; -}; - class IPv4Socket : public Socket { public: static ErrorOr> create(int type, int protocol); @@ -76,14 +71,14 @@ protected: IPv4Socket(int type, int protocol, NonnullOwnPtr receive_buffer, OwnPtr optional_scratch_buffer); virtual StringView class_name() const override { return "IPv4Socket"sv; } - PortAllocationResult allocate_local_port_if_needed(); + void set_bound(bool bound) { m_bound = bound; } + ErrorOr ensure_bound(); virtual ErrorOr protocol_bind() { return {}; } - virtual ErrorOr protocol_listen([[maybe_unused]] bool did_allocate_port) { return {}; } + virtual ErrorOr protocol_listen() { return {}; } virtual ErrorOr protocol_receive(ReadonlyBytes /* raw_ipv4_packet */, UserOrKernelBuffer&, size_t, int) { return ENOTIMPL; } virtual ErrorOr protocol_send(UserOrKernelBuffer const&, size_t) { return ENOTIMPL; } virtual ErrorOr protocol_connect(OpenFileDescription&) { return {}; } - virtual ErrorOr protocol_allocate_local_port() { return ENOPROTOOPT; } virtual ErrorOr protocol_size(ReadonlyBytes /* raw_ipv4_packet */) { return ENOTIMPL; } virtual bool protocol_is_disconnected() const { return false; } @@ -108,6 +103,7 @@ private: Vector m_multicast_memberships; bool m_multicast_loop { true }; + bool m_bound { false }; struct ReceivedPacket { IPv4Address peer_address; diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index 2027f969adc..b0dc7838073 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -137,6 +137,7 @@ ErrorOr> TCPSocket::try_create_client(IPv4Address const client->set_local_port(new_local_port); client->set_peer_address(new_peer_address); client->set_peer_port(new_peer_port); + client->set_bound(true); client->set_direction(Direction::Incoming); client->set_originator(*this); @@ -414,19 +415,46 @@ NetworkOrdered TCPSocket::compute_tcp_checksum(IPv4Address const& source, I ErrorOr TCPSocket::protocol_bind() { - return m_adapter.with([this](auto& adapter) -> ErrorOr { + dbgln_if(TCP_SOCKET_DEBUG, "TCPSocket::protocol_bind(), local_port() is {}", local_port()); + // Check that we do have the address we're trying to bind to. + TRY(m_adapter.with([this](auto& adapter) -> ErrorOr { if (has_specific_local_address() && !adapter) { adapter = NetworkingManagement::the().from_ipv4_address(local_address()); if (!adapter) return set_so_error(EADDRNOTAVAIL); } return {}; - }); -} + })); -ErrorOr TCPSocket::protocol_listen(bool did_allocate_port) -{ - if (!did_allocate_port) { + if (local_port() == 0) { + // Allocate an unused ephemeral port. + constexpr u16 first_ephemeral_port = 32768; + constexpr u16 last_ephemeral_port = 60999; + constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; + u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; + + return sockets_by_tuple().with_exclusive([&](auto& table) -> ErrorOr { + u16 port = first_scan_port; + while (true) { + IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); + + auto it = table.find(proposed_tuple); + if (it == table.end()) { + set_local_port(port); + table.set(proposed_tuple, this); + dbgln_if(TCP_SOCKET_DEBUG, "...allocated port {}, tuple {}", port, proposed_tuple.to_string()); + return {}; + } + ++port; + if (port > last_ephemeral_port) + port = first_ephemeral_port; + if (port == first_scan_port) + break; + } + return set_so_error(EADDRINUSE); + }); + } else { + // Verify that the user-supplied port is not already used by someone else. bool ok = sockets_by_tuple().with_exclusive([&](auto& table) -> bool { if (table.contains(tuple())) return false; @@ -435,8 +463,12 @@ ErrorOr TCPSocket::protocol_listen(bool did_allocate_port) }); if (!ok) return set_so_error(EADDRINUSE); + return {}; } +} +ErrorOr TCPSocket::protocol_listen() +{ set_direction(Direction::Passive); set_state(State::Listen); set_setup_state(SetupState::Completed); @@ -453,8 +485,7 @@ ErrorOr TCPSocket::protocol_connect(OpenFileDescription& description) if (!has_specific_local_address()) set_local_address(routing_decision.adapter->ipv4_address()); - if (auto result = allocate_local_port_if_needed(); result.error_or_port.is_error()) - return result.error_or_port.release_error(); + TRY(ensure_bound()); m_sequence_number = get_good_random(); m_ack_number = 0; @@ -487,33 +518,6 @@ ErrorOr TCPSocket::protocol_connect(OpenFileDescription& description) return set_so_error(EINPROGRESS); } -ErrorOr TCPSocket::protocol_allocate_local_port() -{ - constexpr u16 first_ephemeral_port = 32768; - constexpr u16 last_ephemeral_port = 60999; - constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; - u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; - - return sockets_by_tuple().with_exclusive([&](auto& table) -> ErrorOr { - for (u16 port = first_scan_port;;) { - IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); - - auto it = table.find(proposed_tuple); - if (it == table.end()) { - set_local_port(port); - table.set(proposed_tuple, this); - return port; - } - ++port; - if (port > last_ephemeral_port) - port = first_ephemeral_port; - if (port == first_scan_port) - break; - } - return set_so_error(EADDRINUSE); - }); -} - bool TCPSocket::protocol_is_disconnected() const { switch (m_state) { diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index a2d58df05e6..d018db3d96b 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -176,11 +176,10 @@ private: virtual ErrorOr protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override; virtual ErrorOr protocol_send(UserOrKernelBuffer const&, size_t) override; virtual ErrorOr protocol_connect(OpenFileDescription&) override; - virtual ErrorOr protocol_allocate_local_port() override; virtual ErrorOr protocol_size(ReadonlyBytes raw_ipv4_packet) override; virtual bool protocol_is_disconnected() const override; virtual ErrorOr protocol_bind() override; - virtual ErrorOr protocol_listen(bool did_allocate_port) override; + virtual ErrorOr protocol_listen() override; void enqueue_for_retransmit(); void dequeue_for_retransmit(); diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index a9c1c46eefc..5cb5814cce1 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -108,44 +108,47 @@ ErrorOr UDPSocket::protocol_send(UserOrKernelBuffer const& data, size_t ErrorOr UDPSocket::protocol_connect(OpenFileDescription&) { + TRY(ensure_bound()); set_role(Role::Connected); set_connected(true); return {}; } -ErrorOr UDPSocket::protocol_allocate_local_port() -{ - constexpr u16 first_ephemeral_port = 32768; - constexpr u16 last_ephemeral_port = 60999; - constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; - u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; - - return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr { - for (u16 port = first_scan_port;;) { - auto it = table.find(port); - if (it == table.end()) { - set_local_port(port); - table.set(port, this); - return port; - } - ++port; - if (port > last_ephemeral_port) - port = first_ephemeral_port; - if (port == first_scan_port) - break; - } - return set_so_error(EADDRINUSE); - }); -} - ErrorOr UDPSocket::protocol_bind() { - return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr { - if (table.contains(local_port())) + if (local_port() == 0) { + // Allocate an unused ephemeral port. + constexpr u16 first_ephemeral_port = 32768; + constexpr u16 last_ephemeral_port = 60999; + constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; + u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; + + return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr { + u16 port = first_scan_port; + while (true) { + auto it = table.find(port); + if (it == table.end()) { + set_local_port(port); + table.set(port, this); + return {}; + } + ++port; + if (port > last_ephemeral_port) + port = first_ephemeral_port; + if (port == first_scan_port) + break; + } return set_so_error(EADDRINUSE); - table.set(local_port(), this); - return {}; - }); + }); + } else { + // Verify that the user-supplied port is not already used by someone else. + return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr { + if (table.contains(local_port())) + return set_so_error(EADDRINUSE); + table.set(local_port(), this); + return {}; + }); + } } } diff --git a/Kernel/Net/UDPSocket.h b/Kernel/Net/UDPSocket.h index d904f97e247..32b5b05334c 100644 --- a/Kernel/Net/UDPSocket.h +++ b/Kernel/Net/UDPSocket.h @@ -30,7 +30,6 @@ private: virtual ErrorOr protocol_send(UserOrKernelBuffer const&, size_t) override; virtual ErrorOr protocol_size(ReadonlyBytes raw_ipv4_packet) override; virtual ErrorOr protocol_connect(OpenFileDescription&) override; - virtual ErrorOr protocol_allocate_local_port() override; virtual ErrorOr protocol_bind() override; };