CoreIPCClient.h 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. #pragma once
  2. #include <LibAudio/ASAPI.h>
  3. #include <LibCore/CEvent.h>
  4. #include <LibCore/CEventLoop.h>
  5. #include <LibCore/CLocalSocket.h>
  6. #include <LibCore/CNotifier.h>
  7. #include <LibCore/CSyscallUtils.h>
  8. #include <stdio.h>
  9. #include <sys/select.h>
  10. #include <sys/socket.h>
  11. #include <sys/types.h>
  12. #include <sys/uio.h>
  13. #include <unistd.h>
  14. //#define CIPC_DEBUG
  15. namespace IPC {
  16. namespace Client {
  17. class Event : public CEvent {
  18. public:
  19. enum Type {
  20. Invalid = 2000,
  21. PostProcess,
  22. };
  23. Event() {}
  24. explicit Event(Type type)
  25. : CEvent(type)
  26. {
  27. }
  28. };
  29. class PostProcessEvent : public Event {
  30. public:
  31. explicit PostProcessEvent(int client_id)
  32. : Event(PostProcess)
  33. , m_client_id(client_id)
  34. {
  35. }
  36. int client_id() const { return m_client_id; }
  37. private:
  38. int m_client_id { 0 };
  39. };
  40. template<typename ServerMessage, typename ClientMessage>
  41. class Connection : public CObject {
  42. C_OBJECT(Connection)
  43. public:
  44. Connection(const StringView& address)
  45. : m_notifier(CNotifier(m_connection.fd(), CNotifier::Read))
  46. {
  47. // We want to rate-limit our clients
  48. m_connection.set_blocking(true);
  49. m_notifier.on_ready_to_read = [this] {
  50. drain_messages_from_server();
  51. CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection.fd()));
  52. };
  53. int retries = 1000;
  54. while (retries) {
  55. if (m_connection.connect(CSocketAddress::local(address))) {
  56. break;
  57. }
  58. dbgprintf("Client::Connection: connect failed: %d, %s\n", errno, strerror(errno));
  59. sleep(1);
  60. --retries;
  61. }
  62. ASSERT(m_connection.is_connected());
  63. }
  64. virtual void handshake() = 0;
  65. virtual void event(CEvent& event) override
  66. {
  67. if (event.type() == Event::PostProcess) {
  68. postprocess_bundles(m_unprocessed_bundles);
  69. } else {
  70. CObject::event(event);
  71. }
  72. }
  73. void set_server_pid(pid_t pid) { m_server_pid = pid; }
  74. pid_t server_pid() const { return m_server_pid; }
  75. void set_my_client_id(int id) { m_my_client_id = id; }
  76. int my_client_id() const { return m_my_client_id; }
  77. template<typename MessageType>
  78. bool wait_for_specific_event(MessageType type, ServerMessage& event)
  79. {
  80. // Double check we don't already have the event waiting for us.
  81. // Otherwise we might end up blocked for a while for no reason.
  82. for (ssize_t i = 0; i < m_unprocessed_bundles.size(); ++i) {
  83. if (m_unprocessed_bundles[i].message.type == type) {
  84. event = move(m_unprocessed_bundles[i].message);
  85. m_unprocessed_bundles.remove(i);
  86. CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection.fd()));
  87. return true;
  88. }
  89. }
  90. for (;;) {
  91. fd_set rfds;
  92. FD_ZERO(&rfds);
  93. FD_SET(m_connection.fd(), &rfds);
  94. int rc = CSyscallUtils::safe_syscall(select, m_connection.fd() + 1, &rfds, nullptr, nullptr, nullptr);
  95. if (rc < 0) {
  96. perror("select");
  97. }
  98. ASSERT(rc > 0);
  99. ASSERT(FD_ISSET(m_connection.fd(), &rfds));
  100. bool success = drain_messages_from_server();
  101. if (!success)
  102. return false;
  103. for (ssize_t i = 0; i < m_unprocessed_bundles.size(); ++i) {
  104. if (m_unprocessed_bundles[i].message.type == type) {
  105. event = move(m_unprocessed_bundles[i].message);
  106. m_unprocessed_bundles.remove(i);
  107. CEventLoop::current().post_event(*this, make<PostProcessEvent>(m_connection.fd()));
  108. return true;
  109. }
  110. }
  111. }
  112. }
  113. bool post_message_to_server(const ClientMessage& message, const ByteBuffer&& extra_data = {})
  114. {
  115. #if defined(CIPC_DEBUG)
  116. dbg() << "C: -> S " << int(message.type) << " extra " << extra_data.size();
  117. #endif
  118. if (!extra_data.is_empty())
  119. const_cast<ClientMessage&>(message).extra_size = extra_data.size();
  120. struct iovec iov[2];
  121. int iov_count = 1;
  122. iov[0].iov_base = const_cast<ClientMessage*>(&message);
  123. iov[0].iov_len = sizeof(message);
  124. if (!extra_data.is_empty()) {
  125. iov[1].iov_base = const_cast<u8*>(extra_data.data());
  126. iov[1].iov_len = extra_data.size();
  127. ++iov_count;
  128. }
  129. int nwritten = writev(m_connection.fd(), iov, iov_count);
  130. if (nwritten < 0) {
  131. perror("writev");
  132. ASSERT_NOT_REACHED();
  133. }
  134. ASSERT((size_t)nwritten == sizeof(message) + extra_data.size());
  135. return true;
  136. }
  137. template<typename MessageType>
  138. ServerMessage sync_request(const ClientMessage& request, MessageType response_type)
  139. {
  140. bool success = post_message_to_server(request);
  141. ASSERT(success);
  142. ServerMessage response;
  143. success = wait_for_specific_event(response_type, response);
  144. ASSERT(success);
  145. return response;
  146. }
  147. template<typename RequestType, typename... Args>
  148. typename RequestType::ResponseType send_sync(Args&&... args)
  149. {
  150. bool success = post_message_to_server(RequestType(forward<Args>(args)...));
  151. ASSERT(success);
  152. ServerMessage response;
  153. success = wait_for_specific_event(RequestType::ResponseType::message_type(), response);
  154. ASSERT(success);
  155. return response;
  156. }
  157. protected:
  158. struct IncomingMessageBundle {
  159. ServerMessage message;
  160. ByteBuffer extra_data;
  161. };
  162. virtual void postprocess_bundles(Vector<IncomingMessageBundle>& new_bundles)
  163. {
  164. dbg() << "Client::Connection: "
  165. << " warning: discarding " << new_bundles.size() << " unprocessed bundles; this may not be what you want";
  166. new_bundles.clear();
  167. }
  168. private:
  169. bool drain_messages_from_server()
  170. {
  171. for (;;) {
  172. ServerMessage message;
  173. ssize_t nread = recv(m_connection.fd(), &message, sizeof(ServerMessage), MSG_DONTWAIT);
  174. if (nread < 0) {
  175. if (errno == EAGAIN) {
  176. return true;
  177. }
  178. perror("read");
  179. exit(1);
  180. return false;
  181. }
  182. if (nread == 0) {
  183. dbgprintf("EOF on IPC fd\n");
  184. exit(1);
  185. return false;
  186. }
  187. ASSERT(nread == sizeof(message));
  188. ByteBuffer extra_data;
  189. if (message.extra_size) {
  190. extra_data = ByteBuffer::create_uninitialized(message.extra_size);
  191. int extra_nread = read(m_connection.fd(), extra_data.data(), extra_data.size());
  192. if (extra_nread < 0) {
  193. perror("read");
  194. ASSERT_NOT_REACHED();
  195. }
  196. ASSERT((size_t)extra_nread == message.extra_size);
  197. }
  198. #if defined(CIPC_DEBUG)
  199. dbg() << "C: <- S " << int(message.type) << " extra " << extra_data.size();
  200. #endif
  201. m_unprocessed_bundles.append({ move(message), move(extra_data) });
  202. }
  203. }
  204. CLocalSocket m_connection;
  205. CNotifier m_notifier;
  206. Vector<IncomingMessageBundle> m_unprocessed_bundles;
  207. int m_server_pid;
  208. int m_my_client_id;
  209. };
  210. } // Client
  211. } // IPC