Kernel: Migrate UDP socket table locking to ProtectedValue

This commit is contained in:
Jean-Baptiste Boric 2021-07-18 12:24:34 +02:00 committed by Andreas Kling
parent 9216c72bfe
commit 9517100672
Notes: sideshowbarker 2024-07-18 07:20:11 +09:00
2 changed files with 37 additions and 36 deletions

View file

@ -6,7 +6,6 @@
#include <AK/Singleton.h>
#include <Kernel/Devices/RandomDevice.h>
#include <Kernel/Locking/Mutex.h>
#include <Kernel/Net/NetworkAdapter.h>
#include <Kernel/Net/Routing.h>
#include <Kernel/Net/UDP.h>
@ -18,30 +17,29 @@ namespace Kernel {
void UDPSocket::for_each(Function<void(const UDPSocket&)> callback)
{
MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared);
for (auto it : sockets_by_port().resource())
callback(*it.value);
sockets_by_port().for_each_shared([&](const auto& socket) {
callback(*socket.value);
});
}
static AK::Singleton<Lockable<HashMap<u16, UDPSocket*>>> s_map;
static AK::Singleton<ProtectedValue<HashMap<u16, UDPSocket*>>> s_map;
Lockable<HashMap<u16, UDPSocket*>>& UDPSocket::sockets_by_port()
ProtectedValue<HashMap<u16, UDPSocket*>>& UDPSocket::sockets_by_port()
{
return *s_map;
}
SocketHandle<UDPSocket> UDPSocket::from_port(u16 port)
{
RefPtr<UDPSocket> socket;
{
MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared);
auto it = sockets_by_port().resource().find(port);
if (it == sockets_by_port().resource().end())
return sockets_by_port().with_shared([&](const auto& table) -> SocketHandle<UDPSocket> {
RefPtr<UDPSocket> socket;
auto it = table.find(port);
if (it == table.end())
return {};
socket = (*it).value;
VERIFY(socket);
}
return { *socket };
return { *socket };
});
}
UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
@ -51,8 +49,9 @@ UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
UDPSocket::~UDPSocket()
{
MutexLocker locker(sockets_by_port().lock());
sockets_by_port().resource().remove(local_port());
sockets_by_port().with_exclusive([&](auto& table) {
table.remove(local_port());
});
}
KResultOr<NonnullRefPtr<UDPSocket>> UDPSocket::create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
@ -113,30 +112,32 @@ KResultOr<u16> UDPSocket::protocol_allocate_local_port()
constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size;
MutexLocker locker(sockets_by_port().lock());
for (u16 port = first_scan_port;;) {
auto it = sockets_by_port().resource().find(port);
if (it == sockets_by_port().resource().end()) {
set_local_port(port);
sockets_by_port().resource().set(port, this);
return port;
return sockets_by_port().with_exclusive([&](auto& table) -> KResultOr<u16> {
for (u16 port = first_scan_port;;) {
auto it = table.find(port);
if (it == table.end()) {
set_local_port(port);
table.set(port, this);
return port;
}
++port;
if (port > last_ephemeral_port)
port = first_ephemeral_port;
if (port == first_scan_port)
break;
}
++port;
if (port > last_ephemeral_port)
port = first_ephemeral_port;
if (port == first_scan_port)
break;
}
return EADDRINUSE;
return EADDRINUSE;
});
}
KResult UDPSocket::protocol_bind()
{
MutexLocker locker(sockets_by_port().lock());
if (sockets_by_port().resource().contains(local_port()))
return EADDRINUSE;
sockets_by_port().resource().set(local_port(), this);
return KSuccess;
return sockets_by_port().with_exclusive([&](auto& table) -> KResult {
if (table.contains(local_port()))
return EADDRINUSE;
table.set(local_port(), this);
return KSuccess;
});
}
}

View file

@ -7,7 +7,7 @@
#pragma once
#include <Kernel/KResult.h>
#include <Kernel/Locking/Lockable.h>
#include <Kernel/Locking/ProtectedValue.h>
#include <Kernel/Net/IPv4Socket.h>
namespace Kernel {
@ -23,7 +23,7 @@ public:
private:
explicit UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
virtual StringView class_name() const override { return "UDPSocket"; }
static Lockable<HashMap<u16, UDPSocket*>>& sockets_by_port();
static ProtectedValue<HashMap<u16, UDPSocket*>>& sockets_by_port();
virtual KResultOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;