浏览代码

TCP: Start working on auto-closing connections when we get FIN.

Andreas Kling 6 年之前
父节点
当前提交
25e521f510
共有 6 个文件被更改,包括 35 次插入6 次删除
  1. 18 5
      Kernel/IPv4Socket.cpp
  2. 1 0
      Kernel/IPv4Socket.h
  3. 9 1
      Kernel/NetworkTask.cpp
  4. 1 0
      Kernel/TCP.h
  5. 5 0
      Kernel/TCPSocket.cpp
  6. 1 0
      Kernel/TCPSocket.h

+ 18 - 5
Kernel/IPv4Socket.cpp

@@ -92,22 +92,24 @@ void IPv4Socket::detach_fd(SocketRole)
 
 bool IPv4Socket::can_read(SocketRole) const
 {
+    if (protocol_is_disconnected())
+        return true;
     return m_can_read;
 }
 
-ssize_t IPv4Socket::read(SocketRole, byte*, ssize_t)
+ssize_t IPv4Socket::read(SocketRole, byte* buffer, ssize_t size)
 {
-    ASSERT_NOT_REACHED();
+    return recvfrom(buffer, size, 0, nullptr, 0);
 }
 
-ssize_t IPv4Socket::write(SocketRole, const byte*, ssize_t)
+ssize_t IPv4Socket::write(SocketRole, const byte* data, ssize_t size)
 {
-    ASSERT_NOT_REACHED();
+    return sendto(data, size, 0, nullptr, 0);
 }
 
 bool IPv4Socket::can_write(SocketRole) const
 {
-    ASSERT_NOT_REACHED();
+    return true;
 }
 
 void IPv4Socket::allocate_source_port_if_needed()
@@ -168,9 +170,17 @@ ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sock
         if (!m_receive_queue.is_empty()) {
             packet_buffer = m_receive_queue.take_first();
             m_can_read = !m_receive_queue.is_empty();
+#ifdef IPV4_SOCKET_DEBUG
+            kprintf("IPv4Socket(%p): recvfrom without blocking %d bytes, packets in queue: %d\n", this, packet_buffer.size(), m_receive_queue.size_slow());
+#endif
         }
     }
     if (packet_buffer.is_null()) {
+        if (protocol_is_disconnected()) {
+            kprintf("IPv4Socket{%p} is protocol-disconnected, returning 0 in recvfrom!\n", this);
+            return 0;
+        }
+
         current->set_blocked_socket(this);
         load_receive_deadline();
         block(Process::BlockedReceive);
@@ -185,6 +195,9 @@ ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sock
         ASSERT(!m_receive_queue.is_empty());
         packet_buffer = m_receive_queue.take_first();
         m_can_read = !m_receive_queue.is_empty();
+#ifdef IPV4_SOCKET_DEBUG
+        kprintf("IPv4Socket(%p): recvfrom with blocking %d bytes, packets in queue: %d\n", this, packet_buffer.size(), m_receive_queue.size_slow());
+#endif
     }
     ASSERT(!packet_buffer.is_null());
     auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.pointer());

+ 1 - 0
Kernel/IPv4Socket.h

@@ -51,6 +51,7 @@ protected:
     virtual int protocol_send(const void*, int) { return -ENOTIMPL; }
     virtual KResult protocol_connect() { return KSuccess; }
     virtual void protocol_allocate_source_port() { }
+    virtual bool protocol_is_disconnected() const { return false; }
 
 private:
     virtual bool is_ipv4() const override { return true; }

+ 9 - 1
Kernel/NetworkTask.cpp

@@ -287,7 +287,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
     ASSERT(socket->source_port() == tcp_packet.destination_port());
 
     if (tcp_packet.ack_number() != socket->sequence_number()) {
-        kprintf("handle_tcp: ack/seq mismatch: got %u, wanted %u\n",tcp_packet.ack_number(), socket->sequence_number());
+        kprintf("handle_tcp: ack/seq mismatch: got %u, wanted %u\n", tcp_packet.ack_number(), socket->sequence_number());
         return;
     }
 
@@ -300,6 +300,14 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
         return;
     }
 
+    if (tcp_packet.has_fin()) {
+        kprintf("handle_tcp: Got FIN, payload_size=%u\n", payload_size);
+        socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
+        socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
+        socket->set_state(TCPSocket::State::Disconnecting);
+        return;
+    }
+
     socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
     kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n",
             tcp_packet.ack_number(),

+ 1 - 0
Kernel/TCP.h

@@ -37,6 +37,7 @@ public:
 
     bool has_syn() const { return flags() & TCPFlags::SYN; }
     bool has_ack() const { return flags() & TCPFlags::ACK; }
+    bool has_fin() const { return flags() & TCPFlags::FIN; }
 
     byte data_offset() const { return (m_flags_and_data_offset & 0xf000) >> 12; }
     void set_data_offset(word data_offset) { m_flags_and_data_offset = (m_flags_and_data_offset & ~0xf000) | data_offset << 12; }

+ 5 - 0
Kernel/TCPSocket.cpp

@@ -186,3 +186,8 @@ void TCPSocket::protocol_allocate_source_port()
         }
     }
 }
+
+bool TCPSocket::protocol_is_disconnected() const
+{
+    return m_state == State::Disconnecting || m_state == State::Disconnected;
+}

+ 1 - 0
Kernel/TCPSocket.h

@@ -36,6 +36,7 @@ private:
     virtual int protocol_send(const void*, int) override;
     virtual KResult protocol_connect() override;
     virtual void protocol_allocate_source_port() override;
+    virtual bool protocol_is_disconnected() const override;
 
     dword m_sequence_number { 0 };
     dword m_ack_number { 0 };