Преглед изворни кода

LibSQL+SQLServer+SQLStudio+sql: Allocate per-statement-execution IDs

In order to execute a prepared statement multiple times, and track each
execution's results, clients will need to be provided an execution ID.
This will create a monotonically increasing ID each time a prepared
statement is executed for this purpose.
Timothy Flynn пре 2 година
родитељ
комит
aec75d749a

+ 3 - 3
Userland/DevTools/SQLStudio/MainWidget.cpp

@@ -214,13 +214,13 @@ MainWidget::MainWidget()
     m_statusbar->segment(2).set_fixed_width(font().width("Ln 0000, Col 000"sv) + font().max_glyph_width());
 
     m_sql_client = SQL::SQLClient::try_create().release_value_but_fixme_should_propagate_errors();
-    m_sql_client->on_execution_success = [this](auto, auto, auto, auto, auto) {
+    m_sql_client->on_execution_success = [this](auto, auto, auto, auto, auto, auto) {
         read_next_sql_statement_of_editor();
     };
-    m_sql_client->on_next_result = [this](auto, auto const& row) {
+    m_sql_client->on_next_result = [this](auto, auto, auto const& row) {
         m_results.append(row);
     };
-    m_sql_client->on_results_exhausted = [this](auto, auto) {
+    m_sql_client->on_results_exhausted = [this](auto, auto, auto) {
         if (m_results.size() == 0)
             return;
         if (m_results[0].size() == 0)

+ 8 - 8
Userland/Libraries/LibSQL/SQLClient.cpp

@@ -29,26 +29,26 @@ void SQLClient::connection_error(u64 connection_id, SQLErrorCode const& code, De
         warnln("Connection error for connection_id {}: {} ({})", connection_id, message, to_underlying(code));
 }
 
-void SQLClient::execution_error(u64 statement_id, SQLErrorCode const& code, DeprecatedString const& message)
+void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message)
 {
     if (on_execution_error)
-        on_execution_error(statement_id, code, message);
+        on_execution_error(statement_id, execution_id, code, message);
     else
         warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code));
 }
 
-void SQLClient::execution_success(u64 statement_id, bool has_results, size_t created, size_t updated, size_t deleted)
+void SQLClient::execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted)
 {
     if (on_execution_success)
-        on_execution_success(statement_id, has_results, created, updated, deleted);
+        on_execution_success(statement_id, execution_id, has_results, created, updated, deleted);
     else
         outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted);
 }
 
-void SQLClient::next_result(u64 statement_id, Vector<DeprecatedString> const& row)
+void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector<DeprecatedString> const& row)
 {
     if (on_next_result) {
-        on_next_result(statement_id, row);
+        on_next_result(statement_id, execution_id, row);
         return;
     }
     bool first = true;
@@ -61,10 +61,10 @@ void SQLClient::next_result(u64 statement_id, Vector<DeprecatedString> const& ro
     outln();
 }
 
-void SQLClient::results_exhausted(u64 statement_id, size_t total_rows)
+void SQLClient::results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows)
 {
     if (on_results_exhausted)
-        on_results_exhausted(statement_id, total_rows);
+        on_results_exhausted(statement_id, execution_id, total_rows);
     else
         outln("{} total row(s)", total_rows);
 }

+ 8 - 8
Userland/Libraries/LibSQL/SQLClient.h

@@ -23,10 +23,10 @@ class SQLClient
     Function<void(u64, DeprecatedString const&)> on_connected;
     Function<void(u64)> on_disconnected;
     Function<void(u64, SQLErrorCode, DeprecatedString const&)> on_connection_error;
-    Function<void(u64, SQLErrorCode, DeprecatedString const&)> on_execution_error;
-    Function<void(u64, bool, size_t, size_t, size_t)> on_execution_success;
-    Function<void(u64, Vector<DeprecatedString> const&)> on_next_result;
-    Function<void(u64, size_t)> on_results_exhausted;
+    Function<void(u64, u64, SQLErrorCode, DeprecatedString const&)> on_execution_error;
+    Function<void(u64, u64, bool, size_t, size_t, size_t)> on_execution_success;
+    Function<void(u64, u64, Vector<DeprecatedString> const&)> on_next_result;
+    Function<void(u64, u64, size_t)> on_results_exhausted;
 
 private:
     SQLClient(NonnullOwnPtr<Core::Stream::LocalSocket> socket)
@@ -36,10 +36,10 @@ private:
 
     virtual void connected(u64 connection_id, DeprecatedString const& connected_to_database) override;
     virtual void connection_error(u64 connection_id, SQLErrorCode const& code, DeprecatedString const& message) override;
-    virtual void execution_success(u64 statement_id, bool has_results, size_t created, size_t updated, size_t deleted) override;
-    virtual void next_result(u64 statement_id, Vector<DeprecatedString> const&) override;
-    virtual void results_exhausted(u64 statement_id, size_t total_rows) override;
-    virtual void execution_error(u64 statement_id, SQLErrorCode const& code, DeprecatedString const& message) override;
+    virtual void execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) override;
+    virtual void next_result(u64 statement_id, u64 execution_id, Vector<DeprecatedString> const&) override;
+    virtual void results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) override;
+    virtual void execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) override;
     virtual void disconnected(u64 connection_id) override;
 };
 

