SQLStatement.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. /*
  2. * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <LibCore/Object.h>
  7. #include <LibSQL/AST/Parser.h>
  8. #include <SQLServer/ConnectionFromClient.h>
  9. #include <SQLServer/DatabaseConnection.h>
  10. #include <SQLServer/SQLStatement.h>
  11. namespace SQLServer {
  12. static HashMap<int, NonnullRefPtr<SQLStatement>> s_statements;
  13. RefPtr<SQLStatement> SQLStatement::statement_for(int statement_id)
  14. {
  15. if (s_statements.contains(statement_id))
  16. return *s_statements.get(statement_id).value();
  17. dbgln_if(SQLSERVER_DEBUG, "Invalid statement_id {}", statement_id);
  18. return nullptr;
  19. }
  20. static int s_next_statement_id = 0;
  21. SQLStatement::SQLStatement(DatabaseConnection& connection, String sql)
  22. : Core::Object(&connection)
  23. , m_statement_id(s_next_statement_id++)
  24. , m_sql(move(sql))
  25. {
  26. dbgln_if(SQLSERVER_DEBUG, "SQLStatement({}, {})", connection.connection_id(), sql);
  27. s_statements.set(m_statement_id, *this);
  28. }
  29. void SQLStatement::report_error(SQL::Result result)
  30. {
  31. dbgln_if(SQLSERVER_DEBUG, "SQLStatement::report_error(statement_id {}, error {}", statement_id(), result.error_string());
  32. auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
  33. s_statements.remove(statement_id());
  34. remove_from_parent();
  35. if (client_connection)
  36. client_connection->async_execution_error(statement_id(), (int)result.error(), result.error_string());
  37. else
  38. warnln("Cannot return execution error. Client disconnected");
  39. m_statement = nullptr;
  40. m_result = {};
  41. }
  42. void SQLStatement::execute()
  43. {
  44. dbgln_if(SQLSERVER_DEBUG, "SQLStatement::execute(statement_id {}", statement_id());
  45. auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
  46. if (!client_connection) {
  47. warnln("Cannot yield next result. Client disconnected");
  48. return;
  49. }
  50. deferred_invoke([this]() mutable {
  51. auto parse_result = parse();
  52. if (parse_result.is_error()) {
  53. report_error(parse_result.release_error());
  54. return;
  55. }
  56. VERIFY(!connection()->database().is_null());
  57. auto execution_result = m_statement->execute(connection()->database().release_nonnull());
  58. if (execution_result.is_error()) {
  59. report_error(execution_result.release_error());
  60. return;
  61. }
  62. auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
  63. if (!client_connection) {
  64. warnln("Cannot return statement execution results. Client disconnected");
  65. return;
  66. }
  67. m_result = execution_result.release_value();
  68. if (should_send_result_rows()) {
  69. client_connection->async_execution_success(statement_id(), true, 0, 0, 0);
  70. m_index = 0;
  71. next();
  72. } else {
  73. client_connection->async_execution_success(statement_id(), false, 0, m_result->size(), 0);
  74. }
  75. });
  76. }
  77. SQL::ResultOr<void> SQLStatement::parse()
  78. {
  79. auto parser = SQL::AST::Parser(SQL::AST::Lexer(m_sql));
  80. m_statement = parser.next_statement();
  81. if (parser.has_errors())
  82. return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::SyntaxError, parser.errors()[0].to_string() };
  83. return {};
  84. }
  85. bool SQLStatement::should_send_result_rows() const
  86. {
  87. VERIFY(m_result.has_value());
  88. if (m_result->is_empty())
  89. return false;
  90. switch (m_result->command()) {
  91. case SQL::SQLCommand::Describe:
  92. case SQL::SQLCommand::Select:
  93. return true;
  94. default:
  95. return false;
  96. }
  97. }
  98. void SQLStatement::next()
  99. {
  100. VERIFY(!m_result->is_empty());
  101. auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
  102. if (!client_connection) {
  103. warnln("Cannot yield next result. Client disconnected");
  104. return;
  105. }
  106. if (m_index < m_result->size()) {
  107. auto& tuple = m_result->at(m_index++).row;
  108. client_connection->async_next_result(statement_id(), tuple.to_string_vector());
  109. deferred_invoke([this]() {
  110. next();
  111. });
  112. } else {
  113. client_connection->async_results_exhausted(statement_id(), (int)m_index);
  114. }
  115. }
  116. }