瀏覽代碼

LibCore: Convert CLocalSocket to ObjectPtr

Andreas Kling 5 年之前
父節點
當前提交
c83da29a9d

+ 10 - 9
DevTools/Inspector/RemoteProcess.cpp

@@ -8,6 +8,7 @@
 RemoteProcess::RemoteProcess(pid_t pid)
     : m_pid(pid)
     , m_object_graph_model(RemoteObjectGraphModel::create(*this))
+    , m_socket(CLocalSocket::construct())
 {
 }
 
@@ -65,13 +66,13 @@ void RemoteProcess::send_request(const JsonObject& request)
 {
     auto serialized = request.to_string();
     i32 length = serialized.length();
-    m_socket.write((const u8*)&length, sizeof(length));
-    m_socket.write(serialized);
+    m_socket->write((const u8*)&length, sizeof(length));
+    m_socket->write(serialized);
 }
 
 void RemoteProcess::update()
 {
-    m_socket.on_connected = [this] {
+    m_socket->on_connected = [this] {
         dbg() << "Connected to PID " << m_pid;
 
         {
@@ -87,18 +88,18 @@ void RemoteProcess::update()
         }
     };
 
-    m_socket.on_ready_to_read = [this] {
-        if (m_socket.eof()) {
+    m_socket->on_ready_to_read = [this] {
+        if (m_socket->eof()) {
             dbg() << "Disconnected from PID " << m_pid;
-            m_socket.close();
+            m_socket->close();
             return;
         }
 
         i32 length;
-        int nread = m_socket.read((u8*)&length, sizeof(length));
+        int nread = m_socket->read((u8*)&length, sizeof(length));
         ASSERT(nread == sizeof(length));
 
-        auto data = m_socket.read(length);
+        auto data = m_socket->read(length);
         ASSERT(data.size() == length);
 
         dbg() << "Got packet size " << length << " and read that many bytes";
@@ -125,7 +126,7 @@ void RemoteProcess::update()
         }
     };
 
-    auto success = m_socket.connect(CSocketAddress::local(String::format("/tmp/rpc.%d", m_pid)));
+    auto success = m_socket->connect(CSocketAddress::local(String::format("/tmp/rpc.%d", m_pid)));
     if (!success) {
         fprintf(stderr, "Couldn't connect to PID %d\n", m_pid);
         exit(1);

+ 1 - 1
DevTools/Inspector/RemoteProcess.h

@@ -32,6 +32,6 @@ private:
     pid_t m_pid { -1 };
     String m_process_name;
     NonnullRefPtr<RemoteObjectGraphModel> m_object_graph_model;
-    CLocalSocket m_socket;
+    ObjectPtr<CLocalSocket> m_socket;
     NonnullOwnPtrVector<RemoteObject> m_roots;
 };

+ 11 - 12
Libraries/LibCore/CEventLoop.cpp

@@ -34,21 +34,20 @@ CLocalServer CEventLoop::s_rpc_server;
 class RPCClient : public CObject {
     C_OBJECT(RPCClient)
 public:
-    explicit RPCClient(CLocalSocket& socket)
-        : m_socket(socket)
+    explicit RPCClient(ObjectPtr<CLocalSocket> socket)
+        : m_socket(move(socket))
     {
-        add_child(socket);
-
-        m_socket.on_ready_to_read = [this] {
+        add_child(*m_socket);
+        m_socket->on_ready_to_read = [this] {
             i32 length;
-            int nread = m_socket.read((u8*)&length, sizeof(length));
+            int nread = m_socket->read((u8*)&length, sizeof(length));
             if (nread == 0) {
                 dbg() << "RPC client disconnected";
                 delete_later();
                 return;
             }
             ASSERT(nread == sizeof(length));
-            auto request = m_socket.read(length);
+            auto request = m_socket->read(length);
 
             auto request_json = JsonValue::from_string(request);
             if (!request_json.is_object()) {
@@ -68,8 +67,8 @@ public:
     {
         auto serialized = response.to_string();
         i32 length = serialized.length();
-        m_socket.write((const u8*)&length, sizeof(length));
-        m_socket.write(serialized);
+        m_socket->write((const u8*)&length, sizeof(length));
+        m_socket->write(serialized);
     }
 
     void handle_request(const JsonObject& request)
@@ -118,7 +117,7 @@ public:
     }
 
 private:
-    CLocalSocket& m_socket;
+    ObjectPtr<CLocalSocket> m_socket;
 };
 
 CEventLoop::CEventLoop()
@@ -145,9 +144,9 @@ CEventLoop::CEventLoop()
         ASSERT(listening);
 
         s_rpc_server.on_ready_to_accept = [&] {
-            auto* client_socket = s_rpc_server.accept();
+            auto client_socket = s_rpc_server.accept();
             ASSERT(client_socket);
-            new RPCClient(*client_socket);
+            new RPCClient(move(client_socket));
         };
     }
 

+ 2 - 2
Libraries/LibCore/CLocalServer.cpp

@@ -39,7 +39,7 @@ bool CLocalServer::listen(const String& address)
     return true;
 }
 
-CLocalSocket* CLocalServer::accept()
+ObjectPtr<CLocalSocket> CLocalServer::accept()
 {
     ASSERT(m_listening);
     sockaddr_un un;
@@ -50,5 +50,5 @@ CLocalSocket* CLocalServer::accept()
         return nullptr;
     }
 
-    return new CLocalSocket({}, accepted_fd);
+    return CLocalSocket::construct(accepted_fd);
 }

+ 1 - 1
Libraries/LibCore/CLocalServer.h

@@ -14,7 +14,7 @@ public:
     bool is_listening() const { return m_listening; }
     bool listen(const String& address);
 
-    CLocalSocket* accept();
+    ObjectPtr<CLocalSocket> accept();
 
     Function<void()> on_ready_to_accept;
 

+ 1 - 1
Libraries/LibCore/CLocalSocket.cpp

@@ -2,7 +2,7 @@
 #include <sys/socket.h>
 #include <errno.h>
 
-CLocalSocket::CLocalSocket(Badge<CLocalServer>, int fd, CObject* parent)
+CLocalSocket::CLocalSocket(int fd, CObject* parent)
     : CSocket(CSocket::Type::Local, parent)
 {
     set_fd(fd);

+ 4 - 2
Libraries/LibCore/CLocalSocket.h

@@ -8,7 +8,9 @@ class CLocalServer;
 class CLocalSocket final : public CSocket {
     C_OBJECT(CLocalSocket)
 public:
-    explicit CLocalSocket(CObject* parent = nullptr);
-    CLocalSocket(Badge<CLocalServer>, int fd, CObject* parent = nullptr);
     virtual ~CLocalSocket() override;
+
+private:
+    explicit CLocalSocket(CObject* parent = nullptr);
+    CLocalSocket(int fd, CObject* parent = nullptr);
 };

+ 30 - 29
Libraries/LibCore/CoreIPCClient.h

@@ -7,6 +7,7 @@
 #include <LibCore/CSyscallUtils.h>
 #include <LibIPC/IMessage.h>
 #include <stdio.h>
+#include <stdlib.h>
 #include <sys/select.h>
 #include <sys/socket.h>
 #include <sys/types.h>
@@ -49,19 +50,19 @@ namespace Client {
     class Connection : public CObject {
     public:
         Connection(const StringView& address)
-            : m_connection(this)
-            , m_notifier(CNotifier::create(m_connection.fd(), CNotifier::Read, this))
+            : m_connection(CLocalSocket::construct(this))
+            , m_notifier(CNotifier::create(m_connection->fd(), CNotifier::Read, this))
         {
             // We want to rate-limit our clients
-            m_connection.set_blocking(true);
+            m_connection->set_blocking(true);
             m_notifier->on_ready_to_read = [this] {
                 drain_messages_from_server();
-                CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection.fd()));
+                CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection->fd()));
             };
 
             int retries = 1000;
             while (retries) {
-                if (m_connection.connect(CSocketAddress::local(address))) {
+                if (m_connection->connect(CSocketAddress::local(address))) {
                     break;
                 }
 
@@ -69,7 +70,7 @@ namespace Client {
                 sleep(1);
                 --retries;
             }
-            ASSERT(m_connection.is_connected());
+            ASSERT(m_connection->is_connected());
         }
 
         virtual void handshake() = 0;
@@ -97,20 +98,20 @@ namespace Client {
                 if (m_unprocessed_bundles[i].message.type == type) {
                     event = move(m_unprocessed_bundles[i].message);
                     m_unprocessed_bundles.remove(i);
-                    CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection.fd()));
+                    CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection->fd()));
                     return true;
                 }
             }
             for (;;) {
                 fd_set rfds;
                 FD_ZERO(&rfds);
-                FD_SET(m_connection.fd(), &rfds);
-                int rc = CSyscallUtils::safe_syscall(select, m_connection.fd() + 1, &rfds, nullptr, nullptr, nullptr);
+                FD_SET(m_connection->fd(), &rfds);
+                int rc = CSyscallUtils::safe_syscall(select, m_connection->fd() + 1, &rfds, nullptr, nullptr, nullptr);
                 if (rc < 0) {
                     perror("select");
                 }
                 ASSERT(rc > 0);
-                ASSERT(FD_ISSET(m_connection.fd(), &rfds));
+                ASSERT(FD_ISSET(m_connection->fd(), &rfds));
                 bool success = drain_messages_from_server();
                 if (!success)
                     return false;
@@ -118,7 +119,7 @@ namespace Client {
                     if (m_unprocessed_bundles[i].message.type == type) {
                         event = move(m_unprocessed_bundles[i].message);
                         m_unprocessed_bundles.remove(i);
-                        CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection.fd()));
+                        CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection->fd()));
                         return true;
                     }
                 }