+ 7 - 5
Userland/Services/SQLServer/ConnectionFromClient.cpp

@@ -71,17 +71,19 @@ Messages::SQLServer::PrepareStatementResponse ConnectionFromClient::prepare_stat
     return { result.value() };
 }
 
-void ConnectionFromClient::execute_statement(u64 statement_id, Vector<SQL::Value> const& placeholder_values)
+Messages::SQLServer::ExecuteStatementResponse ConnectionFromClient::execute_statement(u64 statement_id, Vector<SQL::Value> const& placeholder_values)
 {
     dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::execute_query_statement(statement_id: {})", statement_id);
+
     auto statement = SQLStatement::statement_for(statement_id);
     if (statement && statement->connection()->client_id() == client_id()) {
         // FIXME: Support taking parameters from IPC requests.
-        statement->execute(move(const_cast<Vector<SQL::Value>&>(placeholder_values)));
-    } else {
-        dbgln_if(SQLSERVER_DEBUG, "Statement has disappeared");
-        async_execution_error(statement_id, SQL::SQLErrorCode::StatementUnavailable, DeprecatedString::formatted("{}", statement_id));
+        return statement->execute(move(const_cast<Vector<SQL::Value>&>(placeholder_values)));
     }
+
+    dbgln_if(SQLSERVER_DEBUG, "Statement has disappeared");
+    async_execution_error(statement_id, -1, SQL::SQLErrorCode::StatementUnavailable, DeprecatedString::formatted("{}", statement_id));
+    return { {} };
 }
 
 }

+ 1 - 1
Userland/Services/SQLServer/ConnectionFromClient.h

@@ -30,7 +30,7 @@ private:
 
     virtual Messages::SQLServer::ConnectResponse connect(DeprecatedString const&) override;
     virtual Messages::SQLServer::PrepareStatementResponse prepare_statement(u64, DeprecatedString const&) override;
-    virtual void execute_statement(u64, Vector<SQL::Value> const& placeholder_values) override;
+    virtual Messages::SQLServer::ExecuteStatementResponse execute_statement(u64, Vector<SQL::Value> const& placeholder_values) override;
     virtual void disconnect(u64) override;
 };
 

+ 4 - 4
Userland/Services/SQLServer/SQLClient.ipc

@@ -4,9 +4,9 @@ endpoint SQLClient
 {
     connected(u64 connection_id, DeprecatedString connected_to_database) =|
     connection_error(u64 connection_id, SQL::SQLErrorCode code, DeprecatedString message) =|
-    execution_success(u64 statement_id, bool has_results, size_t created, size_t updated, size_t deleted) =|
-    next_result(u64 statement_id, Vector<DeprecatedString> row) =|
-    results_exhausted(u64 statement_id, size_t total_rows) =|
-    execution_error(u64 statement_id, SQL::SQLErrorCode code, DeprecatedString message) =|
+    execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) =|
+    next_result(u64 statement_id, u64 execution_id, Vector<DeprecatedString> row) =|
+    results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) =|
+    execution_error(u64 statement_id, u64 execution_id, SQL::SQLErrorCode code, DeprecatedString message) =|
     disconnected(u64 connection_id) =|
 }

+ 1 - 1
Userland/Services/SQLServer/SQLServer.ipc

@@ -4,6 +4,6 @@ endpoint SQLServer
 {
     connect(DeprecatedString name) => (u64 connection_id)
     prepare_statement(u64 connection_id, DeprecatedString statement) => (Optional<u64> statement_id)
-    execute_statement(u64 statement_id, Vector<SQL::Value> placeholder_values) =|
+    execute_statement(u64 statement_id, Vector<SQL::Value> placeholder_values) => (Optional<u64> execution_id)
     disconnect(u64 connection_id) =|
 }

+ 23 - 14
Userland/Services/SQLServer/SQLStatement.cpp

@@ -43,7 +43,7 @@ SQLStatement::SQLStatement(DatabaseConnection& connection, NonnullRefPtr<SQL::AS
     s_statements.set(m_statement_id, *this);
 }
 
-void SQLStatement::report_error(SQL::Result result)
+void SQLStatement::report_error(SQL::Result result, u64 execution_id)
 {
     dbgln_if(SQLSERVER_DEBUG, "SQLStatement::report_error(statement_id {}, error {}", statement_id(), result.error_string());
 
@@ -53,29 +53,34 @@ void SQLStatement::report_error(SQL::Result result)
     remove_from_parent();
 
     if (client_connection)
-        client_connection->async_execution_error(statement_id(), result.error(), result.error_string());
+        client_connection->async_execution_error(statement_id(), execution_id, result.error(), result.error_string());
     else
         warnln("Cannot return execution error. Client disconnected");
 
     m_result = {};
 }
 
