DatabaseConnection.cpp 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. /*
  2. * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <AK/LexicalPath.h>
  7. #include <SQLServer/ConnectionFromClient.h>
  8. #include <SQLServer/DatabaseConnection.h>
  9. #include <SQLServer/SQLStatement.h>
  10. namespace SQLServer {
  11. static HashMap<int, NonnullRefPtr<DatabaseConnection>> s_connections;
  12. RefPtr<DatabaseConnection> DatabaseConnection::connection_for(int connection_id)
  13. {
  14. if (s_connections.contains(connection_id))
  15. return *s_connections.get(connection_id).value();
  16. dbgln_if(SQLSERVER_DEBUG, "Invalid connection_id {}", connection_id);
  17. return nullptr;
  18. }
  19. static int s_next_connection_id = 0;
  20. DatabaseConnection::DatabaseConnection(DeprecatedString database_name, int client_id)
  21. : Object()
  22. , m_database_name(move(database_name))
  23. , m_connection_id(s_next_connection_id++)
  24. , m_client_id(client_id)
  25. {
  26. if (LexicalPath path(m_database_name); (path.title() != m_database_name) || (path.dirname() != ".")) {
  27. auto client_connection = ConnectionFromClient::client_connection_for(m_client_id);
  28. client_connection->async_connection_error(m_connection_id, (int)SQL::SQLErrorCode::InvalidDatabaseName, m_database_name);
  29. return;
  30. }
  31. dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiating connection with database '{}'", connection_id(), m_database_name);
  32. s_connections.set(m_connection_id, *this);
  33. deferred_invoke([this]() {
  34. m_database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", m_database_name));
  35. auto client_connection = ConnectionFromClient::client_connection_for(m_client_id);
  36. if (auto maybe_error = m_database->open(); maybe_error.is_error()) {
  37. client_connection->async_connection_error(m_connection_id, to_underlying(maybe_error.error().error()), maybe_error.error().error_string());
  38. return;
  39. }
  40. m_accept_statements = true;
  41. if (client_connection)
  42. client_connection->async_connected(m_connection_id, m_database_name);
  43. else
  44. warnln("Cannot notify client of database connection. Client disconnected");
  45. });
  46. }
  47. void DatabaseConnection::disconnect()
  48. {
  49. dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::disconnect(connection_id {}, database '{}'", connection_id(), m_database_name);
  50. m_accept_statements = false;
  51. deferred_invoke([this]() {
  52. m_database = nullptr;
  53. s_connections.remove(m_connection_id);
  54. auto client_connection = ConnectionFromClient::client_connection_for(client_id());
  55. if (client_connection)
  56. client_connection->async_disconnected(m_connection_id);
  57. else
  58. warnln("Cannot notify client of database disconnection. Client disconnected");
  59. });
  60. }
  61. SQL::ResultOr<int> DatabaseConnection::prepare_statement(StringView sql)
  62. {
  63. dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::prepare_statement(connection_id {}, database '{}', sql '{}'", connection_id(), m_database_name, sql);
  64. if (!m_accept_statements)
  65. return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::DatabaseUnavailable };
  66. auto client_connection = ConnectionFromClient::client_connection_for(client_id());
  67. if (!client_connection) {
  68. warnln("Cannot notify client of database disconnection. Client disconnected");
  69. return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::InternalError, "Client disconnected"sv };
  70. }
  71. auto statement = TRY(SQLStatement::create(*this, sql));
  72. return statement->statement_id();
  73. }
  74. }