diff --git a/Libraries/LibCore/CoreIPCClient.h b/Libraries/LibCore/CoreIPCClient.h index 490cf485033..cf23731bc8c 100644 --- a/Libraries/LibCore/CoreIPCClient.h +++ b/Libraries/LibCore/CoreIPCClient.h @@ -1,122 +1,99 @@ #pragma once -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include -#include #include #include -#include #include +#include #include +#include //#define CIPC_DEBUG -namespace IPC -{ +namespace IPC { namespace Client { -class Event : public CEvent { -public: - enum Type { - Invalid = 2000, - PostProcess, - }; - Event() {} - explicit Event(Type type) - : CEvent(type) - { - } -}; - -class PostProcessEvent : public Event { -public: - explicit PostProcessEvent(int client_id) - : Event(PostProcess) - , m_client_id(client_id) - { - } - - int client_id() const { return m_client_id; } - -private: - int m_client_id { 0 }; -}; - -template -class Connection : public CObject { - C_OBJECT(Connection) -public: - Connection(const StringView& address) - : m_notifier(CNotifier(m_connection.fd(), CNotifier::Read)) - { - // We want to rate-limit our clients - m_connection.set_blocking(true); - m_notifier.on_ready_to_read = [this] { - drain_messages_from_server(); - CEventLoop::current().post_event(*this, make(m_connection.fd())); + class Event : public CEvent { + public: + enum Type { + Invalid = 2000, + PostProcess, }; - - int retries = 1000; - while (retries) { - if (m_connection.connect(CSocketAddress::local(address))) { - break; - } - - dbgprintf("Client::Connection: connect failed: %d, %s\n", errno, strerror(errno)); - sleep(1); - --retries; + Event() {} + explicit Event(Type type) + : CEvent(type) + { } - ASSERT(m_connection.is_connected()); - } + }; - virtual void handshake() = 0; - - - virtual void event(CEvent& event) override - { - if (event.type() == Event::PostProcess) { - postprocess_bundles(m_unprocessed_bundles); - } else { - CObject::event(event); + class PostProcessEvent : public Event { + public: + explicit PostProcessEvent(int client_id) + : Event(PostProcess) + , m_client_id(client_id) + { } - } - void set_server_pid(pid_t pid) { m_server_pid = pid; } - pid_t server_pid() const { return m_server_pid; } - void set_my_client_id(int id) { m_my_client_id = id; } - int my_client_id() const { return m_my_client_id; } + int client_id() const { return m_client_id; } - template - bool wait_for_specific_event(MessageType type, ServerMessage& event) - { - // Double check we don't already have the event waiting for us. - // Otherwise we might end up blocked for a while for no reason. - for (ssize_t i = 0; i < m_unprocessed_bundles.size(); ++i) { - if (m_unprocessed_bundles[i].message.type == type) { - event = move(m_unprocessed_bundles[i].message); - m_unprocessed_bundles.remove(i); + private: + int m_client_id { 0 }; + }; + + template + class Connection : public CObject { + C_OBJECT(Connection) + public: + Connection(const StringView& address) + : m_notifier(CNotifier(m_connection.fd(), CNotifier::Read)) + { + // We want to rate-limit our clients + m_connection.set_blocking(true); + m_notifier.on_ready_to_read = [this] { + drain_messages_from_server(); CEventLoop::current().post_event(*this, make(m_connection.fd())); - return true; + }; + + int retries = 1000; + while (retries) { + if (m_connection.connect(CSocketAddress::local(address))) { + break; + } + + dbgprintf("Client::Connection: connect failed: %d, %s\n", errno, strerror(errno)); + sleep(1); + --retries; + } + ASSERT(m_connection.is_connected()); + } + + virtual void handshake() = 0; + + virtual void event(CEvent& event) override + { + if (event.type() == Event::PostProcess) { + postprocess_bundles(m_unprocessed_bundles); + } else { + CObject::event(event); } } - for (;;) { - fd_set rfds; - FD_ZERO(&rfds); - FD_SET(m_connection.fd(), &rfds); - int rc = CSyscallUtils::safe_syscall(select, m_connection.fd() + 1, &rfds, nullptr, nullptr, nullptr); - if (rc < 0) { - perror("select"); - } - ASSERT(rc > 0); - ASSERT(FD_ISSET(m_connection.fd(), &rfds)); - bool success = drain_messages_from_server(); - if (!success) - return false; + + void set_server_pid(pid_t pid) { m_server_pid = pid; } + pid_t server_pid() const { return m_server_pid; } + void set_my_client_id(int id) { m_my_client_id = id; } + int my_client_id() const { return m_my_client_id; } + + template + bool wait_for_specific_event(MessageType type, ServerMessage& event) + { + // Double check we don't already have the event waiting for us. + // Otherwise we might end up blocked for a while for no reason. for (ssize_t i = 0; i < m_unprocessed_bundles.size(); ++i) { if (m_unprocessed_bundles[i].message.type == type) { event = move(m_unprocessed_bundles[i].message); @@ -125,106 +102,127 @@ public: return true; } } + for (;;) { + fd_set rfds; + FD_ZERO(&rfds); + FD_SET(m_connection.fd(), &rfds); + int rc = CSyscallUtils::safe_syscall(select, m_connection.fd() + 1, &rfds, nullptr, nullptr, nullptr); + if (rc < 0) { + perror("select"); + } + ASSERT(rc > 0); + ASSERT(FD_ISSET(m_connection.fd(), &rfds)); + bool success = drain_messages_from_server(); + if (!success) + return false; + for (ssize_t i = 0; i < m_unprocessed_bundles.size(); ++i) { + if (m_unprocessed_bundles[i].message.type == type) { + event = move(m_unprocessed_bundles[i].message); + m_unprocessed_bundles.remove(i); + CEventLoop::current().post_event(*this, make(m_connection.fd())); + return true; + } + } + } } - } - bool post_message_to_server(const ClientMessage& message, const ByteBuffer&& extra_data = {}) - { + bool post_message_to_server(const ClientMessage& message, const ByteBuffer&& extra_data = {}) + { #if defined(CIPC_DEBUG) - dbg() << "C: -> S " << int(message.type) << " extra " << extra_data.size(); + dbg() << "C: -> S " << int(message.type) << " extra " << extra_data.size(); #endif - if (!extra_data.is_empty()) - const_cast(message).extra_size = extra_data.size(); + if (!extra_data.is_empty()) + const_cast(message).extra_size = extra_data.size(); - struct iovec iov[2]; - int iov_count = 1; - iov[0].iov_base = const_cast(&message); - iov[0].iov_len = sizeof(message); + struct iovec iov[2]; + int iov_count = 1; + iov[0].iov_base = const_cast(&message); + iov[0].iov_len = sizeof(message); - if (!extra_data.is_empty()) { - iov[1].iov_base = const_cast(extra_data.data()); - iov[1].iov_len = extra_data.size(); - ++iov_count; + if (!extra_data.is_empty()) { + iov[1].iov_base = const_cast(extra_data.data()); + iov[1].iov_len = extra_data.size(); + ++iov_count; + } + + int nwritten = writev(m_connection.fd(), iov, iov_count); + if (nwritten < 0) { + perror("writev"); + ASSERT_NOT_REACHED(); + } + ASSERT((size_t)nwritten == sizeof(message) + extra_data.size()); + + return true; } - int nwritten = writev(m_connection.fd(), iov, iov_count); - if (nwritten < 0) { - perror("writev"); - ASSERT_NOT_REACHED(); + template + ServerMessage sync_request(const ClientMessage& request, MessageType response_type) + { + bool success = post_message_to_server(request); + ASSERT(success); + + ServerMessage response; + success = wait_for_specific_event(response_type, response); + ASSERT(success); + return response; } - ASSERT((size_t)nwritten == sizeof(message) + extra_data.size()); - return true; - } - - template - ServerMessage sync_request(const ClientMessage& request, MessageType response_type) - { - bool success = post_message_to_server(request); - ASSERT(success); - - ServerMessage response; - success = wait_for_specific_event(response_type, response); - ASSERT(success); - return response; - } - -protected: - struct IncomingMessageBundle { - ServerMessage message; - ByteBuffer extra_data; - }; - - virtual void postprocess_bundles(Vector& new_bundles) - { - dbg() << "Client::Connection: " << " warning: discarding " << new_bundles.size() << " unprocessed bundles; this may not be what you want"; - new_bundles.clear(); - } - -private: - bool drain_messages_from_server() - { - for (;;) { + protected: + struct IncomingMessageBundle { ServerMessage message; - ssize_t nread = recv(m_connection.fd(), &message, sizeof(ServerMessage), MSG_DONTWAIT); - if (nread < 0) { - if (errno == EAGAIN) { - return true; - } - perror("read"); - exit(1); - return false; - } - if (nread == 0) { - dbgprintf("EOF on IPC fd\n"); - exit(1); - return false; - } - ASSERT(nread == sizeof(message)); ByteBuffer extra_data; - if (message.extra_size) { - extra_data = ByteBuffer::create_uninitialized(message.extra_size); - int extra_nread = read(m_connection.fd(), extra_data.data(), extra_data.size()); - if (extra_nread < 0) { - perror("read"); - ASSERT_NOT_REACHED(); - } - ASSERT((size_t)extra_nread == message.extra_size); - } -#if defined(CIPC_DEBUG) - dbg() << "C: <- S " << int(message.type) << " extra " << extra_data.size(); -#endif - m_unprocessed_bundles.append({ move(message), move(extra_data) }); - } - } + }; - CLocalSocket m_connection; - CNotifier m_notifier; - Vector m_unprocessed_bundles; - int m_server_pid; - int m_my_client_id; -}; + virtual void postprocess_bundles(Vector& new_bundles) + { + dbg() << "Client::Connection: " + << " warning: discarding " << new_bundles.size() << " unprocessed bundles; this may not be what you want"; + new_bundles.clear(); + } + + private: + bool drain_messages_from_server() + { + for (;;) { + ServerMessage message; + ssize_t nread = recv(m_connection.fd(), &message, sizeof(ServerMessage), MSG_DONTWAIT); + if (nread < 0) { + if (errno == EAGAIN) { + return true; + } + perror("read"); + exit(1); + return false; + } + if (nread == 0) { + dbgprintf("EOF on IPC fd\n"); + exit(1); + return false; + } + ASSERT(nread == sizeof(message)); + ByteBuffer extra_data; + if (message.extra_size) { + extra_data = ByteBuffer::create_uninitialized(message.extra_size); + int extra_nread = read(m_connection.fd(), extra_data.data(), extra_data.size()); + if (extra_nread < 0) { + perror("read"); + ASSERT_NOT_REACHED(); + } + ASSERT((size_t)extra_nread == message.extra_size); + } +#if defined(CIPC_DEBUG) + dbg() << "C: <- S " << int(message.type) << " extra " << extra_data.size(); +#endif + m_unprocessed_bundles.append({ move(message), move(extra_data) }); + } + } + + CLocalSocket m_connection; + CNotifier m_notifier; + Vector m_unprocessed_bundles; + int m_server_pid; + int m_my_client_id; + }; } // Client } // IPC - diff --git a/Libraries/LibCore/CoreIPCServer.h b/Libraries/LibCore/CoreIPCServer.h index 9849c80329a..97db89a38a0 100644 --- a/Libraries/LibCore/CoreIPCServer.h +++ b/Libraries/LibCore/CoreIPCServer.h @@ -1,221 +1,219 @@ #pragma once -#include #include #include #include #include +#include #include -#include -#include -#include -#include #include +#include +#include +#include +#include //#define CIPC_DEBUG -namespace IPC -{ +namespace IPC { namespace Server { -class Event : public CEvent { -public: - enum Type { - Invalid = 2000, - Disconnected, - }; - Event() {} - explicit Event(Type type) - : CEvent(type) - { - } -}; - -class DisconnectedEvent : public Event { -public: - explicit DisconnectedEvent(int client_id) - : Event(Disconnected) - , m_client_id(client_id) - { - } - - int client_id() const { return m_client_id; } - -private: - int m_client_id { 0 }; -}; - -template -T* new_connection_for_client(Args&& ... args) -{ - auto conn = new T(AK::forward(args)...) /* arghs */; - conn->send_greeting(); - return conn; -}; - -template -class Connection : public CObject { - C_OBJECT(Connection) -public: - Connection(int fd, int client_id) - : m_socket(fd) - , m_notifier(CNotifier(fd, CNotifier::Read)) - , m_client_id(client_id) - { - m_notifier.on_ready_to_read = [this] { drain_client(); }; -#if defined(CIPC_DEBUG) - dbg() << "S: Created new Connection " << fd << client_id << " and said hello"; -#endif - } - - ~Connection() - { -#if defined(CIPC_DEBUG) - dbg() << "S: Destroyed Connection " << m_socket.fd() << client_id(); -#endif - } - - void post_message(const ServerMessage& message, const ByteBuffer& extra_data = {}) - { -#if defined(CIPC_DEBUG) - dbg() << "S: -> C " << int(message.type) << " extra " << extra_data.size(); -#endif - if (!extra_data.is_empty()) - const_cast(message).extra_size = extra_data.size(); - - struct iovec iov[2]; - int iov_count = 1; - - iov[0].iov_base = const_cast(&message); - iov[0].iov_len = sizeof(message); - - if (!extra_data.is_empty()) { - iov[1].iov_base = const_cast(extra_data.data()); - iov[1].iov_len = extra_data.size(); - ++iov_count; - } - - int nwritten = writev(m_socket.fd(), iov, iov_count); - if (nwritten < 0) { - switch (errno) { - case EPIPE: - dbgprintf("Connection::post_message: Disconnected from peer.\n"); - delete_later(); - return; - break; - case EAGAIN: - dbgprintf("Connection::post_message: Client buffer overflowed.\n"); - did_misbehave(); - return; - break; - default: - perror("Connection::post_message writev"); - ASSERT_NOT_REACHED(); - } - } - - ASSERT(nwritten == (int)(sizeof(message) + extra_data.size())); - } - - void drain_client() - { - unsigned messages_received = 0; - for (;;) { - ClientMessage message; - // FIXME: Don't go one message at a time, that's so much context switching, oof. - ssize_t nread = recv(m_socket.fd(), &message, sizeof(ClientMessage), MSG_DONTWAIT); - if (nread == 0 || (nread == -1 && errno == EAGAIN)) { - if (!messages_received) { - // TODO: is delete_later() sufficient? - CEventLoop::current().post_event(*this, make(client_id())); - } - break; - } - if (nread < 0) { - perror("recv"); - ASSERT_NOT_REACHED(); - } - ByteBuffer extra_data; - if (message.extra_size) { - if (message.extra_size >= 32768) { - dbgprintf("message.extra_size is way too large\n"); - return did_misbehave(); - } - extra_data = ByteBuffer::create_uninitialized(message.extra_size); - // FIXME: We should allow this to time out. Maybe use a socket timeout? - int extra_nread = read(m_socket.fd(), extra_data.data(), extra_data.size()); - if (extra_nread != (int)message.extra_size) { - dbgprintf("extra_nread(%d) != extra_size(%d)\n", extra_nread, extra_data.size()); - if (extra_nread < 0) - perror("read"); - return did_misbehave(); - } - } -#if defined(CIPC_DEBUG) - dbg() << "S: <- C " << int(message.type) << " extra " << extra_data.size(); -#endif - if (!handle_message(message, move(extra_data))) - return; - ++messages_received; - } - } - - void did_misbehave() - { - dbgprintf("Connection{%p} (id=%d, pid=%d) misbehaved, disconnecting.\n", this, client_id(), m_pid); - delete_later(); - m_notifier.set_enabled(false); - } - - int client_id() const { return m_client_id; } - pid_t client_pid() const { return m_pid; } - void set_client_pid(pid_t pid) { m_pid = pid; } - - // ### having this public is sad - virtual void send_greeting() = 0; - -protected: - void event(CEvent& event) - { - if (event.type() == Event::Disconnected) { - int client_id = static_cast(event).client_id(); - dbgprintf("Connection: Client disconnected: %d\n", client_id); - delete this; - return; - } - - CObject::event(event); - } - - virtual bool handle_message(const ClientMessage&, const ByteBuffer&& = {}) = 0; - -private: - // TODO: A way to create some kind of CIODevice with an open FD would be nice. - class COpenedSocket : public CIODevice { - C_OBJECT(COpenedSocket) + class Event : public CEvent { public: - COpenedSocket(int fd) + enum Type { + Invalid = 2000, + Disconnected, + }; + Event() {} + explicit Event(Type type) + : CEvent(type) + { + } + }; + + class DisconnectedEvent : public Event { + public: + explicit DisconnectedEvent(int client_id) + : Event(Disconnected) + , m_client_id(client_id) { - set_fd(fd); - set_mode(CIODevice::OpenMode::ReadWrite); } - bool open(CIODevice::OpenMode) override + int client_id() const { return m_client_id; } + + private: + int m_client_id { 0 }; + }; + + template + T* new_connection_for_client(Args&&... args) + { + auto conn = new T(AK::forward(args)...) /* arghs */; + conn->send_greeting(); + return conn; + }; + + template + class Connection : public CObject { + C_OBJECT(Connection) + public: + Connection(int fd, int client_id) + : m_socket(fd) + , m_notifier(CNotifier(fd, CNotifier::Read)) + , m_client_id(client_id) { - ASSERT_NOT_REACHED(); - return true; + m_notifier.on_ready_to_read = [this] { drain_client(); }; +#if defined(CIPC_DEBUG) + dbg() << "S: Created new Connection " << fd << client_id << " and said hello"; +#endif + } + + ~Connection() + { +#if defined(CIPC_DEBUG) + dbg() << "S: Destroyed Connection " << m_socket.fd() << client_id(); +#endif + } + + void post_message(const ServerMessage& message, const ByteBuffer& extra_data = {}) + { +#if defined(CIPC_DEBUG) + dbg() << "S: -> C " << int(message.type) << " extra " << extra_data.size(); +#endif + if (!extra_data.is_empty()) + const_cast(message).extra_size = extra_data.size(); + + struct iovec iov[2]; + int iov_count = 1; + + iov[0].iov_base = const_cast(&message); + iov[0].iov_len = sizeof(message); + + if (!extra_data.is_empty()) { + iov[1].iov_base = const_cast(extra_data.data()); + iov[1].iov_len = extra_data.size(); + ++iov_count; + } + + int nwritten = writev(m_socket.fd(), iov, iov_count); + if (nwritten < 0) { + switch (errno) { + case EPIPE: + dbgprintf("Connection::post_message: Disconnected from peer.\n"); + delete_later(); + return; + break; + case EAGAIN: + dbgprintf("Connection::post_message: Client buffer overflowed.\n"); + did_misbehave(); + return; + break; + default: + perror("Connection::post_message writev"); + ASSERT_NOT_REACHED(); + } + } + + ASSERT(nwritten == (int)(sizeof(message) + extra_data.size())); + } + + void drain_client() + { + unsigned messages_received = 0; + for (;;) { + ClientMessage message; + // FIXME: Don't go one message at a time, that's so much context switching, oof. + ssize_t nread = recv(m_socket.fd(), &message, sizeof(ClientMessage), MSG_DONTWAIT); + if (nread == 0 || (nread == -1 && errno == EAGAIN)) { + if (!messages_received) { + // TODO: is delete_later() sufficient? + CEventLoop::current().post_event(*this, make(client_id())); + } + break; + } + if (nread < 0) { + perror("recv"); + ASSERT_NOT_REACHED(); + } + ByteBuffer extra_data; + if (message.extra_size) { + if (message.extra_size >= 32768) { + dbgprintf("message.extra_size is way too large\n"); + return did_misbehave(); + } + extra_data = ByteBuffer::create_uninitialized(message.extra_size); + // FIXME: We should allow this to time out. Maybe use a socket timeout? + int extra_nread = read(m_socket.fd(), extra_data.data(), extra_data.size()); + if (extra_nread != (int)message.extra_size) { + dbgprintf("extra_nread(%d) != extra_size(%d)\n", extra_nread, extra_data.size()); + if (extra_nread < 0) + perror("read"); + return did_misbehave(); + } + } +#if defined(CIPC_DEBUG) + dbg() << "S: <- C " << int(message.type) << " extra " << extra_data.size(); +#endif + if (!handle_message(message, move(extra_data))) + return; + ++messages_received; + } + } + + void did_misbehave() + { + dbgprintf("Connection{%p} (id=%d, pid=%d) misbehaved, disconnecting.\n", this, client_id(), m_pid); + delete_later(); + m_notifier.set_enabled(false); + } + + int client_id() const { return m_client_id; } + pid_t client_pid() const { return m_pid; } + void set_client_pid(pid_t pid) { m_pid = pid; } + + // ### having this public is sad + virtual void send_greeting() = 0; + + protected: + void event(CEvent& event) + { + if (event.type() == Event::Disconnected) { + int client_id = static_cast(event).client_id(); + dbgprintf("Connection: Client disconnected: %d\n", client_id); + delete this; + return; + } + + CObject::event(event); + } + + virtual bool handle_message(const ClientMessage&, const ByteBuffer&& = {}) = 0; + + private: + // TODO: A way to create some kind of CIODevice with an open FD would be nice. + class COpenedSocket : public CIODevice { + C_OBJECT(COpenedSocket) + public: + COpenedSocket(int fd) + { + set_fd(fd); + set_mode(CIODevice::OpenMode::ReadWrite); + } + + bool open(CIODevice::OpenMode) override + { + ASSERT_NOT_REACHED(); + return true; + }; + + int fd() const { return CIODevice::fd(); } }; - int fd() const { return CIODevice::fd(); } + COpenedSocket m_socket; + CNotifier m_notifier; + int m_client_id; + int m_pid; }; - COpenedSocket m_socket; - CNotifier m_notifier; - int m_client_id; - int m_pid; -}; - } // Server } // IPC -