Просмотр исходного кода

Kernel: Add SocketHandle helper class that wraps locked sockets.

This allows us to have a comfy IPv4Socket::from_tcp_port() API that returns
a socket that's locked and safe to access. No need to worry about locking
at the client site.
Andreas Kling 6 лет назад
Родитель
Сommit
54e7df0586
4 измененных файлов с 106 добавлено и 22 удалено
  1. 34 2
      Kernel/IPv4Socket.cpp
  2. 27 0
      Kernel/IPv4Socket.h
  3. 8 20
      Kernel/NetworkTask.cpp
  4. 37 0
      Kernel/Socket.h

+ 34 - 2
Kernel/IPv4Socket.cpp

@@ -27,6 +27,34 @@ Lockable<HashMap<word, IPv4Socket*>>& IPv4Socket::sockets_by_tcp_port()
     return *s_map;
 }
 
+IPv4SocketHandle IPv4Socket::from_tcp_port(word port)
+{
+    RetainPtr<IPv4Socket> socket;
+    {
+        LOCKER(sockets_by_tcp_port().lock());
+        auto it = sockets_by_tcp_port().resource().find(port);
+        if (it == sockets_by_tcp_port().resource().end())
+            return { };
+        socket = (*it).value;
+        ASSERT(socket);
+    }
+    return { move(socket) };
+}
+
+IPv4SocketHandle IPv4Socket::from_udp_port(word port)
+{
+    RetainPtr<IPv4Socket> socket;
+    {
+        LOCKER(sockets_by_udp_port().lock());
+        auto it = sockets_by_udp_port().resource().find(port);
+        if (it == sockets_by_udp_port().resource().end())
+            return { };
+        socket = (*it).value;
+        ASSERT(socket);
+    }
+    return { move(socket) };
+}
+
 Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets()
 {
     static Lockable<HashTable<IPv4Socket*>>* s_table;
@@ -217,8 +245,12 @@ NetworkOrdered<word> IPv4Socket::compute_tcp_checksum(const IPv4Address& source,
         if (checksum > 0xffff)
             checksum = (checksum >> 16) + (checksum & 0xffff);
     }
-    if (payload_size & 1)
-        ASSERT_NOT_REACHED();
+    if (payload_size & 1) {
+        word expanded_byte = ((const byte*)packet.payload())[payload_size - 1];
+        checksum += expanded_byte;
+        if (checksum > 0xffff)
+            checksum = (checksum >> 16) + (checksum & 0xffff);
+    }
     return ~(checksum & 0xffff);
 }
 

+ 27 - 0
Kernel/IPv4Socket.h

@@ -7,6 +7,7 @@
 #include <AK/Lock.h>
 #include <AK/SinglyLinkedList.h>
 
+class IPv4SocketHandle;
 class NetworkAdapter;
 class TCPPacket;
 
@@ -28,6 +29,9 @@ public:
     static Lockable<HashMap<word, IPv4Socket*>>& sockets_by_udp_port();
     static Lockable<HashMap<word, IPv4Socket*>>& sockets_by_tcp_port();
 
+    static IPv4SocketHandle from_tcp_port(word);
+    static IPv4SocketHandle from_udp_port(word);
+
     virtual KResult bind(const sockaddr*, socklen_t) override;
     virtual KResult connect(const sockaddr*, socklen_t) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
@@ -79,3 +83,26 @@ private:
     bool m_can_read { false };
 };
 
+class IPv4SocketHandle : public SocketHandle {
+public:
+    IPv4SocketHandle() { }
+
+    IPv4SocketHandle(RetainPtr<IPv4Socket>&& socket)
+        : SocketHandle(move(socket))
+    {
+    }
+
+    IPv4SocketHandle(IPv4SocketHandle&& other)
+        : SocketHandle(move(other))
+    {
+    }
+
+    IPv4SocketHandle(const IPv4SocketHandle&) = delete;
+    IPv4SocketHandle& operator=(const IPv4SocketHandle&) = delete;
+
+    IPv4Socket* operator->() { return &socket(); }
+    const IPv4Socket* operator->() const { return &socket(); }
+
+    IPv4Socket& socket() { return static_cast<IPv4Socket&>(SocketHandle::socket()); }
+    const IPv4Socket& socket() const { return static_cast<const IPv4Socket&>(SocketHandle::socket()); }
+};

+ 8 - 20
Kernel/NetworkTask.cpp

@@ -234,17 +234,12 @@ void handle_udp(const EthernetFrameHeader& eth, int frame_size)
     );
 #endif
 
-    RetainPtr<IPv4Socket> socket;
-    {
-        LOCKER(IPv4Socket::sockets_by_udp_port().lock());
-        auto it = IPv4Socket::sockets_by_udp_port().resource().find(udp_packet.destination_port());
-        if (it == IPv4Socket::sockets_by_udp_port().resource().end())
-            return;
-        ASSERT((*it).value);
-        socket = *(*it).value;
+    auto socket = IPv4Socket::from_udp_port(udp_packet.destination_port());
+    if (!socket) {
+        kprintf("handle_udp: No UDP socket for port %u\n", udp_packet.destination_port());
+        return;
     }
 
-    LOCKER(socket->lock());
     ASSERT(socket->type() == SOCK_DGRAM);
     ASSERT(socket->source_port() == udp_packet.destination_port());
     socket->did_receive(ByteBuffer::copy((const byte*)&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
@@ -280,19 +275,12 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
     );
 #endif
 
-    RetainPtr<IPv4Socket> socket;
-    {
-        LOCKER(IPv4Socket::sockets_by_tcp_port().lock());
-        auto it = IPv4Socket::sockets_by_tcp_port().resource().find(tcp_packet.destination_port());
-        if (it == IPv4Socket::sockets_by_tcp_port().resource().end()) {
-            kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port());
-            return;
-        }
-        ASSERT((*it).value);
-        socket = *(*it).value;
+    auto socket = IPv4Socket::from_tcp_port(tcp_packet.destination_port());
+    if (!socket) {
+        kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port());
+        return;
     }
 
-    LOCKER(socket->lock());
     ASSERT(socket->type() == SOCK_STREAM);
     ASSERT(socket->source_port() == tcp_packet.destination_port());
 

+ 37 - 0
Kernel/Socket.h

@@ -76,3 +76,40 @@ private:
     Vector<RetainPtr<Socket>> m_pending;
     Vector<RetainPtr<Socket>> m_clients;
 };
+
+class SocketHandle {
+public:
+    SocketHandle() { }
+
+    SocketHandle(RetainPtr<Socket>&& socket)
+        : m_socket(move(socket))
+    {
+        if (m_socket)
+            m_socket->lock().lock();
+    }
+
+    SocketHandle(SocketHandle&& other)
+        : m_socket(move(other.m_socket))
+    {
+    }
+
+    ~SocketHandle()
+    {
+        if (m_socket)
+            m_socket->lock().unlock();
+    }
+
+    SocketHandle(const SocketHandle&) = delete;
+    SocketHandle& operator=(const SocketHandle&) = delete;
+
+    operator bool() const { return m_socket; }
+
+    Socket* operator->() { return &socket(); }
+    const Socket* operator->() const { return &socket(); }
+
+    Socket& socket() { return *m_socket; }
+    const Socket& socket() const { return *m_socket; }
+
+private:
+    RetainPtr<Socket> m_socket;
+};