Browse Source

Net: Store all the LocalSockets in an InlineLinkedList

Sergey Bugaev 6 years ago
parent
commit
66e5d0bdf3
2 changed files with 30 additions and 1 deletions
  1. 21 0
      Kernel/Net/LocalSocket.cpp
  2. 9 1
      Kernel/Net/LocalSocket.h

+ 21 - 0
Kernel/Net/LocalSocket.cpp

@@ -1,3 +1,4 @@
+#include <AK/StringBuilder.h>
 #include <Kernel/FileSystem/FileDescription.h>
 #include <Kernel/FileSystem/FileDescription.h>
 #include <Kernel/FileSystem/VirtualFileSystem.h>
 #include <Kernel/FileSystem/VirtualFileSystem.h>
 #include <Kernel/Net/LocalSocket.h>
 #include <Kernel/Net/LocalSocket.h>
@@ -7,6 +8,21 @@
 
 
 //#define DEBUG_LOCAL_SOCKET
 //#define DEBUG_LOCAL_SOCKET
 
 
+Lockable<InlineLinkedList<LocalSocket>>& LocalSocket::all_sockets()
+{
+    static Lockable<InlineLinkedList<LocalSocket>>* s_list;
+    if (!s_list)
+        s_list = new Lockable<InlineLinkedList<LocalSocket>>();
+    return *s_list;
+}
+
+void LocalSocket::for_each(Function<void(LocalSocket&)> callback)
+{
+    LOCKER(all_sockets().lock());
+    for (auto& socket : all_sockets().resource())
+        callback(socket);
+}
+
 NonnullRefPtr<LocalSocket> LocalSocket::create(int type)
 NonnullRefPtr<LocalSocket> LocalSocket::create(int type)
 {
 {
     return adopt(*new LocalSocket(type));
     return adopt(*new LocalSocket(type));
@@ -15,6 +31,8 @@ NonnullRefPtr<LocalSocket> LocalSocket::create(int type)
 LocalSocket::LocalSocket(int type)
 LocalSocket::LocalSocket(int type)
     : Socket(AF_LOCAL, type, 0)
     : Socket(AF_LOCAL, type, 0)
 {
 {
+    LOCKER(all_sockets().lock());
+    all_sockets().resource().append(this);
 #ifdef DEBUG_LOCAL_SOCKET
 #ifdef DEBUG_LOCAL_SOCKET
     kprintf("%s(%u) LocalSocket{%p} created with type=%u\n", current->process().name().characters(), current->pid(), this, type);
     kprintf("%s(%u) LocalSocket{%p} created with type=%u\n", current->process().name().characters(), current->pid(), this, type);
 #endif
 #endif
@@ -22,6 +40,8 @@ LocalSocket::LocalSocket(int type)
 
 
 LocalSocket::~LocalSocket()
 LocalSocket::~LocalSocket()
 {
 {
+    LOCKER(all_sockets().lock());
+    all_sockets().resource().remove(this);
 }
 }
 
 
 bool LocalSocket::get_local_address(sockaddr* address, socklen_t* address_size)
 bool LocalSocket::get_local_address(sockaddr* address, socklen_t* address_size)
@@ -91,6 +111,7 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre
     auto description_or_error = VFS::the().open(safe_address, 0, 0, current->process().current_directory());
     auto description_or_error = VFS::the().open(safe_address, 0, 0, current->process().current_directory());
     if (description_or_error.is_error())
     if (description_or_error.is_error())
         return KResult(-ECONNREFUSED);
         return KResult(-ECONNREFUSED);
+
     m_file = move(description_or_error.value());
     m_file = move(description_or_error.value());
 
 
     ASSERT(m_file->inode());
     ASSERT(m_file->inode());

+ 9 - 1
Kernel/Net/LocalSocket.h

@@ -1,15 +1,18 @@
 #pragma once
 #pragma once
 
 
+#include <AK/InlineLinkedList.h>
 #include <Kernel/DoubleBuffer.h>
 #include <Kernel/DoubleBuffer.h>
 #include <Kernel/Net/Socket.h>
 #include <Kernel/Net/Socket.h>
 
 
 class FileDescription;
 class FileDescription;
 
 
-class LocalSocket final : public Socket {
+class LocalSocket final : public Socket, public InlineLinkedListNode<LocalSocket> {
+    friend class InlineLinkedListNode<LocalSocket>;
 public:
 public:
     static NonnullRefPtr<LocalSocket> create(int type);
     static NonnullRefPtr<LocalSocket> create(int type);
     virtual ~LocalSocket() override;
     virtual ~LocalSocket() override;
 
 
+    static void for_each(Function<void(LocalSocket&)>);
 
 
     StringView socket_path() const;
     StringView socket_path() const;
     // ^Socket
     // ^Socket
@@ -30,6 +33,7 @@ private:
     virtual const char* class_name() const override { return "LocalSocket"; }
     virtual const char* class_name() const override { return "LocalSocket"; }
     virtual bool is_local() const override { return true; }
     virtual bool is_local() const override { return true; }
     bool has_attached_peer(const FileDescription&) const;
     bool has_attached_peer(const FileDescription&) const;
+    static Lockable<InlineLinkedList<LocalSocket>>& all_sockets();
 
 
     // An open socket file on the filesystem.
     // An open socket file on the filesystem.
     RefPtr<FileDescription> m_file;
     RefPtr<FileDescription> m_file;
@@ -54,4 +58,8 @@ private:
 
 
     DoubleBuffer m_for_client;
     DoubleBuffer m_for_client;
     DoubleBuffer m_for_server;
     DoubleBuffer m_for_server;
+
+    // for InlineLinkedList
+    LocalSocket* m_prev { nullptr };
+    LocalSocket* m_next { nullptr };
 };
 };