SQLClient.cpp 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. /*
  2. * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
  3. * Copyright (c) 2022, the SerenityOS developers.
  4. *
  5. * SPDX-License-Identifier: BSD-2-Clause
  6. */
  7. #include <AK/ByteString.h>
  8. #include <AK/ScopeGuard.h>
  9. #include <AK/String.h>
  10. #include <LibSQL/SQLClient.h>
  11. #if !defined(AK_OS_SERENITY)
  12. # include <LibCore/Directory.h>
  13. # include <LibCore/Environment.h>
  14. # include <LibCore/SocketAddress.h>
  15. # include <LibCore/StandardPaths.h>
  16. # include <LibCore/System.h>
  17. # include <LibFileSystem/FileSystem.h>
  18. # include <signal.h>
  19. #endif
  20. namespace SQL {
  21. #if !defined(AK_OS_SERENITY)
  22. // This is heavily based on how SystemServer's Service creates its socket.
  23. static ErrorOr<int> create_database_socket(ByteString const& socket_path)
  24. {
  25. if (FileSystem::exists(socket_path))
  26. TRY(Core::System::unlink(socket_path));
  27. # ifdef SOCK_NONBLOCK
  28. auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
  29. # else
  30. auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM, 0));
  31. int option = 1;
  32. TRY(Core::System::ioctl(socket_fd, FIONBIO, &option));
  33. TRY(Core::System::fcntl(socket_fd, F_SETFD, FD_CLOEXEC));
  34. # endif
  35. # if !defined(AK_OS_BSD_GENERIC) && !defined(AK_OS_GNU_HURD)
  36. TRY(Core::System::fchmod(socket_fd, 0600));
  37. # endif
  38. auto socket_address = Core::SocketAddress::local(socket_path);
  39. auto socket_address_un = socket_address.to_sockaddr_un().release_value();
  40. TRY(Core::System::bind(socket_fd, reinterpret_cast<sockaddr*>(&socket_address_un), sizeof(socket_address_un)));
  41. TRY(Core::System::listen(socket_fd, 16));
  42. return socket_fd;
  43. }
  44. static ErrorOr<void> launch_server(ByteString const& socket_path, ByteString const& pid_path, Vector<ByteString> candidate_server_paths)
  45. {
  46. auto server_fd_or_error = create_database_socket(socket_path);
  47. if (server_fd_or_error.is_error()) {
  48. warnln("Failed to create a database socket at {}: {}", socket_path, server_fd_or_error.error());
  49. return server_fd_or_error.release_error();
  50. }
  51. auto server_fd = server_fd_or_error.value();
  52. sigset_t original_set;
  53. sigset_t setting_set;
  54. sigfillset(&setting_set);
  55. (void)pthread_sigmask(SIG_BLOCK, &setting_set, &original_set);
  56. auto server_pid = TRY(Core::System::fork());
  57. if (server_pid == 0) {
  58. (void)pthread_sigmask(SIG_SETMASK, &original_set, nullptr);
  59. TRY(Core::System::setsid());
  60. TRY(Core::System::signal(SIGCHLD, SIG_IGN));
  61. server_pid = TRY(Core::System::fork());
  62. if (server_pid != 0) {
  63. auto server_pid_file = TRY(Core::File::open(pid_path, Core::File::OpenMode::Write));
  64. TRY(server_pid_file->write_until_depleted(ByteString::number(server_pid).bytes()));
  65. TRY(Core::System::kill(getpid(), SIGTERM));
  66. }
  67. server_fd = TRY(Core::System::dup(server_fd));
  68. auto takeover_string = ByteString::formatted("SQLServer:{}", server_fd);
  69. TRY(Core::Environment::set("SOCKET_TAKEOVER"sv, takeover_string, Core::Environment::Overwrite::Yes));
  70. ErrorOr<void> result;
  71. for (auto const& server_path : candidate_server_paths) {
  72. auto arguments = Array {
  73. server_path.view(),
  74. "--pid-file"sv,
  75. pid_path,
  76. };
  77. result = Core::System::exec(arguments[0], arguments, Core::System::SearchInPath::Yes);
  78. if (!result.is_error())
  79. break;
  80. }
  81. if (result.is_error()) {
  82. warnln("Could not launch any of {}: {}", candidate_server_paths, result.error());
  83. TRY(Core::System::unlink(pid_path));
  84. }
  85. VERIFY_NOT_REACHED();
  86. }
  87. VERIFY(server_pid > 0);
  88. auto wait_err = Core::System::waitpid(server_pid);
  89. (void)pthread_sigmask(SIG_SETMASK, &original_set, nullptr);
  90. if (wait_err.is_error())
  91. return wait_err.release_error();
  92. return {};
  93. }
  94. static ErrorOr<bool> should_launch_server(ByteString const& pid_path)
  95. {
  96. if (!FileSystem::exists(pid_path))
  97. return true;
  98. Optional<pid_t> pid;
  99. {
  100. auto server_pid_file = Core::File::open(pid_path, Core::File::OpenMode::Read);
  101. if (server_pid_file.is_error()) {
  102. warnln("Could not open SQLServer PID file '{}': {}", pid_path, server_pid_file.error());
  103. return server_pid_file.release_error();
  104. }
  105. auto contents = server_pid_file.value()->read_until_eof();
  106. if (contents.is_error()) {
  107. warnln("Could not read SQLServer PID file '{}': {}", pid_path, contents.error());
  108. return contents.release_error();
  109. }
  110. pid = StringView { contents.value() }.to_number<pid_t>();
  111. }
  112. if (!pid.has_value()) {
  113. warnln("SQLServer PID file '{}' exists, but with an invalid PID", pid_path);
  114. TRY(Core::System::unlink(pid_path));
  115. return true;
  116. }
  117. if (kill(*pid, 0) < 0) {
  118. warnln("SQLServer PID file '{}' exists with PID {}, but process cannot be found", pid_path, *pid);
  119. TRY(Core::System::unlink(pid_path));
  120. return true;
  121. }
  122. return false;
  123. }
  124. ErrorOr<NonnullRefPtr<SQLClient>> SQLClient::launch_server_and_create_client(Vector<ByteString> candidate_server_paths)
  125. {
  126. auto runtime_directory = TRY(Core::StandardPaths::runtime_directory());
  127. auto socket_path = ByteString::formatted("{}/SQLServer.socket", runtime_directory);
  128. auto pid_path = ByteString::formatted("{}/SQLServer.pid", runtime_directory);
  129. if (TRY(should_launch_server(pid_path)))
  130. TRY(launch_server(socket_path, pid_path, move(candidate_server_paths)));
  131. auto socket = TRY(Core::LocalSocket::connect(move(socket_path)));
  132. TRY(socket->set_blocking(true));
  133. return adopt_nonnull_ref_or_enomem(new (nothrow) SQLClient(move(socket)));
  134. }
  135. #endif
  136. void SQLClient::execution_success(u64 statement_id, u64 execution_id, Vector<ByteString> const& column_names, bool has_results, size_t created, size_t updated, size_t deleted)
  137. {
  138. if (!on_execution_success) {
  139. outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted);
  140. return;
  141. }
  142. ExecutionSuccess success {
  143. .statement_id = statement_id,
  144. .execution_id = execution_id,
  145. .column_names = move(const_cast<Vector<ByteString>&>(column_names)),
  146. .has_results = has_results,
  147. .rows_created = created,
  148. .rows_updated = updated,
  149. .rows_deleted = deleted,
  150. };
  151. on_execution_success(move(success));
  152. }
  153. void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, ByteString const& message)
  154. {
  155. if (!on_execution_error) {
  156. warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code));
  157. return;
  158. }
  159. ExecutionError error {
  160. .statement_id = statement_id,
  161. .execution_id = execution_id,
  162. .error_code = code,
  163. .error_message = move(const_cast<ByteString&>(message)),
  164. };
  165. on_execution_error(move(error));
  166. }
  167. void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector<Value> const& row)
  168. {
  169. ScopeGuard guard { [&]() { async_ready_for_next_result(statement_id, execution_id); } };
  170. if (!on_next_result) {
  171. StringBuilder builder;
  172. builder.join(", "sv, row, "\"{}\""sv);
  173. outln("{}", builder.string_view());
  174. return;
  175. }
  176. ExecutionResult result {
  177. .statement_id = statement_id,
  178. .execution_id = execution_id,
  179. .values = move(const_cast<Vector<Value>&>(row)),
  180. };
  181. on_next_result(move(result));
  182. }
  183. void SQLClient::results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows)
  184. {
  185. if (!on_results_exhausted) {
  186. outln("{} total row(s)", total_rows);
  187. return;
  188. }
  189. ExecutionComplete success {
  190. .statement_id = statement_id,
  191. .execution_id = execution_id,
  192. .total_rows = total_rows,
  193. };
  194. on_results_exhausted(move(success));
  195. }
  196. }