Pārlūkot izejas kodu

SQLServer: Parse SQL a single time to actually "prepare" the statement

One of the benefits of prepared statements is that the SQL string is
parsed just once and re-used. This updates SQLStatement to do just that
and store the parsed result.
Timothy Flynn 2 gadi atpakaļ
vecāks
revīzija
b13527b8b2

+ 11 - 5
Userland/Services/SQLServer/ConnectionFromClient.cpp

@@ -54,15 +54,21 @@ void ConnectionFromClient::disconnect(int connection_id)
 Messages::SQLServer::PrepareStatementResponse ConnectionFromClient::prepare_statement(int connection_id, DeprecatedString const& sql)
 Messages::SQLServer::PrepareStatementResponse ConnectionFromClient::prepare_statement(int connection_id, DeprecatedString const& sql)
 {
 {
     dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::prepare_statement(connection_id: {}, sql: '{}')", connection_id, sql);
     dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::prepare_statement(connection_id: {}, sql: '{}')", connection_id, sql);
+
     auto database_connection = DatabaseConnection::connection_for(connection_id);
     auto database_connection = DatabaseConnection::connection_for(connection_id);
-    if (database_connection) {
-        auto statement_id = database_connection->prepare_statement(sql);
-        dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::prepare_statement -> statement_id = {}", statement_id);
-        return { statement_id };
-    } else {
+    if (!database_connection) {
         dbgln("Database connection has disappeared");
         dbgln("Database connection has disappeared");
         return { -1 };
         return { -1 };
     }
     }
+
+    auto result = database_connection->prepare_statement(sql);
+    if (result.is_error()) {
+        dbgln_if(SQLSERVER_DEBUG, "Could not parse SQL statement: {}", result.error().error_string());
+        return { -1 };
+    }
+
+    dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::prepare_statement -> statement_id = {}", result.value());
+    return { result.value() };
 }
 }
 
 
 void ConnectionFromClient::execute_statement(int statement_id)
 void ConnectionFromClient::execute_statement(int statement_id)

+ 8 - 7
Userland/Services/SQLServer/DatabaseConnection.cpp

@@ -67,19 +67,20 @@ void DatabaseConnection::disconnect()
     });
     });
 }
 }
 
 
-int DatabaseConnection::prepare_statement(DeprecatedString const& sql)
+SQL::ResultOr<int> DatabaseConnection::prepare_statement(StringView sql)
 {
 {
     dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::prepare_statement(connection_id {}, database '{}', sql '{}'", connection_id(), m_database_name, sql);
     dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::prepare_statement(connection_id {}, database '{}', sql '{}'", connection_id(), m_database_name, sql);
+
+    if (!m_accept_statements)
+        return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::DatabaseUnavailable };
+
     auto client_connection = ConnectionFromClient::client_connection_for(client_id());
     auto client_connection = ConnectionFromClient::client_connection_for(client_id());
     if (!client_connection) {
     if (!client_connection) {
         warnln("Cannot notify client of database disconnection. Client disconnected");
         warnln("Cannot notify client of database disconnection. Client disconnected");
-        return -1;
+        return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::InternalError, "Client disconnected"sv };
     }
     }
-    if (!m_accept_statements) {
-        client_connection->async_execution_error(-1, (int)SQL::SQLErrorCode::DatabaseUnavailable, m_database_name);
-        return -1;
-    }
-    auto statement = SQLStatement::construct(*this, sql);
+
+    auto statement = TRY(SQLStatement::create(*this, sql));
     return statement->statement_id();
     return statement->statement_id();
 }
 }
 
 

+ 2 - 1
Userland/Services/SQLServer/DatabaseConnection.h

@@ -8,6 +8,7 @@
 
 
 #include <LibCore/Object.h>
 #include <LibCore/Object.h>
 #include <LibSQL/Database.h>
 #include <LibSQL/Database.h>
