Jelajahi Sumber

IPv4: Factor out UDP parts of IPv4Socket into a UDPSocket class.

Andreas Kling 6 tahun lalu
induk
melakukan
edb986c276
6 mengubah file dengan 157 tambahan dan 90 penghapusan
  1. 8 86
      Kernel/IPv4Socket.cpp
  2. 0 3
      Kernel/IPv4Socket.h
  3. 1 0
      Kernel/Makefile
  4. 2 1
      Kernel/NetworkTask.cpp
  5. 99 0
      Kernel/UDPSocket.cpp
  6. 47 0
      Kernel/UDPSocket.h

+ 8 - 86
Kernel/IPv4Socket.cpp

@@ -1,5 +1,6 @@
 #include <Kernel/IPv4Socket.h>
 #include <Kernel/TCPSocket.h>
+#include <Kernel/UDPSocket.h>
 #include <Kernel/UnixTypes.h>
 #include <Kernel/Process.h>
 #include <Kernel/NetworkAdapter.h>
@@ -12,28 +13,6 @@
 
 #define IPV4_SOCKET_DEBUG
 
-Lockable<HashMap<word, IPv4Socket*>>& IPv4Socket::sockets_by_udp_port()
-{
-    static Lockable<HashMap<word, IPv4Socket*>>* s_map;
-    if (!s_map)
-        s_map = new Lockable<HashMap<word, IPv4Socket*>>;
-    return *s_map;
-}
-
-IPv4SocketHandle IPv4Socket::from_udp_port(word port)
-{
-    RetainPtr<IPv4Socket> socket;
-    {
-        LOCKER(sockets_by_udp_port().lock());
-        auto it = sockets_by_udp_port().resource().find(port);
-        if (it == sockets_by_udp_port().resource().end())
-            return { };
-        socket = (*it).value;
-        ASSERT(socket);
-    }
-    return { move(socket) };
-}
-
 Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets()
 {
     static Lockable<HashTable<IPv4Socket*>>* s_table;
@@ -46,6 +25,8 @@ Retained<IPv4Socket> IPv4Socket::create(int type, int protocol)
 {
     if (type == SOCK_STREAM)
         return TCPSocket::create(protocol);
+    if (type == SOCK_DGRAM)
+        return UDPSocket::create(protocol);
     return adopt(*new IPv4Socket(type, protocol));
 }
 
@@ -59,14 +40,8 @@ IPv4Socket::IPv4Socket(int type, int protocol)
 
 IPv4Socket::~IPv4Socket()
 {
-    {
-        LOCKER(all_sockets().lock());
-        all_sockets().resource().remove(this);
-    }
-    if (type() == SOCK_DGRAM) {
-        LOCKER(sockets_by_udp_port().lock());
-        sockets_by_udp_port().resource().remove(m_source_port);
-    }
+    LOCKER(all_sockets().lock());
+    all_sockets().resource().remove(this);
 }
 
 bool IPv4Socket::get_address(sockaddr* address, socklen_t* address_size)
@@ -139,26 +114,7 @@ void IPv4Socket::allocate_source_port_if_needed()
 {
     if (m_source_port)
         return;
-    if (type() == SOCK_DGRAM) {
-        // This is not a very efficient allocation algorithm.
-        // FIXME: Replace it with a bitmap or some other fast-paced looker-upper.
-        LOCKER(sockets_by_udp_port().lock());
-        for (word port = 2000; port < 60000; ++port) {
-            auto it = sockets_by_udp_port().resource().find(port);
-            if (it == sockets_by_udp_port().resource().end()) {
-                m_source_port = port;
-                sockets_by_udp_port().resource().set(port, this);
-                return;
-            }
-        }
-        ASSERT_NOT_REACHED();
-    }
-    if (type() == SOCK_STREAM) {
-        protocol_allocate_source_port();
-        return;
-    }
-
-    ASSERT_NOT_REACHED();
+    protocol_allocate_source_port();
 }
 
 ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, const sockaddr* addr, socklen_t addr_length)
@@ -193,26 +149,7 @@ ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, cons
         return data_length;
     }
 
