SQLStatement.cpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. /*
  2. * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <LibCore/Object.h>
  7. #include <LibSQL/AST/Parser.h>
  8. #include <SQLServer/ClientConnection.h>
  9. #include <SQLServer/DatabaseConnection.h>
  10. #include <SQLServer/SQLStatement.h>
  11. namespace SQLServer {
  12. static HashMap<int, NonnullRefPtr<SQLStatement>> s_statements;
  13. RefPtr<SQLStatement> SQLStatement::statement_for(int statement_id)
  14. {
  15. if (s_statements.contains(statement_id))
  16. return *s_statements.get(statement_id).value();
  17. dbgln_if(SQLSERVER_DEBUG, "Invalid statement_id {}", statement_id);
  18. return nullptr;
  19. }
  20. static int s_next_statement_id = 0;
  21. SQLStatement::SQLStatement(DatabaseConnection& connection, String sql)
  22. : Core::Object(&connection)
  23. , m_statement_id(s_next_statement_id++)
  24. , m_sql(move(sql))
  25. {
  26. dbgln_if(SQLSERVER_DEBUG, "SQLStatement({}, {})", connection.connection_id(), sql);
  27. s_statements.set(m_statement_id, *this);
  28. }
  29. void SQLStatement::report_error(SQL::SQLError error)
  30. {
  31. dbgln_if(SQLSERVER_DEBUG, "SQLStatement::report_error(statement_id {}, error {}", statement_id(), error.to_string());
  32. auto client_connection = ClientConnection::client_connection_for(connection()->client_id());
  33. m_statement = nullptr;
  34. m_result = nullptr;
  35. remove_from_parent();
  36. s_statements.remove(statement_id());
  37. if (!client_connection) {
  38. warnln("Cannot return execution error. Client disconnected");
  39. warnln("SQLStatement::report_error(statement_id {}, error {}", statement_id(), error.to_string());
  40. m_result = nullptr;
  41. return;
  42. }
  43. client_connection->async_execution_error(statement_id(), (int)error.code, error.to_string());
  44. m_result = nullptr;
  45. }
  46. void SQLStatement::execute()
  47. {
  48. dbgln_if(SQLSERVER_DEBUG, "SQLStatement::execute(statement_id {}", statement_id());
  49. auto client_connection = ClientConnection::client_connection_for(connection()->client_id());
  50. if (!client_connection) {
  51. warnln("Cannot yield next result. Client disconnected");
  52. return;
  53. }
  54. deferred_invoke([&](Object&) {
  55. auto maybe_error = parse();
  56. if (maybe_error.has_value()) {
  57. report_error(maybe_error.value());
  58. return;
  59. }
  60. VERIFY(!connection()->database().is_null());
  61. SQL::AST::ExecutionContext context { connection()->database().release_nonnull() };
  62. m_result = m_statement->execute(context);
  63. if (m_result->error().code != SQL::SQLErrorCode::NoError) {
  64. report_error(m_result->error());
  65. return;
  66. }
  67. auto client_connection = ClientConnection::client_connection_for(connection()->client_id());
  68. if (!client_connection) {
  69. warnln("Cannot return statement execution results. Client disconnected");
  70. return;
  71. }
  72. client_connection->async_execution_success(statement_id(), m_result->has_results(), m_result->updated(), m_result->inserted(), m_result->deleted());
  73. if (m_result->has_results()) {
  74. m_index = 0;
  75. next();
  76. }
  77. });
  78. }
  79. Optional<SQL::SQLError> SQLStatement::parse()
  80. {
  81. auto parser = SQL::AST::Parser(SQL::AST::Lexer(m_sql));
  82. m_statement = parser.next_statement();
  83. if (parser.has_errors()) {
  84. return SQL::SQLError { SQL::SQLErrorCode::SyntaxError, parser.errors()[0].to_string() };
  85. }
  86. return {};
  87. }
  88. void SQLStatement::next()
  89. {
  90. VERIFY(m_result->has_results());
  91. auto client_connection = ClientConnection::client_connection_for(connection()->client_id());
  92. if (!client_connection) {
  93. warnln("Cannot yield next result. Client disconnected");
  94. return;
  95. }
  96. if (m_index < m_result->results().size()) {
  97. auto& tuple = m_result->results()[m_index++];
  98. client_connection->async_next_result(statement_id(), tuple.to_string_vector());
  99. deferred_invoke([&](Object&) {
  100. next();
  101. });
  102. } else {
  103. client_connection->async_results_exhausted(statement_id(), (int)m_index);
  104. }
  105. }
  106. }