ソースを参照

Kernel: Implement TCP listening sockets and incoming connections

Conrad Pankoff 6 年 前
コミット
3eb659a2bb
3 ファイル変更89 行追加15 行削除
  1. 19 4
      Kernel/Net/NetworkTask.cpp
  2. 39 11
      Kernel/Net/TCPSocket.cpp
  3. 31 0
      Kernel/Net/TCPSocket.h

+ 19 - 4
Kernel/Net/NetworkTask.cpp

@@ -353,10 +353,24 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
         return;
     case TCPSocket::State::Listen:
         switch (tcp_packet.flags()) {
-        case TCPFlags::SYN:
-            kprintf("handle_tcp: incoming connections not supported\n");
-            // socket->send_tcp_packet(TCPFlags::RST);
+        case TCPFlags::SYN: {
+            kprintf("handle_tcp: incoming connection\n");
+            auto& local_address = ipv4_packet.destination();
+            auto& peer_address = ipv4_packet.source();
+            auto client = socket->create_client(local_address, tcp_packet.destination_port(), peer_address, tcp_packet.source_port());
+            if (!client) {
+                kprintf("handle_tcp: couldn't create client socket\n");
+                return;
+            }
+            kprintf("handle_tcp: created new client socket with tuple %s\n", client->tuple().to_string().characters());
+            client->set_sequence_number(1000);
+            client->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
+            client->send_tcp_packet(TCPFlags::SYN | TCPFlags::ACK);
+            client->set_sequence_number(1001);
+            client->set_state(TCPSocket::State::SynReceived);
+            kprintf("handle_tcp: Closed -> SynReceived\n");
             return;
+        }
         default:
             kprintf("handle_tcp: unexpected flags in Listen state\n");
             // socket->send_tcp_packet(TCPFlags::RST);
@@ -389,7 +403,8 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
         case TCPFlags::ACK:
             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
             socket->set_state(TCPSocket::State::Established);
-            socket->set_connected(true);
+            if (socket->direction() == TCPSocket::Direction::Outgoing)
+                socket->set_connected(true);
             kprintf("handle_tcp: SynReceived -> Established\n");
             return;
         default:

+ 39 - 11
Kernel/Net/TCPSocket.cpp

@@ -25,16 +25,18 @@ Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& TCPSocket::sockets_by_tuple()
 
 TCPSocketHandle TCPSocket::from_tuple(const IPv4SocketTuple& tuple)
 {
-    RefPtr<TCPSocket> socket;
-    {
-        LOCKER(sockets_by_tuple().lock());
-        auto it = sockets_by_tuple().resource().find(tuple);
-        if (it == sockets_by_tuple().resource().end())
-            return {};
-        socket = (*it).value;
-        ASSERT(socket);
-    }
-    return { move(socket) };
+    LOCKER(sockets_by_tuple().lock());
+
+    auto exact_match = sockets_by_tuple().resource().get(tuple);
+    if (exact_match.has_value())
+        return { move(exact_match.value()) };
+
+    auto address_tuple = IPv4SocketTuple(tuple.local_address(), tuple.local_port(), IPv4Address(), 0);
+    auto address_match = sockets_by_tuple().resource().get(address_tuple);
+    if (address_match.has_value())
+        return { move(address_match.value()) };
+
+    return {};
 }
 
 TCPSocketHandle TCPSocket::from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port)
@@ -42,6 +44,29 @@ TCPSocketHandle TCPSocket::from_endpoints(const IPv4Address& local_address, u16
     return from_tuple(IPv4SocketTuple(local_address, local_port, peer_address, peer_port));
 }
 
+TCPSocketHandle TCPSocket::create_client(const IPv4Address& new_local_address, u16 new_local_port, const IPv4Address& new_peer_address, u16 new_peer_port)
+{
+    auto tuple = IPv4SocketTuple(new_local_address, new_local_port, new_peer_address, new_peer_port);
+
+    LOCKER(sockets_by_tuple().lock());
+    if (sockets_by_tuple().resource().contains(tuple))
+        return {};
+
+    auto client = TCPSocket::create(protocol());
+
+    client->set_local_address(new_local_address);
+    client->set_local_port(new_local_port);
+    client->set_peer_address(new_peer_address);
+    client->set_peer_port(new_peer_port);
+    client->set_direction(Direction::Incoming);
+
+    queue_connection_from(client);
+
+    sockets_by_tuple().resource().set(tuple, client);
+
+    return from_tuple(tuple);
+}
+
 TCPSocket::TCPSocket(int protocol)
     : IPv4Socket(SOCK_STREAM, protocol)
 {
@@ -104,7 +129,7 @@ void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size
     if (flags & TCPFlags::ACK)
         tcp_packet.set_ack_number(m_ack_number);
 
-    if (flags == TCPFlags::SYN) {
+    if (flags & TCPFlags::SYN) {
         ++m_sequence_number;
     } else {
         m_sequence_number += payload_size;
@@ -196,7 +221,9 @@ KResult TCPSocket::protocol_listen()
     if (sockets_by_tuple().resource().contains(tuple()))
         return KResult(-EADDRINUSE);
     sockets_by_tuple().resource().set(tuple(), this);
+    set_direction(Direction::Passive);
     set_state(State::Listen);
+    set_connected(true);
     return KSuccess;
 }
 
@@ -217,6 +244,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh
 
     send_tcp_packet(TCPFlags::SYN);
     m_state = State::SynSent;
+    m_direction = Direction::Outgoing;
 
     if (should_block == ShouldBlock::Yes) {
         if (current->block<Thread::ConnectBlocker>(description) == Thread::BlockResult::InterruptedBySignal)

+ 31 - 0
Kernel/Net/TCPSocket.h

@@ -10,6 +10,29 @@ public:
     static NonnullRefPtr<TCPSocket> create(int protocol);
     virtual ~TCPSocket() override;
 
+    enum class Direction {
+        Unspecified,
+        Outgoing,
+        Incoming,
+        Passive,
+    };
+
+    static const char* to_string(Direction direction)
+    {
+        switch (direction) {
+        case Direction::Unspecified:
+            return "Unspecified";
+        case Direction::Outgoing:
+            return "Outgoing";
+        case Direction::Incoming:
+            return "Incoming";
+        case Direction::Passive:
+            return "Passive";
+        default:
+            return "None";
+        }
+    }
+
     enum class State {
         Closed,
         Listen,
@@ -57,6 +80,8 @@ public:
     State state() const { return m_state; }
     void set_state(State state) { m_state = state; }
 
+    Direction direction() const { return m_direction; }
+
     void set_ack_number(u32 n) { m_ack_number = n; }
     void set_sequence_number(u32 n) { m_sequence_number = n; }
     u32 ack_number() const { return m_ack_number; }
@@ -73,6 +98,11 @@ public:
     static TCPSocketHandle from_tuple(const IPv4SocketTuple& tuple);
     static TCPSocketHandle from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port);
 
+    TCPSocketHandle create_client(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port);
+
+protected:
+    void set_direction(Direction direction) { m_direction = direction; }
+
 private:
     explicit TCPSocket(int protocol);
     virtual const char* class_name() const override { return "TCPSocket"; }
@@ -87,6 +117,7 @@ private:
     virtual KResult protocol_bind() override;
     virtual KResult protocol_listen() override;
 
+    Direction m_direction { Direction::Unspecified };
     WeakPtr<NetworkAdapter> m_adapter;
     u32 m_sequence_number { 0 };
     u32 m_ack_number { 0 };