-    if (type() == SOCK_DGRAM) {
-        auto buffer = ByteBuffer::create_zeroed(sizeof(UDPPacket) + data_length);
-        auto& udp_packet = *(UDPPacket*)(buffer.pointer());
-        udp_packet.set_source_port(m_source_port);
-        udp_packet.set_destination_port(m_destination_port);
-        udp_packet.set_length(sizeof(UDPPacket) + data_length);
-        memcpy(udp_packet.payload(), data, data_length);
-        kprintf("sending as udp packet from %s:%u to %s:%u!\n",
-            adapter->ipv4_address().to_string().characters(),
-            source_port(),
-            m_destination_address.to_string().characters(),
-            m_destination_port);
-        adapter->send_ipv4(MACAddress(), m_destination_address, IPv4Protocol::UDP, move(buffer));
-        return data_length;
-    }
-
-    if (type() == SOCK_STREAM)
-        return protocol_send(data, data_length);
-
-    ASSERT_NOT_REACHED();
+    return protocol_send(data, data_length);
 }
 
 ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sockaddr* addr, socklen_t* addr_length)
@@ -266,22 +203,7 @@ ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sock
         return ipv4_packet.payload_size();
     }
 
-    if (type() == SOCK_DGRAM) {
-        auto& udp_packet = *static_cast<const UDPPacket*>(ipv4_packet.payload());
-        ASSERT(udp_packet.length() >= sizeof(UDPPacket)); // FIXME: This should be rejected earlier.
-        ASSERT(buffer_length >= (udp_packet.length() - sizeof(UDPPacket)));
-        if (addr) {
-            auto& ia = *(sockaddr_in*)addr;
-            ia.sin_port = htons(udp_packet.destination_port());
-        }
-        memcpy(buffer, udp_packet.payload(), udp_packet.length() - sizeof(UDPPacket));
-        return udp_packet.length() - sizeof(UDPPacket);
-    }
-
-    if (type() == SOCK_STREAM)
-        return protocol_receive(packet_buffer, buffer, buffer_length, flags, addr, addr_length);
-
-    ASSERT_NOT_REACHED();
+    return protocol_receive(packet_buffer, buffer, buffer_length, flags, addr, addr_length);
 }
 
 void IPv4Socket::did_receive(ByteBuffer&& packet)

+ 0 - 3
Kernel/IPv4Socket.h

@@ -20,9 +20,6 @@ public:
 
     static Lockable<HashTable<IPv4Socket*>>& all_sockets();
 
-    static Lockable<HashMap<word, IPv4Socket*>>& sockets_by_udp_port();
-    static IPv4SocketHandle from_udp_port(word);
-
     virtual KResult bind(const sockaddr*, socklen_t) override;
     virtual KResult connect(const sockaddr*, socklen_t) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;

+ 1 - 0
Kernel/Makefile

@@ -36,6 +36,7 @@ KERNEL_OBJS = \
        LocalSocket.o \
        IPv4Socket.o \
        TCPSocket.o \
+       UDPSocket.o \
        NetworkAdapter.o \
        E1000NetworkAdapter.o \
        NetworkTask.o

+ 2 - 1
Kernel/NetworkTask.cpp

@@ -7,6 +7,7 @@
 #include <Kernel/IPv4.h>
 #include <Kernel/IPv4Socket.h>
 #include <Kernel/TCPSocket.h>
+#include <Kernel/UDPSocket.h>
 #include <Kernel/Process.h>
 #include <Kernel/EtherType.h>
 #include <AK/Lock.h>
@@ -235,7 +236,7 @@ void handle_udp(const EthernetFrameHeader& eth, int frame_size)
     );
 #endif
 
