DatabaseConnection.cpp 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. /*
  2. * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <AK/LexicalPath.h>
  7. #include <SQLServer/ConnectionFromClient.h>
  8. #include <SQLServer/DatabaseConnection.h>
  9. #include <SQLServer/SQLStatement.h>
  10. namespace SQLServer {
  11. static HashMap<int, NonnullRefPtr<DatabaseConnection>> s_connections;
  12. RefPtr<DatabaseConnection> DatabaseConnection::connection_for(int connection_id)
  13. {
  14. if (s_connections.contains(connection_id))
  15. return *s_connections.get(connection_id).value();
  16. dbgln_if(SQLSERVER_DEBUG, "Invalid connection_id {}", connection_id);
  17. return nullptr;
  18. }
  19. static int s_next_connection_id = 0;
  20. DatabaseConnection::DatabaseConnection(String database_name, int client_id)
  21. : Object()
  22. , m_database_name(move(database_name))
  23. , m_connection_id(s_next_connection_id++)
  24. , m_client_id(client_id)
  25. {
  26. if (LexicalPath path(m_database_name); (path.title() != m_database_name) || (path.dirname() != ".")) {
  27. auto client_connection = ConnectionFromClient::client_connection_for(m_client_id);
  28. client_connection->async_connection_error(m_connection_id, (int)SQL::SQLErrorCode::InvalidDatabaseName, m_database_name);
  29. return;
  30. }
  31. dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiating connection with database '{}'", connection_id(), m_database_name);
  32. s_connections.set(m_connection_id, *this);
  33. deferred_invoke([this]() {
  34. m_database = SQL::Database::construct(String::formatted("/home/anon/sql/{}.db", m_database_name));
  35. auto client_connection = ConnectionFromClient::client_connection_for(m_client_id);
  36. if (auto maybe_error = m_database->open(); maybe_error.is_error()) {
  37. client_connection->async_connection_error(m_connection_id, to_underlying(maybe_error.error().error()), maybe_error.error().error_string());
  38. return;
  39. }
  40. m_accept_statements = true;
  41. if (client_connection)
  42. client_connection->async_connected(m_connection_id, m_database_name);
  43. else
  44. warnln("Cannot notify client of database connection. Client disconnected");
  45. });
  46. }
  47. void DatabaseConnection::disconnect()
  48. {
  49. dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::disconnect(connection_id {}, database '{}'", connection_id(), m_database_name);
  50. m_accept_statements = false;
  51. deferred_invoke([this]() {
  52. m_database = nullptr;
  53. s_connections.remove(m_connection_id);
  54. auto client_connection = ConnectionFromClient::client_connection_for(client_id());
  55. if (client_connection)
  56. client_connection->async_disconnected(m_connection_id);
  57. else
  58. warnln("Cannot notify client of database disconnection. Client disconnected");
  59. });
  60. }
  61. int DatabaseConnection::prepare_statement(String const& sql)
  62. {
  63. dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::prepare_statement(connection_id {}, database '{}', sql '{}'", connection_id(), m_database_name, sql);
  64. auto client_connection = ConnectionFromClient::client_connection_for(client_id());
  65. if (!client_connection) {
  66. warnln("Cannot notify client of database disconnection. Client disconnected");
  67. return -1;
  68. }
  69. if (!m_accept_statements) {
  70. client_connection->async_execution_error(-1, (int)SQL::SQLErrorCode::DatabaseUnavailable, m_database_name);
  71. return -1;
  72. }
  73. auto statement = SQLStatement::construct(*this, sql);
  74. return statement->statement_id();
  75. }
  76. }