SQLStatement.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. /*
  2. * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <LibCore/EventReceiver.h>
  7. #include <LibSQL/AST/Parser.h>
  8. #include <SQLServer/ConnectionFromClient.h>
  9. #include <SQLServer/DatabaseConnection.h>
  10. #include <SQLServer/SQLStatement.h>
  11. namespace SQLServer {
  12. static HashMap<SQL::StatementID, NonnullRefPtr<SQLStatement>> s_statements;
  13. static SQL::StatementID s_next_statement_id = 0;
  14. RefPtr<SQLStatement> SQLStatement::statement_for(SQL::StatementID statement_id)
  15. {
  16. if (s_statements.contains(statement_id))
  17. return *s_statements.get(statement_id).value();
  18. dbgln_if(SQLSERVER_DEBUG, "Invalid statement_id {}", statement_id);
  19. return nullptr;
  20. }
  21. SQL::ResultOr<NonnullRefPtr<SQLStatement>> SQLStatement::create(DatabaseConnection& connection, StringView sql)
  22. {
  23. auto parser = SQL::AST::Parser(SQL::AST::Lexer(sql));
  24. auto statement = parser.next_statement();
  25. if (parser.has_errors())
  26. return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::SyntaxError, parser.errors()[0].to_deprecated_string() };
  27. return TRY(adopt_nonnull_ref_or_enomem(new (nothrow) SQLStatement(connection, move(statement))));
  28. }
  29. SQLStatement::SQLStatement(DatabaseConnection& connection, NonnullRefPtr<SQL::AST::Statement> statement)
  30. : m_connection(connection)
  31. , m_statement_id(s_next_statement_id++)
  32. , m_statement(move(statement))
  33. {
  34. dbgln_if(SQLSERVER_DEBUG, "SQLStatement({})", connection.connection_id());
  35. s_statements.set(m_statement_id, *this);
  36. }
  37. void SQLStatement::report_error(SQL::Result result, SQL::ExecutionID execution_id)
  38. {
  39. dbgln_if(SQLSERVER_DEBUG, "SQLStatement::report_error(statement_id {}, error {}", statement_id(), result.error_string());
  40. auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id());
  41. s_statements.remove(statement_id());
  42. if (client_connection)
  43. client_connection->async_execution_error(statement_id(), execution_id, result.error(), result.error_string());
  44. else
  45. warnln("Cannot return execution error. Client disconnected");
  46. }
  47. Optional<SQL::ExecutionID> SQLStatement::execute(Vector<SQL::Value> placeholder_values)
  48. {
  49. dbgln_if(SQLSERVER_DEBUG, "SQLStatement::execute(statement_id {}", statement_id());
  50. auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id());
  51. if (!client_connection) {
  52. warnln("Cannot yield next result. Client disconnected");
  53. return {};
  54. }
  55. auto execution_id = m_next_execution_id++;
  56. m_ongoing_executions.set(execution_id);
  57. Core::deferred_invoke([this, strong_this = NonnullRefPtr(*this), placeholder_values = move(placeholder_values), execution_id] {
  58. auto execution_result = m_statement->execute(connection().database(), placeholder_values);
  59. m_ongoing_executions.remove(execution_id);
  60. if (execution_result.is_error()) {
  61. report_error(execution_result.release_error(), execution_id);
  62. return;
  63. }
  64. auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id());
  65. if (!client_connection) {
  66. warnln("Cannot return statement execution results. Client disconnected");
  67. return;
  68. }
  69. auto result = execution_result.release_value();
  70. if (should_send_result_rows(result)) {
  71. client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), true, 0, 0, 0);
  72. auto result_size = result.size();
  73. next(execution_id, move(result), result_size);
  74. } else {
  75. if (result.command() == SQL::SQLCommand::Insert)
  76. client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, result.size(), 0, 0);
  77. else if (result.command() == SQL::SQLCommand::Update)
  78. client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, result.size(), 0);
  79. else if (result.command() == SQL::SQLCommand::Delete)
  80. client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, result.size());
  81. else
  82. client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, 0);
  83. }
  84. });
  85. return execution_id;
  86. }
  87. bool SQLStatement::should_send_result_rows(SQL::ResultSet const& result) const
  88. {
  89. if (result.is_empty())
  90. return false;
  91. switch (result.command()) {
  92. case SQL::SQLCommand::Describe:
  93. case SQL::SQLCommand::Select:
  94. return true;
  95. default:
  96. return false;
  97. }
  98. }
  99. void SQLStatement::next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size)
  100. {
  101. auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id());
  102. if (!client_connection) {
  103. warnln("Cannot yield next result. Client disconnected");
  104. return;
  105. }
  106. if (!result.is_empty()) {
  107. auto result_row = result.take_first();
  108. client_connection->async_next_result(statement_id(), execution_id, result_row.row.take_data());
  109. Core::deferred_invoke([this, strong_this = NonnullRefPtr(*this), execution_id, result = move(result), result_size]() mutable {
  110. next(execution_id, move(result), result_size);
  111. });
  112. } else {
  113. client_connection->async_results_exhausted(statement_id(), execution_id, result_size);
  114. }
  115. }
  116. }