Quellcode durchsuchen

IPv4: Dynamically allocate the UDP source port if needed.

Andreas Kling vor 6 Jahren
Ursprung
Commit
209a16bb7f
3 geänderte Dateien mit 47 neuen und 1 gelöschten Zeilen
  1. 6 0
      AK/Traits.h
  2. 36 1
      Kernel/IPv4Socket.cpp
  3. 5 0
      Kernel/IPv4Socket.h

+ 6 - 0
AK/Traits.h

@@ -22,6 +22,12 @@ struct Traits<unsigned> {
     static void dump(unsigned u) { kprintf("%u", u); }
     static void dump(unsigned u) { kprintf("%u", u); }
 };
 };
 
 
+template<>
+struct Traits<word> {
+    static unsigned hash(unsigned u) { return int_hash(u); }
+    static void dump(unsigned u) { kprintf("%u", u); }
+};
+
 template<typename T>
 template<typename T>
 struct Traits<T*> {
 struct Traits<T*> {
     static unsigned hash(const T* p)
     static unsigned hash(const T* p)

+ 36 - 1
Kernel/IPv4Socket.cpp

@@ -10,6 +10,22 @@
 
 
 #define IPV4_SOCKET_DEBUG
 #define IPV4_SOCKET_DEBUG
 
 
+Lockable<HashMap<word, IPv4Socket*>>& IPv4Socket::sockets_by_udp_port()
+{
+    static Lockable<HashMap<word, IPv4Socket*>>* s_map;
+    if (!s_map)
+        s_map = new Lockable<HashMap<word, IPv4Socket*>>;
+    return *s_map;
+}
+
+Lockable<HashMap<word, IPv4Socket*>>& IPv4Socket::sockets_by_tcp_port()
+{
+    static Lockable<HashMap<word, IPv4Socket*>>* s_map;
+    if (!s_map)
+        s_map = new Lockable<HashMap<word, IPv4Socket*>>;
+    return *s_map;
+}
+
 Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets()
 Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets()
 {
 {
     static Lockable<HashTable<IPv4Socket*>>* s_table;
     static Lockable<HashTable<IPv4Socket*>>* s_table;
@@ -100,6 +116,25 @@ bool IPv4Socket::can_write(SocketRole role) const
     ASSERT_NOT_REACHED();
     ASSERT_NOT_REACHED();
 }
 }
 
 
+void IPv4Socket::allocate_source_port_if_needed()
+{
+    if (m_source_port)
+        return;
+    if (type() == SOCK_DGRAM) {
+        // This is not a very efficient allocation algorithm.
+        // FIXME: Replace it with a bitmap or some other fast-paced looker-upper.
+        LOCKER(sockets_by_udp_port().lock());
+        for (word port = 2000; port < 60000; ++port) {
+            auto it = sockets_by_udp_port().resource().find(port);
+            if (it == sockets_by_udp_port().resource().end()) {
+                m_source_port = port;
+                return;
+            }
+        }
+        ASSERT_NOT_REACHED();
+    }
+}
+
 ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, const sockaddr* addr, socklen_t addr_length)
 ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, const sockaddr* addr, socklen_t addr_length)
 {
 {
     (void)flags;
     (void)flags;
@@ -121,7 +156,7 @@ ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, cons
     m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr);
     m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr);
     m_destination_port = ntohs(ia.sin_port);
     m_destination_port = ntohs(ia.sin_port);
 
 
-    m_source_port = 2413;
+    allocate_source_port_if_needed();
 
 
     kprintf("sendto: destination=%s:%u\n", m_destination_address.to_string().characters(), m_destination_port);
     kprintf("sendto: destination=%s:%u\n", m_destination_address.to_string().characters(), m_destination_port);
 
 

+ 5 - 0
Kernel/IPv4Socket.h

@@ -3,6 +3,7 @@
 #include <Kernel/Socket.h>
 #include <Kernel/Socket.h>
 #include <Kernel/DoubleBuffer.h>
 #include <Kernel/DoubleBuffer.h>
 #include <Kernel/IPv4.h>
 #include <Kernel/IPv4.h>
+#include <AK/HashMap.h>
 #include <AK/Lock.h>
 #include <AK/Lock.h>
 #include <AK/SinglyLinkedList.h>
 #include <AK/SinglyLinkedList.h>
 
 
@@ -12,6 +13,8 @@ public:
     virtual ~IPv4Socket() override;
     virtual ~IPv4Socket() override;
 
 
     static Lockable<HashTable<IPv4Socket*>>& all_sockets();
     static Lockable<HashTable<IPv4Socket*>>& all_sockets();
+    static Lockable<HashMap<word, IPv4Socket*>>& sockets_by_udp_port();
+    static Lockable<HashMap<word, IPv4Socket*>>& sockets_by_tcp_port();
 
 
     virtual KResult bind(const sockaddr*, socklen_t) override;
     virtual KResult bind(const sockaddr*, socklen_t) override;
     virtual KResult connect(const sockaddr*, socklen_t) override;
     virtual KResult connect(const sockaddr*, socklen_t) override;
@@ -36,6 +39,8 @@ private:
     IPv4Socket(int type, int protocol);
     IPv4Socket(int type, int protocol);
     virtual bool is_ipv4() const override { return true; }
     virtual bool is_ipv4() const override { return true; }
 
 
+    void allocate_source_port_if_needed();
+
     bool m_bound { false };
     bool m_bound { false };
     int m_attached_fds { 0 };
     int m_attached_fds { 0 };
     IPv4Address m_destination_address;
     IPv4Address m_destination_address;