-    auto socket = IPv4Socket::from_udp_port(udp_packet.destination_port());
+    auto socket = UDPSocket::from_port(udp_packet.destination_port());
     if (!socket) {
         kprintf("handle_udp: No UDP socket for port %u\n", udp_packet.destination_port());
         return;

+ 99 - 0
Kernel/UDPSocket.cpp

@@ -0,0 +1,99 @@
+#include <Kernel/UDPSocket.h>
+#include <Kernel/UDP.h>
+#include <Kernel/NetworkAdapter.h>
+#include <Kernel/Process.h>
+
+Lockable<HashMap<word, UDPSocket*>>& UDPSocket::sockets_by_port()
+{
+    static Lockable<HashMap<word, UDPSocket*>>* s_map;
+    if (!s_map)
+        s_map = new Lockable<HashMap<word, UDPSocket*>>;
+    return *s_map;
+}
+
+UDPSocketHandle UDPSocket::from_port(word port)
+{
+    RetainPtr<UDPSocket> socket;
+    {
+        LOCKER(sockets_by_port().lock());
+        auto it = sockets_by_port().resource().find(port);
+        if (it == sockets_by_port().resource().end())
+            return { };
+        socket = (*it).value;
+        ASSERT(socket);
+    }
+    return { move(socket) };
+}
+
+
+UDPSocket::UDPSocket(int protocol)
+    : IPv4Socket(SOCK_DGRAM, protocol)
+{
+}
+
+UDPSocket::~UDPSocket()
+{
+    LOCKER(sockets_by_port().lock());
+    sockets_by_port().resource().remove(source_port());
+}
+
+Retained<UDPSocket> UDPSocket::create(int protocol)
+{
+    return adopt(*new UDPSocket(protocol));
+}
+
+int UDPSocket::protocol_receive(const ByteBuffer& packet_buffer, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length)
+{
+    (void)flags;
+    (void)addr_length;
+    ASSERT(!packet_buffer.is_null());
+    auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.pointer());
+    auto& udp_packet = *static_cast<const UDPPacket*>(ipv4_packet.payload());
+    ASSERT(udp_packet.length() >= sizeof(UDPPacket)); // FIXME: This should be rejected earlier.
+    ASSERT(buffer_size >= (udp_packet.length() - sizeof(UDPPacket)));
+    if (addr) {
+        auto& ia = *(sockaddr_in*)addr;
+        ia.sin_port = htons(udp_packet.destination_port());
+    }
+    memcpy(buffer, udp_packet.payload(), udp_packet.length() - sizeof(UDPPacket));
+    return udp_packet.length() - sizeof(UDPPacket);
+}
+
+int UDPSocket::protocol_send(const void* data, int data_length)
+{
+    // FIXME: Figure out the adapter somehow differently.
+    auto& adapter = *NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2));
+    auto buffer = ByteBuffer::create_zeroed(sizeof(UDPPacket) + data_length);
+    auto& udp_packet = *(UDPPacket*)(buffer.pointer());
+    udp_packet.set_source_port(source_port());
+    udp_packet.set_destination_port(destination_port());
+    udp_packet.set_length(sizeof(UDPPacket) + data_length);
+    memcpy(udp_packet.payload(), data, data_length);
+    kprintf("sending as udp packet from %s:%u to %s:%u!\n",
+        adapter.ipv4_address().to_string().characters(),
+        source_port(),
+        destination_address().to_string().characters(),
+        destination_port());
+    adapter.send_ipv4(MACAddress(), destination_address(), IPv4Protocol::UDP, move(buffer));
+    return data_length;
+}
+
+KResult UDPSocket::protocol_connect()
+{
+    return KSuccess;
+}
+
+void UDPSocket::protocol_allocate_source_port()
+{
+    // This is not a very efficient allocation algorithm.
+    // FIXME: Replace it with a bitmap or some other fast-paced looker-upper.
+    LOCKER(sockets_by_port().lock());
+    for (word port = 2000; port < 60000; ++port) {
+        auto it = sockets_by_port().resource().find(port);
+        if (it == sockets_by_port().resource().end()) {
+            set_source_port(port);
+            sockets_by_port().resource().set(port, this);
+            return;
+        }
+    }
+}

+ 47 - 0
Kernel/UDPSocket.h

@@ -0,0 +1,47 @@
+#pragma once
+
+#include <Kernel/IPv4Socket.h>
+
+class UDPSocketHandle;
+
+class UDPSocket final : public IPv4Socket {
+public:
+    static Retained<UDPSocket> create(int protocol);
+    virtual ~UDPSocket() override;
+
+    static Lockable<HashMap<word, UDPSocket*>>& sockets_by_port();
+    static UDPSocketHandle from_port(word);
+
+private:
+    explicit UDPSocket(int protocol);
+
+    virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
+    virtual int protocol_send(const void*, int) override;
+    virtual KResult protocol_connect() override;
+    virtual void protocol_allocate_source_port() override;
+};
+
+class UDPSocketHandle : public SocketHandle {
+public:
+    UDPSocketHandle() { }
+
+    UDPSocketHandle(RetainPtr<UDPSocket>&& socket)
+        : SocketHandle(move(socket))
+    {
+    }
+
+    UDPSocketHandle(UDPSocketHandle&& other)
+        : SocketHandle(move(other))
+    {
+    }
+
+    UDPSocketHandle(const UDPSocketHandle&) = delete;
+    UDPSocketHandle& operator=(const UDPSocketHandle&) = delete;
+
+    UDPSocket* operator->() { return &socket(); }
+    const UDPSocket* operator->() const { return &socket(); }
+
+    UDPSocket& socket() { return static_cast<UDPSocket&>(SocketHandle::socket()); }
+    const UDPSocket& socket() const { return static_cast<const UDPSocket&>(SocketHandle::socket()); }
+};
+