Kernel: Add support for TCP window size scaling

This should allow us to eventually properly saturate high-bandwidth
network links when using TCP, once other nonoptimal parts of our
network stack are improved.
This commit is contained in:
Idan Horowitz 2023-12-26 21:05:15 +02:00 committed by Andreas Kling
parent 2c51ff763b
commit 785c9d5c2b
Notes: sideshowbarker 2024-07-17 03:14:39 +09:00
6 changed files with 124 additions and 14 deletions

View file

@ -38,7 +38,7 @@ MutexProtected<IPv4Socket::List>& IPv4Socket::all_sockets()
ErrorOr<NonnullOwnPtr<DoubleBuffer>> IPv4Socket::try_create_receive_buffer() ErrorOr<NonnullOwnPtr<DoubleBuffer>> IPv4Socket::try_create_receive_buffer()
{ {
return DoubleBuffer::try_create("IPv4Socket: Receive buffer"sv, 256 * KiB); return DoubleBuffer::try_create("IPv4Socket: Receive buffer"sv, receive_buffer_size);
} }
ErrorOr<NonnullRefPtr<Socket>> IPv4Socket::create(int type, int protocol) ErrorOr<NonnullRefPtr<Socket>> IPv4Socket::create(int type, int protocol)

View file

@ -68,6 +68,8 @@ public:
BufferMode buffer_mode() const { return m_buffer_mode; } BufferMode buffer_mode() const { return m_buffer_mode; }
protected: protected:
static constexpr size_t receive_buffer_size = 256 * KiB;
IPv4Socket(int type, int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer, OwnPtr<KBuffer> optional_scratch_buffer); IPv4Socket(int type, int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer, OwnPtr<KBuffer> optional_scratch_buffer);
virtual StringView class_name() const override { return "IPv4Socket"sv; } virtual StringView class_name() const override { return "IPv4Socket"sv; }

View file