+#include <LibSQL/Result.h>
 #include <SQLServer/Forward.h>
 #include <SQLServer/Forward.h>
 
 
 namespace SQLServer {
 namespace SQLServer {
@@ -23,7 +24,7 @@ public:
     int client_id() const { return m_client_id; }
     int client_id() const { return m_client_id; }
     RefPtr<SQL::Database> database() { return m_database; }
     RefPtr<SQL::Database> database() { return m_database; }
     void disconnect();
     void disconnect();
-    int prepare_statement(DeprecatedString const& sql);
+    SQL::ResultOr<int> prepare_statement(StringView sql);
 
 
 private:
 private:
     DatabaseConnection(DeprecatedString database_name, int client_id);
     DatabaseConnection(DeprecatedString database_name, int client_id);

+ 14 - 20
Userland/Services/SQLServer/SQLStatement.cpp

@@ -24,12 +24,23 @@ RefPtr<SQLStatement> SQLStatement::statement_for(int statement_id)
 
 
 static int s_next_statement_id = 0;
 static int s_next_statement_id = 0;
 
 
-SQLStatement::SQLStatement(DatabaseConnection& connection, DeprecatedString sql)
+SQL::ResultOr<NonnullRefPtr<SQLStatement>> SQLStatement::create(DatabaseConnection& connection, StringView sql)
+{
+    auto parser = SQL::AST::Parser(SQL::AST::Lexer(sql));
+    auto statement = parser.next_statement();
+
+    if (parser.has_errors())
+        return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::SyntaxError, parser.errors()[0].to_deprecated_string() };
+
+    return TRY(adopt_nonnull_ref_or_enomem(new (nothrow) SQLStatement(connection, move(statement))));
+}
+
+SQLStatement::SQLStatement(DatabaseConnection& connection, NonnullRefPtr<SQL::AST::Statement> statement)
     : Core::Object(&connection)
     : Core::Object(&connection)
     , m_statement_id(s_next_statement_id++)
     , m_statement_id(s_next_statement_id++)
-    , m_sql(move(sql))
+    , m_statement(move(statement))
 {
 {
-    dbgln_if(SQLSERVER_DEBUG, "SQLStatement({}, {})", connection.connection_id(), sql);
+    dbgln_if(SQLSERVER_DEBUG, "SQLStatement({})", connection.connection_id());
     s_statements.set(m_statement_id, *this);
     s_statements.set(m_statement_id, *this);
 }
 }
 
 
@@ -47,7 +58,6 @@ void SQLStatement::report_error(SQL::Result result)
     else
     else
         warnln("Cannot return execution error. Client disconnected");
         warnln("Cannot return execution error. Client disconnected");
 
 
-    m_statement = nullptr;
     m_result = {};
     m_result = {};
 }
 }
 
 
@@ -61,12 +71,6 @@ void SQLStatement::execute()
     }
     }
 
 
     deferred_invoke([this] {
     deferred_invoke([this] {
-        auto parse_result = parse();
-        if (parse_result.is_error()) {
-            report_error(parse_result.release_error());
-            return;
-        }
-
         VERIFY(!connection()->database().is_null());
         VERIFY(!connection()->database().is_null());
 
 
         auto execution_result = m_statement->execute(connection()->database().release_nonnull());
         auto execution_result = m_statement->execute(connection()->database().release_nonnull());
@@ -93,16 +97,6 @@ void SQLStatement::execute()
     });
     });
 }
 }
 
 
-SQL::ResultOr<void> SQLStatement::parse()
-{
-    auto parser = SQL::AST::Parser(SQL::AST::Lexer(m_sql));
-    m_statement = parser.next_statement();
-
-    if (parser.has_errors())
-        return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::SyntaxError, parser.errors()[0].to_deprecated_string() };
-    return {};
-}
-
 bool SQLStatement::should_send_result_rows() const
 bool SQLStatement::should_send_result_rows() const
 {
 {
     VERIFY(m_result.has_value());
     VERIFY(m_result.has_value());

+ 5 - 6
Userland/Services/SQLServer/SQLStatement.h

@@ -18,28 +18,27 @@
 namespace SQLServer {
 namespace SQLServer {
 
 
 class SQLStatement final : public Core::Object {
 class SQLStatement final : public Core::Object {
-    C_OBJECT(SQLStatement)
+    C_OBJECT_ABSTRACT(SQLStatement)
 
 
 public:
 public:
+    static SQL::ResultOr<NonnullRefPtr<SQLStatement>> create(DatabaseConnection&, StringView sql);
     ~SQLStatement() override = default;
     ~SQLStatement() override = default;
 
 
     static RefPtr<SQLStatement> statement_for(int statement_id);
     static RefPtr<SQLStatement> statement_for(int statement_id);
     int statement_id() const { return m_statement_id; }
     int statement_id() const { return m_statement_id; }
-    DeprecatedString const& sql() const { return m_sql; }
     DatabaseConnection* connection() { return dynamic_cast<DatabaseConnection*>(parent()); }
     DatabaseConnection* connection() { return dynamic_cast<DatabaseConnection*>(parent()); }
     void execute();
     void execute();
 
 
 private:
 private:
-    SQLStatement(DatabaseConnection&, DeprecatedString sql);
-    SQL::ResultOr<void> parse();
+    SQLStatement(DatabaseConnection&, NonnullRefPtr<SQL::AST::Statement> statement);
+
     bool should_send_result_rows() const;
     bool should_send_result_rows() const;
     void next();
     void next();
     void report_error(SQL::Result);
     void report_error(SQL::Result);
 
 
     int m_statement_id;
     int m_statement_id;
-    DeprecatedString m_sql;
     size_t m_index { 0 };
     size_t m_index { 0 };
-    RefPtr<SQL::AST::Statement> m_statement { nullptr };
+    NonnullRefPtr<SQL::AST::Statement> m_statement;
     Optional<SQL::ResultSet> m_result {};
     Optional<SQL::ResultSet> m_result {};
 };
 };