-void SQLStatement::execute(Vector<SQL::Value> placeholder_values)
+Optional<u64> SQLStatement::execute(Vector<SQL::Value> placeholder_values)
 {
     dbgln_if(SQLSERVER_DEBUG, "SQLStatement::execute(statement_id {}", statement_id());
 
     auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
     if (!client_connection) {
         warnln("Cannot yield next result. Client disconnected");
-        return;
+        return {};
     }
 
-    deferred_invoke([this, placeholder_values = move(placeholder_values)] {
+    auto execution_id = m_next_execution_id++;
+    m_ongoing_executions.set(execution_id);
+
+    deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] {
         VERIFY(!connection()->database().is_null());
 
         auto execution_result = m_statement->execute(connection()->database().release_nonnull(), placeholder_values);
+        m_ongoing_executions.remove(execution_id);
+
         if (execution_result.is_error()) {
-            report_error(execution_result.release_error());
+            report_error(execution_result.release_error(), execution_id);
             return;
         }
 
@@ -88,13 +93,15 @@ void SQLStatement::execute(Vector<SQL::Value> placeholder_values)
         m_result = execution_result.release_value();
 
         if (should_send_result_rows()) {
-            client_connection->async_execution_success(statement_id(), true, 0, 0, 0);
+            client_connection->async_execution_success(statement_id(), execution_id, true, 0, 0, 0);
             m_index = 0;
-            next();
+            next(execution_id);
         } else {
-            client_connection->async_execution_success(statement_id(), false, 0, m_result->size(), 0);
+            client_connection->async_execution_success(statement_id(), execution_id, false, 0, m_result->size(), 0);
         }
     });
+
+    return execution_id;
 }
 
 bool SQLStatement::should_send_result_rows() const
@@ -113,22 +120,24 @@ bool SQLStatement::should_send_result_rows() const
     }
 }
 
-void SQLStatement::next()
+void SQLStatement::next(u64 execution_id)
 {
     VERIFY(!m_result->is_empty());
+
     auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
     if (!client_connection) {
         warnln("Cannot yield next result. Client disconnected");
         return;
     }
+
     if (m_index < m_result->size()) {
         auto& tuple = m_result->at(m_index++).row;
-        client_connection->async_next_result(statement_id(), tuple.to_deprecated_string_vector());
-        deferred_invoke([this]() {
-            next();
+        client_connection->async_next_result(statement_id(), execution_id, tuple.to_deprecated_string_vector());
+        deferred_invoke([this, execution_id]() {
+            next(execution_id);
         });
     } else {
-        client_connection->async_results_exhausted(statement_id(), m_index);
+        client_connection->async_results_exhausted(statement_id(), execution_id, m_index);
     }
 }
 

+ 7 - 3
Userland/Services/SQLServer/SQLStatement.h

@@ -28,17 +28,21 @@ public:
     static RefPtr<SQLStatement> statement_for(u64 statement_id);
     u64 statement_id() const { return m_statement_id; }
     DatabaseConnection* connection() { return dynamic_cast<DatabaseConnection*>(parent()); }
-    void execute(Vector<SQL::Value> placeholder_values);
+    Optional<u64> execute(Vector<SQL::Value> placeholder_values);
 
 private:
     SQLStatement(DatabaseConnection&, NonnullRefPtr<SQL::AST::Statement> statement);
 
     bool should_send_result_rows() const;
-    void next();
-    void report_error(SQL::Result);
+    void next(u64 execution_id);
+    void report_error(SQL::Result, u64 execution_id);
 
     u64 m_statement_id { 0 };
     size_t m_index { 0 };
+
+    HashTable<u64> m_ongoing_executions;
+    u64 m_next_execution_id { 0 };
+
     NonnullRefPtr<SQL::AST::Statement> m_statement;
     Optional<SQL::ResultSet> m_result {};
 };

+ 4 - 4
Userland/Utilities/sql.cpp

@@ -84,7 +84,7 @@ public:
             read_sql();
         };
 
-        m_sql_client->on_execution_success = [this](auto, auto has_results, auto updated, auto created, auto deleted) {
+        m_sql_client->on_execution_success = [this](auto, auto, auto has_results, auto updated, auto created, auto deleted) {
             if (updated != 0 || created != 0 || deleted != 0) {
                 outln("{} row(s) updated, {} created, {} deleted", updated, created, deleted);
             }
@@ -93,13 +93,13 @@ public:
             }
         };
 
-        m_sql_client->on_next_result = [](auto, auto const& row) {
+        m_sql_client->on_next_result = [](auto, auto, auto const& row) {
             StringBuilder builder;
             builder.join(", "sv, row);
             outln("{}", builder.build());
         };
 
-        m_sql_client->on_results_exhausted = [this](auto, auto total_rows) {
+        m_sql_client->on_results_exhausted = [this](auto, auto, auto total_rows) {
             outln("{} row(s)", total_rows);
             read_sql();
         };
@@ -109,7 +109,7 @@ public:
             m_loop.quit(to_underlying(code));
         };
 
-        m_sql_client->on_execution_error = [this](auto, auto, auto const& message) {
+        m_sql_client->on_execution_error = [this](auto, auto, auto, auto const& message) {
             outln("\033[33;1mExecution error:\033[0m {}", message);
             read_sql();
         };