CoreIPCClient.h 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #pragma once
  2. #include <LibCore/CEvent.h>
  3. #include <LibCore/CEventLoop.h>
  4. #include <LibCore/CLocalSocket.h>
  5. #include <LibCore/CNotifier.h>
  6. #include <LibCore/CSyscallUtils.h>
  7. #include <LibIPC/IMessage.h>
  8. #include <sched.h>
  9. #include <stdio.h>
  10. #include <stdlib.h>
  11. #include <sys/select.h>
  12. #include <sys/socket.h>
  13. #include <sys/types.h>
  14. #include <sys/uio.h>
  15. #include <unistd.h>
  16. //#define CIPC_DEBUG
  17. namespace IPC {
  18. namespace Client {
  19. template<typename LocalEndpoint, typename PeerEndpoint>
  20. class ConnectionNG : public CObject {
  21. public:
  22. ConnectionNG(LocalEndpoint& local_endpoint, const StringView& address)
  23. : m_local_endpoint(local_endpoint)
  24. , m_connection(CLocalSocket::construct(this))
  25. , m_notifier(CNotifier::construct(m_connection->fd(), CNotifier::Read, this))
  26. {
  27. // We want to rate-limit our clients
  28. m_connection->set_blocking(true);
  29. m_notifier->on_ready_to_read = [this] {
  30. drain_messages_from_server();
  31. };
  32. int retries = 100000;
  33. while (retries) {
  34. if (m_connection->connect(CSocketAddress::local(address))) {
  35. break;
  36. }
  37. dbgprintf("Client::Connection: connect failed: %d, %s\n", errno, strerror(errno));
  38. usleep(10000);
  39. --retries;
  40. }
  41. ASSERT(m_connection->is_connected());
  42. }
  43. virtual void handshake() = 0;
  44. void set_server_pid(pid_t pid) { m_server_pid = pid; }
  45. pid_t server_pid() const { return m_server_pid; }
  46. void set_my_client_id(int id) { m_my_client_id = id; }
  47. int my_client_id() const { return m_my_client_id; }
  48. template<typename MessageType>
  49. OwnPtr<MessageType> wait_for_specific_message()
  50. {
  51. // Double check we don't already have the event waiting for us.
  52. // Otherwise we might end up blocked for a while for no reason.
  53. for (ssize_t i = 0; i < m_unprocessed_messages.size(); ++i) {
  54. if (m_unprocessed_messages[i]->id() == MessageType::static_message_id()) {
  55. auto message = move(m_unprocessed_messages[i]);
  56. m_unprocessed_messages.remove(i);
  57. return message;
  58. }
  59. }
  60. for (;;) {
  61. fd_set rfds;
  62. FD_ZERO(&rfds);
  63. FD_SET(m_connection->fd(), &rfds);
  64. int rc = CSyscallUtils::safe_syscall(select, m_connection->fd() + 1, &rfds, nullptr, nullptr, nullptr);
  65. if (rc < 0) {
  66. perror("select");
  67. }
  68. ASSERT(rc > 0);
  69. ASSERT(FD_ISSET(m_connection->fd(), &rfds));
  70. if (!drain_messages_from_server())
  71. return nullptr;
  72. for (ssize_t i = 0; i < m_unprocessed_messages.size(); ++i) {
  73. if (m_unprocessed_messages[i]->id() == MessageType::static_message_id()) {
  74. auto message = move(m_unprocessed_messages[i]);
  75. m_unprocessed_messages.remove(i);
  76. return message;
  77. }
  78. }
  79. }
  80. }
  81. bool post_message(const IMessage& message)
  82. {
  83. auto buffer = message.encode();
  84. int nwritten = write(m_connection->fd(), buffer.data(), (size_t)buffer.size());
  85. if (nwritten < 0) {
  86. perror("write");
  87. ASSERT_NOT_REACHED();
  88. return false;
  89. }
  90. ASSERT(nwritten == buffer.size());
  91. return true;
  92. }
  93. template<typename RequestType, typename... Args>
  94. OwnPtr<typename RequestType::ResponseType> send_sync(Args&&... args)
  95. {
  96. bool success = post_message(RequestType(forward<Args>(args)...));
  97. ASSERT(success);
  98. auto response = wait_for_specific_message<typename RequestType::ResponseType>();
  99. ASSERT(response);
  100. return response;
  101. }
  102. private:
  103. bool drain_messages_from_server()
  104. {
  105. Vector<u8> bytes;
  106. for (;;) {
  107. u8 buffer[4096];
  108. ssize_t nread = recv(m_connection->fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
  109. if (nread < 0) {
  110. if (errno == EAGAIN)
  111. break;
  112. perror("read");
  113. exit(1);
  114. return false;
  115. }
  116. if (nread == 0) {
  117. dbg() << "EOF on IPC fd";
  118. // FIXME: Dying is definitely not always appropriate!
  119. exit(1);
  120. return false;
  121. }
  122. bytes.append(buffer, nread);
  123. }
  124. size_t decoded_bytes = 0;
  125. for (size_t index = 0; index < (size_t)bytes.size(); index += decoded_bytes) {
  126. auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index);
  127. if (auto message = LocalEndpoint::decode_message(remaining_bytes, decoded_bytes)) {
  128. m_local_endpoint.handle(*message);
  129. } else if (auto message = PeerEndpoint::decode_message(remaining_bytes, decoded_bytes)) {
  130. m_unprocessed_messages.append(move(message));
  131. } else {
  132. ASSERT_NOT_REACHED();
  133. }
  134. ASSERT(decoded_bytes);
  135. }
  136. return true;
  137. }
  138. LocalEndpoint& m_local_endpoint;
  139. RefPtr<CLocalSocket> m_connection;
  140. RefPtr<CNotifier> m_notifier;
  141. Vector<OwnPtr<IMessage>> m_unprocessed_messages;
  142. int m_server_pid { -1 };
  143. int m_my_client_id { -1 };
  144. };
  145. } // Client
  146. } // IPC