فهرست منبع

Kernel: Migrate UDP socket table locking to ProtectedValue

Jean-Baptiste Boric 4 سال پیش
والد
کامیت
9517100672
2فایلهای تغییر یافته به همراه37 افزوده شده و 36 حذف شده
  1. 35 34
      Kernel/Net/UDPSocket.cpp
  2. 2 2
      Kernel/Net/UDPSocket.h

+ 35 - 34
Kernel/Net/UDPSocket.cpp

@@ -6,7 +6,6 @@
 
 #include <AK/Singleton.h>
 #include <Kernel/Devices/RandomDevice.h>
-#include <Kernel/Locking/Mutex.h>
 #include <Kernel/Net/NetworkAdapter.h>
 #include <Kernel/Net/Routing.h>
 #include <Kernel/Net/UDP.h>
@@ -18,30 +17,29 @@ namespace Kernel {
 
 void UDPSocket::for_each(Function<void(const UDPSocket&)> callback)
 {
-    MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared);
-    for (auto it : sockets_by_port().resource())
-        callback(*it.value);
+    sockets_by_port().for_each_shared([&](const auto& socket) {
+        callback(*socket.value);
+    });
 }
 
-static AK::Singleton<Lockable<HashMap<u16, UDPSocket*>>> s_map;
+static AK::Singleton<ProtectedValue<HashMap<u16, UDPSocket*>>> s_map;
 
-Lockable<HashMap<u16, UDPSocket*>>& UDPSocket::sockets_by_port()
+ProtectedValue<HashMap<u16, UDPSocket*>>& UDPSocket::sockets_by_port()
 {
     return *s_map;
 }
 
 SocketHandle<UDPSocket> UDPSocket::from_port(u16 port)
 {
-    RefPtr<UDPSocket> socket;
-    {
-        MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared);
-        auto it = sockets_by_port().resource().find(port);
-        if (it == sockets_by_port().resource().end())
+    return sockets_by_port().with_shared([&](const auto& table) -> SocketHandle<UDPSocket> {
+        RefPtr<UDPSocket> socket;
+        auto it = table.find(port);
+        if (it == table.end())
             return {};
         socket = (*it).value;
         VERIFY(socket);
-    }
-    return { *socket };
+        return { *socket };
+    });
 }
 
 UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
@@ -51,8 +49,9 @@ UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
 
 UDPSocket::~UDPSocket()
 {
-    MutexLocker locker(sockets_by_port().lock());
-    sockets_by_port().resource().remove(local_port());
+    sockets_by_port().with_exclusive([&](auto& table) {
+        table.remove(local_port());
+    });
 }
 
 KResultOr<NonnullRefPtr<UDPSocket>> UDPSocket::create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
@@ -113,30 +112,32 @@ KResultOr<u16> UDPSocket::protocol_allocate_local_port()
     constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
     u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size;
 
-    MutexLocker locker(sockets_by_port().lock());
-    for (u16 port = first_scan_port;;) {
-        auto it = sockets_by_port().resource().find(port);
-        if (it == sockets_by_port().resource().end()) {
-            set_local_port(port);
-            sockets_by_port().resource().set(port, this);
-            return port;
+    return sockets_by_port().with_exclusive([&](auto& table) -> KResultOr<u16> {
+        for (u16 port = first_scan_port;;) {
+            auto it = table.find(port);
+            if (it == table.end()) {
+                set_local_port(port);
+                table.set(port, this);
+                return port;
+            }
+            ++port;
+            if (port > last_ephemeral_port)
+                port = first_ephemeral_port;
+            if (port == first_scan_port)
+                break;
         }
-        ++port;
-        if (port > last_ephemeral_port)
-            port = first_ephemeral_port;
-        if (port == first_scan_port)
-            break;
-    }
-    return EADDRINUSE;
+        return EADDRINUSE;
+    });
 }
 
 KResult UDPSocket::protocol_bind()
 {
-    MutexLocker locker(sockets_by_port().lock());
-    if (sockets_by_port().resource().contains(local_port()))
-        return EADDRINUSE;
-    sockets_by_port().resource().set(local_port(), this);
-    return KSuccess;
+    return sockets_by_port().with_exclusive([&](auto& table) -> KResult {
+        if (table.contains(local_port()))
+            return EADDRINUSE;
+        table.set(local_port(), this);
+        return KSuccess;
+    });
 }
 
 }

+ 2 - 2
Kernel/Net/UDPSocket.h

@@ -7,7 +7,7 @@
 #pragma once
 
 #include <Kernel/KResult.h>
-#include <Kernel/Locking/Lockable.h>
+#include <Kernel/Locking/ProtectedValue.h>
 #include <Kernel/Net/IPv4Socket.h>
 
 namespace Kernel {
@@ -23,7 +23,7 @@ public:
 private:
     explicit UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
     virtual StringView class_name() const override { return "UDPSocket"; }
-    static Lockable<HashMap<u16, UDPSocket*>>& sockets_by_port();
+    static ProtectedValue<HashMap<u16, UDPSocket*>>& sockets_by_port();
 
     virtual KResultOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
     virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;