@ -430,6 +430,19 @@ void handle_tcp(IPv4Packet const& ipv4_packet, UnixDateTime const& packet_timest
dbgln_if(TCP_DEBUG, "handle_tcp: got socket {}; state={}", socket->tuple().to_string(), TCPSocket::to_string(socket->state())); dbgln_if(TCP_DEBUG, "handle_tcp: got socket {}; state={}", socket->tuple().to_string(), TCPSocket::to_string(socket->state()));
socket->receive_tcp_packet(tcp_packet, ipv4_packet.payload_size()); socket->receive_tcp_packet(tcp_packet, ipv4_packet.payload_size());
Optional<u8> send_window_scale;
if (tcp_packet.has_syn()) {
tcp_packet.for_each_option([&send_window_scale](auto const& option) {
if (option.kind() != TCPOptionKind::WindowScale)
return;
if (option.length() != sizeof(TCPOptionWindowScale))
return;
auto scale = static_cast<TCPOptionWindowScale const&>(option).value();
if (scale > 14)
return; // Maximum allowed as per RFC7323
send_window_scale = scale;
});
}
switch (socket->state()) { switch (socket->state()) {
case TCPSocket::State::Closed: case TCPSocket::State::Closed:
@ -459,6 +472,8 @@ void handle_tcp(IPv4Packet const& ipv4_packet, UnixDateTime const& packet_timest
client->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); client->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
[[maybe_unused]] auto rc2 = client->send_tcp_packet(TCPFlags::SYN | TCPFlags::ACK); [[maybe_unused]] auto rc2 = client->send_tcp_packet(TCPFlags::SYN | TCPFlags::ACK);
client->set_state(TCPSocket::State::SynReceived); client->set_state(TCPSocket::State::SynReceived);
if (send_window_scale.has_value())
client->set_send_window_scale(*send_window_scale);
return; return;
} }
default: default:
@ -472,6 +487,8 @@ void handle_tcp(IPv4Packet const& ipv4_packet, UnixDateTime const& packet_timest
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
(void)socket->send_tcp_packet(TCPFlags::SYN | TCPFlags::ACK); (void)socket->send_tcp_packet(TCPFlags::SYN | TCPFlags::ACK);
socket->set_state(TCPSocket::State::SynReceived); socket->set_state(TCPSocket::State::SynReceived);
if (send_window_scale.has_value())
socket->set_send_window_scale(*send_window_scale);
return; return;
case TCPFlags::ACK | TCPFlags::SYN: case TCPFlags::ACK | TCPFlags::SYN:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
@ -479,6 +496,8 @@ void handle_tcp(IPv4Packet const& ipv4_packet, UnixDateTime const& packet_timest
socket->set_state(TCPSocket::State::Established); socket->set_state(TCPSocket::State::Established);
socket->set_setup_state(Socket::SetupState::Completed); socket->set_setup_state(Socket::SetupState::Completed);
socket->set_connected(true); socket->set_connected(true);
if (send_window_scale.has_value())
socket->set_send_window_scale(*send_window_scale);
return; return;
case TCPFlags::ACK | TCPFlags::FIN: case TCPFlags::ACK | TCPFlags::FIN:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);

View file

@ -21,21 +21,59 @@ struct TCPFlags {
}; };
}; };
class [[gnu::packed]] TCPOptionMSS { enum class TCPOptionKind : u8 {
End = 0,
Nop = 1,
MSS = 2,
WindowScale = 3,
SACKPermitted = 4,
SACK = 5,
Timestamp = 6,
};
class [[gnu::packed]] TCPOption {
public:
TCPOptionKind kind() const { return m_kind; }
u8 length() const { return m_length; }
protected:
TCPOption(TCPOptionKind kind, u8 length)
: m_kind(kind)
, m_length(length) {};
private:
TCPOptionKind m_kind { TCPOptionKind::End };
u8 m_length { sizeof(TCPOption) };
};
class [[gnu::packed]] TCPOptionMSS : public TCPOption {
public: public:
TCPOptionMSS(u16 value) TCPOptionMSS(u16 value)
: m_value(value) : TCPOption(TCPOptionKind::MSS, sizeof(TCPOptionMSS))
, m_value(value)
{ {
} }
u16 value() const { return m_value; } u16 value() const { return m_value; }
private: private:
u8 m_option_kind { 0x02 };
u8 m_option_length { sizeof(TCPOptionMSS) };
NetworkOrdered<u16> m_value; NetworkOrdered<u16> m_value;
}; };
class [[gnu::packed]] TCPOptionWindowScale : public TCPOption {
public:
TCPOptionWindowScale(u8 value)
: TCPOption(TCPOptionKind::WindowScale, sizeof(TCPOptionWindowScale))
, m_value(value)
{
}
u8 value() const { return m_value; }
private:
NetworkOrdered<u8> m_value;
};
static_assert(AssertSize<TCPOptionMSS, 4>()); static_assert(AssertSize<TCPOptionMSS, 4>());
class [[gnu::packed]] TCPPacket { class [[gnu::packed]] TCPPacket {
@ -80,6 +118,28 @@ public:
void const* payload() const { return ((u8 const*)this) + header_size(); } void const* payload() const { return ((u8 const*)this) + header_size(); }
void* payload() { return ((u8*)this) + header_size(); } void* payload() { return ((u8*)this) + header_size(); }
template<typename Callback>
void for_each_option(Callback callback) const
{
auto const* next_option = (u8 const*)this + sizeof(TCPPacket);
auto const* options_end = payload();
while (next_option < options_end) {
if ((size_t)options_end - (size_t)next_option < sizeof(TCPOption))
return; // Not enough space left for another option
auto const* option = (TCPOption const*)next_option;
if (option->kind() == TCPOptionKind::End)
return;
if (option->kind() == TCPOptionKind::Nop) {
next_option += 1;
continue;
}
if (option->length() < sizeof(TCPOption))
return; // minimal option length
callback(*option);
next_option += option->length();
}
}
private: private:
NetworkOrdered<u16> m_source_port; NetworkOrdered<u16> m_source_port;
NetworkOrdered<u16> m_destination_port; NetworkOrdered<u16> m_destination_port;

View file

@ -245,10 +245,11 @@ ErrorOr<void> TCPSocket::send_tcp_packet(u16 flags, UserOrKernelBuffer const* pa
auto ipv4_payload_offset = routing_decision.adapter->ipv4_payload_offset(); auto ipv4_payload_offset = routing_decision.adapter->ipv4_payload_offset();
bool const has_mss_option = flags == TCPFlags::SYN; bool const has_mss_option = flags & TCPFlags::SYN;
const size_t options_size = has_mss_option ? sizeof(TCPOptionMSS) : 0; bool const has_window_scale_option = flags & TCPFlags::SYN;
const size_t tcp_header_size = sizeof(TCPPacket) + options_size; size_t const options_size = (has_mss_option ? sizeof(TCPOptionMSS) : 0) + (has_window_scale_option ? sizeof(TCPOptionWindowScale) : 0);
const size_t buffer_size = ipv4_payload_offset + tcp_header_size + payload_size; size_t const tcp_header_size = sizeof(TCPPacket) + align_up_to(options_size, 4);
size_t const buffer_size = ipv4_payload_offset + tcp_header_size + payload_size;
auto packet = routing_decision.adapter->acquire_packet_buffer(buffer_size); auto packet = routing_decision.adapter->acquire_packet_buffer(buffer_size);
if (!packet) if (!packet)
return set_so_error(ENOMEM); return set_so_error(ENOMEM);
@ -260,7 +261,10 @@ ErrorOr<void> TCPSocket::send_tcp_packet(u16 flags, UserOrKernelBuffer const* pa
VERIFY(local_port()); VERIFY(local_port());
tcp_packet.set_source_port(local_port()); tcp_packet.set_source_port(local_port());
tcp_packet.set_destination_port(peer_port()); tcp_packet.set_destination_port(peer_port());
tcp_packet.set_window_size(min(available_space_in_receive_buffer(), NumericLimits<u16>::max())); auto window_size = available_space_in_receive_buffer();
if ((flags & TCPFlags::SYN) == 0 && m_window_scaling_supported)
window_size >>= receive_window_scale();
tcp_packet.set_window_size(min(window_size, NumericLimits<u16>::max()));
tcp_packet.set_sequence_number(m_sequence_number); tcp_packet.set_sequence_number(m_sequence_number);
tcp_packet.set_data_offset(tcp_header_size / sizeof(u32)); tcp_packet.set_data_offset(tcp_header_size / sizeof(u32));
tcp_packet.set_flags(flags); tcp_packet.set_flags(flags);
@ -284,12 +288,20 @@ ErrorOr<void> TCPSocket::send_tcp_packet(u16 flags, UserOrKernelBuffer const* pa
m_sequence_number += payload_size; m_sequence_number += payload_size;
} }
u8* next_option = packet->buffer->data() + ipv4_payload_offset + sizeof(TCPPacket);
if (has_mss_option) { if (has_mss_option) {
u16 mss = routing_decision.adapter->mtu() - sizeof(IPv4Packet) - sizeof(TCPPacket); u16 mss = routing_decision.adapter->mtu() - sizeof(IPv4Packet) - sizeof(TCPPacket);
TCPOptionMSS mss_option { mss }; TCPOptionMSS mss_option { mss };
VERIFY(packet->buffer->size() >= ipv4_payload_offset + sizeof(TCPPacket) + sizeof(mss_option)); memcpy(next_option, &mss_option, sizeof(mss_option));
memcpy(packet->buffer->data() + ipv4_payload_offset + sizeof(TCPPacket), &mss_option, sizeof(mss_option)); next_option += sizeof(mss_option);
} }
if (has_window_scale_option) {
TCPOptionWindowScale window_scale_option { receive_window_scale() };
memcpy(next_option, &window_scale_option, sizeof(window_scale_option));
next_option += sizeof(window_scale_option);
}
if ((options_size % 4) != 0)
*next_option = to_underlying(TCPOptionKind::End);
tcp_packet.set_checksum(compute_tcp_checksum(local_address(), peer_address(), tcp_packet, payload_size)); tcp_packet.set_checksum(compute_tcp_checksum(local_address(), peer_address(), tcp_packet, payload_size));
@ -339,7 +351,7 @@ void TCPSocket::receive_tcp_packet(TCPPacket const& packet, u16 size)
old_adapter->release_packet_buffer(*packet.buffer); old_adapter->release_packet_buffer(*packet.buffer);
TCPPacket& tcp_packet = *(TCPPacket*)(packet.buffer->buffer->data() + packet.ipv4_payload_offset); TCPPacket& tcp_packet = *(TCPPacket*)(packet.buffer->buffer->data() + packet.ipv4_payload_offset);
if (m_send_window_size != tcp_packet.window_size()) { if (m_send_window_size != tcp_packet.window_size()) {
m_send_window_size = tcp_packet.window_size(); m_send_window_size = tcp_packet.window_size() << m_send_window_scale;
} }
auto payload_size = packet.buffer->buffer->data() + packet.buffer->buffer->size() - (u8*)tcp_packet.payload(); auto payload_size = packet.buffer->buffer->data() + packet.buffer->buffer->size() - (u8*)tcp_packet.payload();
unacked_packets.size -= payload_size; unacked_packets.size -= payload_size;
@ -367,7 +379,7 @@ void TCPSocket::receive_tcp_packet(TCPPacket const& packet, u16 size)
bool TCPSocket::should_delay_next_ack() const bool TCPSocket::should_delay_next_ack() const
{ {
// FIXME: We don't know the MSS here so make a reasonable guess. // FIXME: We don't know the MSS here so make a reasonable guess.
const size_t mss = 1500; size_t const mss = 1500;
// RFC 1122 says we should send an ACK for every two full-sized segments. // RFC 1122 says we should send an ACK for every two full-sized segments.
if (m_ack_number >= m_last_ack_number_sent + 2 * mss) if (m_ack_number >= m_last_ack_number_sent + 2 * mss)

View file

@ -9,6 +9,7 @@
#include <AK/Error.h> #include <AK/Error.h>
#include <AK/Function.h> #include <AK/Function.h>
#include <AK/HashMap.h> #include <AK/HashMap.h>
#include <AK/IntegralMath.h>
#include <AK/SinglyLinkedList.h> #include <AK/SinglyLinkedList.h>
#include <AK/Time.h> #include <AK/Time.h>
#include <Kernel/Library/LockWeakPtr.h> #include <Kernel/Library/LockWeakPtr.h>
@ -135,6 +136,12 @@ public:
u32 packets_out() const { return m_packets_out; } u32 packets_out() const { return m_packets_out; }
u32 bytes_out() const { return m_bytes_out; } u32 bytes_out() const { return m_bytes_out; }
void set_send_window_scale(size_t scale)
{
m_window_scaling_supported = true;
m_send_window_scale = scale;
}
// FIXME: Make this configurable? // FIXME: Make this configurable?
static constexpr u32 maximum_duplicate_acks = 5; static constexpr u32 maximum_duplicate_acks = 5;
void set_duplicate_acks(u32 acks) { m_duplicate_acks = acks; } void set_duplicate_acks(u32 acks) { m_duplicate_acks = acks; }
@ -188,6 +195,14 @@ private:
void enqueue_for_retransmit(); void enqueue_for_retransmit();
void dequeue_for_retransmit(); void dequeue_for_retransmit();
static constexpr size_t receive_window_scale()
{
auto buffer_size_bit_length = AK::log2(receive_buffer_size) + 1;
if (buffer_size_bit_length < 16)
return 0;
return buffer_size_bit_length - 16;
}
LockWeakPtr<TCPSocket> m_originator; LockWeakPtr<TCPSocket> m_originator;
HashMap<IPv4SocketTuple, NonnullRefPtr<TCPSocket>> m_pending_release_for_accept; HashMap<IPv4SocketTuple, NonnullRefPtr<TCPSocket>> m_pending_release_for_accept;
Direction m_direction { Direction::Unspecified }; Direction m_direction { Direction::Unspecified };
@ -229,6 +244,8 @@ private:
// Default to maximum window size. receive_tcp_packet() will update from the // Default to maximum window size. receive_tcp_packet() will update from the
// peer's advertised window size. // peer's advertised window size.
u32 m_send_window_size { 64 * KiB }; u32 m_send_window_size { 64 * KiB };
bool m_window_scaling_supported { false };
size_t m_send_window_scale { 0 };
bool m_no_delay { false }; bool m_no_delay { false };