@@ -144,7 +145,7 @@ namespace Client {
                 ++iov_count;
             }
 
-            int nwritten = writev(m_connection.fd(), iov, iov_count);
+            int nwritten = writev(m_connection->fd(), iov, iov_count);
             if (nwritten < 0) {
                 perror("writev");
                 ASSERT_NOT_REACHED();
@@ -196,7 +197,7 @@ namespace Client {
         {
             for (;;) {
                 ServerMessage message;
-                ssize_t nread = recv(m_connection.fd(), &message, sizeof(ServerMessage), MSG_DONTWAIT);
+                ssize_t nread = recv(m_connection->fd(), &message, sizeof(ServerMessage), MSG_DONTWAIT);
                 if (nread < 0) {
                     if (errno == EAGAIN) {
                         return true;
@@ -214,7 +215,7 @@ namespace Client {
                 ByteBuffer extra_data;
                 if (message.extra_size) {
                     extra_data = ByteBuffer::create_uninitialized(message.extra_size);
-                    int extra_nread = read(m_connection.fd(), extra_data.data(), extra_data.size());
+                    int extra_nread = read(m_connection->fd(), extra_data.data(), extra_data.size());
                     if (extra_nread < 0) {
                         perror("read");
                         ASSERT_NOT_REACHED();
@@ -228,7 +229,7 @@ namespace Client {
             }
         }
 
-        CLocalSocket m_connection;
+        ObjectPtr<CLocalSocket> m_connection;
         ObjectPtr<CNotifier> m_notifier;
         Vector<IncomingMessageBundle> m_unprocessed_bundles;
         int m_server_pid { -1 };
@@ -239,19 +240,19 @@ namespace Client {
     class ConnectionNG : public CObject {
     public:
         ConnectionNG(const StringView& address)
-            : m_connection(this)
-            , m_notifier(CNotifier::create(m_connection.fd(), CNotifier::Read, this))
+            : m_connection(CLocalSocket::construct(this))
+            , m_notifier(CNotifier::create(m_connection->fd(), CNotifier::Read, this))
         {
             // We want to rate-limit our clients
-            m_connection.set_blocking(true);
+            m_connection->set_blocking(true);
             m_notifier->on_ready_to_read = [this] {
                 drain_messages_from_server();
-                CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection.fd()));
+                CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection->fd()));
             };
 
             int retries = 1000;
             while (retries) {
-                if (m_connection.connect(CSocketAddress::local(address))) {
+                if (m_connection->connect(CSocketAddress::local(address))) {
                     break;
                 }
 
@@ -259,7 +260,7 @@ namespace Client {
                 sleep(1);
                 --retries;
             }
-            ASSERT(m_connection.is_connected());
+            ASSERT(m_connection->is_connected());
         }
 
         virtual void handshake() = 0;
@@ -287,20 +288,20 @@ namespace Client {
                 if (m_unprocessed_messages[i]->id() == MessageType::static_message_id()) {
                     auto message = move(m_unprocessed_messages[i]);
                     m_unprocessed_messages.remove(i);
-                    CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection.fd()));
+                    CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection->fd()));
                     return message;
                 }
             }
             for (;;) {
                 fd_set rfds;
                 FD_ZERO(&rfds);
-                FD_SET(m_connection.fd(), &rfds);
-                int rc = CSyscallUtils::safe_syscall(select, m_connection.fd() + 1, &rfds, nullptr, nullptr, nullptr);
+                FD_SET(m_connection->fd(), &rfds);
+                int rc = CSyscallUtils::safe_syscall(select, m_connection->fd() + 1, &rfds, nullptr, nullptr, nullptr);
                 if (rc < 0) {
                     perror("select");
                 }
                 ASSERT(rc > 0);
-                ASSERT(FD_ISSET(m_connection.fd(), &rfds));
+                ASSERT(FD_ISSET(m_connection->fd(), &rfds));
                 bool success = drain_messages_from_server();
                 if (!success)
                     return nullptr;
@@ -308,7 +309,7 @@ namespace Client {
                     if (m_unprocessed_messages[i]->id() == MessageType::static_message_id()) {
                         auto message = move(m_unprocessed_messages[i]);
                         m_unprocessed_messages.remove(i);
-                        CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection.fd()));
+                        CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection->fd()));
                         return message;
                     }
                 }
@@ -318,7 +319,7 @@ namespace Client {
         bool post_message_to_server(const IMessage& message)
         {
             auto buffer = message.encode();
-            int nwritten = write(m_connection.fd(), buffer.data(), (size_t)buffer.size());
+            int nwritten = write(m_connection->fd(), buffer.data(), (size_t)buffer.size());
             if (nwritten < 0) {
                 perror("write");
                 ASSERT_NOT_REACHED();
@@ -349,7 +350,7 @@ namespace Client {
         {
             for (;;) {
                 u8 buffer[4096];
-                ssize_t nread = recv(m_connection.fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
+                ssize_t nread = recv(m_connection->fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
                 if (nread < 0) {
                     if (errno == EAGAIN) {
                         return true;
@@ -371,7 +372,7 @@ namespace Client {
             }
         }
 
-        CLocalSocket m_connection;
+        ObjectPtr<CLocalSocket> m_connection;
         ObjectPtr<CNotifier> m_notifier;
         Vector<OwnPtr<IMessage>> m_unprocessed_messages;
         int m_server_pid { -1 };

+ 12 - 12
Libraries/LibCore/CoreIPCServer.h

@@ -70,7 +70,7 @@ namespace Server {
             , m_client_id(client_id)
         {
             add_child(socket);
-            m_socket.on_ready_to_read = [this] { drain_client(); };
+            m_socket->on_ready_to_read = [this] { drain_client(); };
 #if defined(CIPC_DEBUG)
             dbg() << "S: Created new Connection " << fd << client_id << " and said hello";
 #endif
@@ -79,7 +79,7 @@ namespace Server {
         ~Connection()
         {
 #if defined(CIPC_DEBUG)
-            dbg() << "S: Destroyed Connection " << m_socket.fd() << client_id();
+            dbg() << "S: Destroyed Connection " << m_socket->fd() << client_id();
 #endif
         }
 
@@ -103,7 +103,7 @@ namespace Server {
                 ++iov_count;
             }
 
-            int nwritten = writev(m_socket.fd(), iov, iov_count);
+            int nwritten = writev(m_socket->fd(), iov, iov_count);
             if (nwritten < 0) {
                 switch (errno) {
                 case EPIPE:
@@ -131,7 +131,7 @@ namespace Server {
             for (;;) {
                 ClientMessage message;
                 // FIXME: Don't go one message at a time, that's so much context switching, oof.
-                ssize_t nread = recv(m_socket.fd(), &message, sizeof(ClientMessage), MSG_DONTWAIT);
+                ssize_t nread = recv(m_socket->fd(), &message, sizeof(ClientMessage), MSG_DONTWAIT);
                 if (nread == 0 || (nread == -1 && errno == EAGAIN)) {
                     if (!messages_received) {
                         // TODO: is delete_later() sufficient?
@@ -151,7 +151,7 @@ namespace Server {
                     }
                     extra_data = ByteBuffer::create_uninitialized(message.extra_size);
                     // FIXME: We should allow this to time out. Maybe use a socket timeout?
-                    int extra_nread = read(m_socket.fd(), extra_data.data(), extra_data.size());
+                    int extra_nread = read(m_socket->fd(), extra_data.data(), extra_data.size());
                     if (extra_nread != (int)message.extra_size) {
                         dbgprintf("extra_nread(%d) != extra_size(%d)\n", extra_nread, extra_data.size());
                         if (extra_nread < 0)
@@ -171,7 +171,7 @@ namespace Server {
         void did_misbehave()
         {
             dbgprintf("Connection{%p} (id=%d, pid=%d) misbehaved, disconnecting.\n", this, client_id(), m_client_pid);
-            m_socket.close();
+            m_socket->close();
             delete_later();
         }
 
@@ -198,7 +198,7 @@ namespace Server {
         virtual bool handle_message(const ClientMessage&, const ByteBuffer&& = {}) = 0;
 
     private:
-        CLocalSocket& m_socket;
+        ObjectPtr<CLocalSocket> m_socket;
         int m_client_id { -1 };
         int m_client_pid { -1 };
     };
@@ -212,7 +212,7 @@ namespace Server {
             , m_client_id(client_id)
         {
             add_child(socket);
-            m_socket.on_ready_to_read = [this] { drain_client(); };
+            m_socket->on_ready_to_read = [this] { drain_client(); };
         }
 
         virtual ~ConnectionNG() override
@@ -223,7 +223,7 @@ namespace Server {
         {
             auto buffer = message.encode();
 
-            int nwritten = write(m_socket.fd(), buffer.data(), (size_t)buffer.size());
+            int nwritten = write(m_socket->fd(), buffer.data(), (size_t)buffer.size());
             if (nwritten < 0) {
                 switch (errno) {
                 case EPIPE:
@@ -249,7 +249,7 @@ namespace Server {
             unsigned messages_received = 0;
             for (;;) {
                 u8 buffer[4096];
-                ssize_t nread = recv(m_socket.fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
+                ssize_t nread = recv(m_socket->fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
                 if (nread == 0 || (nread == -1 && errno == EAGAIN)) {
                     if (!messages_received) {
                         // TODO: is delete_later() sufficient?
@@ -278,7 +278,7 @@ namespace Server {
         void did_misbehave()
         {
             dbg() << "Connection{" << this << "} (id=" << m_client_id << ", pid=" << m_client_pid << ") misbehaved, disconnecting.";
-            m_socket.close();
+            m_socket->close();
             delete_later();
         }
 
@@ -301,7 +301,7 @@ namespace Server {
 
     private:
         Endpoint& m_endpoint;
-        CLocalSocket& m_socket;
+        ObjectPtr<CLocalSocket> m_socket;
         int m_client_id { -1 };
         int m_client_pid { -1 };
     };

+ 1 - 0
Libraries/LibCore/ObjectPtr.h

@@ -9,6 +9,7 @@ class ObjectPtr {
 public:
     ObjectPtr() {}
     ObjectPtr(T* ptr) : m_ptr(ptr) {}
+    ObjectPtr(T& ptr) : m_ptr(&ptr) {}
     ~ObjectPtr()
     {
         if (m_ptr && !m_ptr->parent())

+ 1 - 1
Servers/AudioServer/ASEventLoop.cpp

@@ -9,7 +9,7 @@ ASEventLoop::ASEventLoop()
     unlink("/tmp/asportal");
     m_server_sock.listen("/tmp/asportal");
     m_server_sock.on_ready_to_accept = [this] {
-        auto* client_socket = m_server_sock.accept();
+        auto client_socket = m_server_sock.accept();
         if (!client_socket) {
             dbg() << "AudioServer: accept failed.";
             return;

+ 1 - 1
Servers/WindowServer/WSEventLoop.cpp

@@ -30,7 +30,7 @@ WSEventLoop::WSEventLoop()
     m_server_sock.listen("/tmp/wsportal");
 
     m_server_sock.on_ready_to_accept = [this] {
-        auto* client_socket = m_server_sock.accept();
+        auto client_socket = m_server_sock.accept();
         if (!client_socket) {
             dbg() << "WindowServer: accept failed.";
             return;

+ 8 - 8
Userland/rpcdump.cpp

@@ -16,27 +16,27 @@ int main(int argc, char** argv)
 
     int pid = atoi(argv[1]);
 
-    CLocalSocket socket;
+    auto socket = CLocalSocket::construct();
 
-    socket.on_connected = [&] {
+    socket->on_connected = [&] {
         dbg() << "Connected to PID " << pid;
 
         JsonObject request;
         request.set("type", "GetAllObjects");
         auto serialized = request.to_string();
         i32 length = serialized.length();
-        socket.write((const u8*)&length, sizeof(length));
-        socket.write(serialized);
+        socket->write((const u8*)&length, sizeof(length));
+        socket->write(serialized);
     };
 
-    socket.on_ready_to_read = [&] {
-        if (socket.eof()) {
+    socket->on_ready_to_read = [&] {
+        if (socket->eof()) {
             dbg() << "Disconnected from PID " << pid;
             loop.quit(0);
             return;
         }
 
-        auto data = socket.read_all();
+        auto data = socket->read_all();
 
         for (int i = 0; i < data.size(); ++i)
             putchar(data[i]);
@@ -45,7 +45,7 @@ int main(int argc, char** argv)
         loop.quit(0);
     };
 
-    auto success = socket.connect(CSocketAddress::local(String::format("/tmp/rpc.%d", pid)));
+    auto success = socket->connect(CSocketAddress::local(String::format("/tmp/rpc.%d", pid)));
     if (!success) {
         fprintf(stderr, "Couldn't connect to PID %d\n", pid);
         return 1;