Kaynağa Gözat

SQLServer: Remove Core::EventReceiver parent from SQLStatement

This relationship was only used to provide factory methods and a parent-
child relationship between SQLStatement and DatabaseConnection.
Timothy Flynn 2 yıl önce
ebeveyn
işleme
08d77ca6b1

+ 1 - 1
Userland/Services/SQLServer/ConnectionFromClient.cpp

@@ -87,7 +87,7 @@ Messages::SQLServer::ExecuteStatementResponse ConnectionFromClient::execute_stat
     dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::execute_query_statement(statement_id: {})", statement_id);
 
     auto statement = SQLStatement::statement_for(statement_id);
-    if (statement && statement->connection()->client_id() == client_id()) {
+    if (statement && statement->connection().client_id() == client_id()) {
         // FIXME: Support taking parameters from IPC requests.
         return statement->execute(move(const_cast<Vector<SQL::Value>&>(placeholder_values)));
     }

+ 8 - 9
Userland/Services/SQLServer/SQLStatement.cpp

@@ -35,7 +35,7 @@ SQL::ResultOr<NonnullRefPtr<SQLStatement>> SQLStatement::create(DatabaseConnecti
 }
 
 SQLStatement::SQLStatement(DatabaseConnection& connection, NonnullRefPtr<SQL::AST::Statement> statement)
-    : Core::EventReceiver(&connection)
+    : m_connection(connection)
     , m_statement_id(s_next_statement_id++)
     , m_statement(move(statement))
 {
@@ -47,10 +47,9 @@ void SQLStatement::report_error(SQL::Result result, SQL::ExecutionID execution_i
 {
     dbgln_if(SQLSERVER_DEBUG, "SQLStatement::report_error(statement_id {}, error {}", statement_id(), result.error_string());
 
-    auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
+    auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id());
 
     s_statements.remove(statement_id());
-    remove_from_parent();
 
     if (client_connection)
         client_connection->async_execution_error(statement_id(), execution_id, result.error(), result.error_string());
@@ -62,7 +61,7 @@ Optional<SQL::ExecutionID> SQLStatement::execute(Vector<SQL::Value> placeholder_
 {
     dbgln_if(SQLSERVER_DEBUG, "SQLStatement::execute(statement_id {}", statement_id());
 
-    auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
+    auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id());
     if (!client_connection) {
         warnln("Cannot yield next result. Client disconnected");
         return {};
@@ -71,8 +70,8 @@ Optional<SQL::ExecutionID> SQLStatement::execute(Vector<SQL::Value> placeholder_
     auto execution_id = m_next_execution_id++;
     m_ongoing_executions.set(execution_id);
 
-    deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] {
-        auto execution_result = m_statement->execute(connection()->database(), placeholder_values);
+    Core::deferred_invoke([this, strong_this = NonnullRefPtr(*this), placeholder_values = move(placeholder_values), execution_id] {
+        auto execution_result = m_statement->execute(connection().database(), placeholder_values);
         m_ongoing_executions.remove(execution_id);
 
         if (execution_result.is_error()) {
@@ -80,7 +79,7 @@ Optional<SQL::ExecutionID> SQLStatement::execute(Vector<SQL::Value> placeholder_
             return;
         }
 
-        auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
+        auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id());
         if (!client_connection) {
             warnln("Cannot return statement execution results. Client disconnected");
             return;
@@ -124,7 +123,7 @@ bool SQLStatement::should_send_result_rows(SQL::ResultSet const& result) const
 
 void SQLStatement::next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size)
 {
-    auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
+    auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id());
     if (!client_connection) {
         warnln("Cannot yield next result. Client disconnected");
         return;
@@ -134,7 +133,7 @@ void SQLStatement::next(SQL::ExecutionID execution_id, SQL::ResultSet result, si
         auto result_row = result.take_first();
         client_connection->async_next_result(statement_id(), execution_id, result_row.row.take_data());
 
-        deferred_invoke([this, execution_id, result = move(result), result_size]() mutable {
+        Core::deferred_invoke([this, strong_this = NonnullRefPtr(*this), execution_id, result = move(result), result_size]() mutable {
             next(execution_id, move(result), result_size);
         });
     } else {

+ 4 - 6
Userland/Services/SQLServer/SQLStatement.h

@@ -7,8 +7,8 @@
 #pragma once
 
 #include <AK/NonnullRefPtr.h>
+#include <AK/RefCounted.h>
 #include <AK/Vector.h>
-#include <LibCore/EventReceiver.h>
 #include <LibSQL/AST/AST.h>
 #include <LibSQL/Result.h>
 #include <LibSQL/ResultSet.h>
@@ -18,16 +18,13 @@
 
 namespace SQLServer {
 
-class SQLStatement final : public Core::EventReceiver {
-    C_OBJECT_ABSTRACT(SQLStatement)
-
+class SQLStatement final : public RefCounted<SQLStatement> {
 public:
     static SQL::ResultOr<NonnullRefPtr<SQLStatement>> create(DatabaseConnection&, StringView sql);
-    ~SQLStatement() override = default;
 
     static RefPtr<SQLStatement> statement_for(SQL::StatementID statement_id);
     SQL::StatementID statement_id() const { return m_statement_id; }
-    DatabaseConnection* connection() { return dynamic_cast<DatabaseConnection*>(parent()); }
+    DatabaseConnection& connection() { return m_connection; }
     Optional<SQL::ExecutionID> execute(Vector<SQL::Value> placeholder_values);
 
 private:
@@ -37,6 +34,7 @@ private:
     void next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size);
     void report_error(SQL::Result, SQL::ExecutionID execution_id);
 
+    DatabaseConnection& m_connection;
     SQL::StatementID m_statement_id { 0 };
 
     HashTable<SQL::ExecutionID> m_ongoing_executions;