Ver Fonte

LibSQL: Partially implement the UPDATE command

This implements enough to update rows filtered by a WHERE clause.
Timothy Flynn há 2 anos atrás
pai
commit
53f8d62ea4

+ 119 - 0
Tests/LibSQL/TestSqlStatementExecution.cpp

@@ -886,4 +886,123 @@ TEST_CASE(delete_all_rows)
     }
     }
 }
 }
 
 
+TEST_CASE(update_single_row)
+{
+    ScopeGuard guard([]() { unlink(db_name); });
+    {
+        auto database = SQL::Database::construct(db_name);
+        EXPECT(!database->open().is_error());
+
+        create_table(database);
+        for (auto count = 0; count < 10; ++count) {
+            auto result = execute(database, DeprecatedString::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count));
+            EXPECT_EQ(result.size(), 1u);
+        }
+
+        execute(database, "UPDATE TestSchema.TestTable SET IntColumn=123456 WHERE (TextColumn = 'T3');");
+
+        auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
+        EXPECT_EQ(result.size(), 10u);
+
+        for (auto i = 0u; i < 10; ++i) {
+            if (i < 3)
+                EXPECT_EQ(result[i].row[0], i);
+            else if (i < 9)
+                EXPECT_EQ(result[i].row[0], i + 1);
+            else
+                EXPECT_EQ(result[i].row[0], 123456);
+        }
+    }
+    {
+        auto database = SQL::Database::construct(db_name);
+        EXPECT(!database->open().is_error());
+
+        auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
+        EXPECT_EQ(result.size(), 10u);
+
+        for (auto i = 0u; i < 10; ++i) {
+            if (i < 3)
+                EXPECT_EQ(result[i].row[0], i);
+            else if (i < 9)
+                EXPECT_EQ(result[i].row[0], i + 1);
+            else
+                EXPECT_EQ(result[i].row[0], 123456);
+        }
+    }
+}
+
+TEST_CASE(update_multiple_rows)
+{
+    ScopeGuard guard([]() { unlink(db_name); });
+    {
+        auto database = SQL::Database::construct(db_name);
+        EXPECT(!database->open().is_error());
+
+        create_table(database);
+        for (auto count = 0; count < 10; ++count) {
+            auto result = execute(database, DeprecatedString::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count));
+            EXPECT_EQ(result.size(), 1u);
+        }
+
+        execute(database, "UPDATE TestSchema.TestTable SET IntColumn=123456 WHERE (IntColumn > 4);");
+
+        auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
+        EXPECT_EQ(result.size(), 10u);
+
+        for (auto i = 0u; i < 10; ++i) {
+            if (i < 5)
+                EXPECT_EQ(result[i].row[0], i);
+            else
+                EXPECT_EQ(result[i].row[0], 123456);
+        }
+    }
+    {
+        auto database = SQL::Database::construct(db_name);
+        EXPECT(!database->open().is_error());
+
+        auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
+        EXPECT_EQ(result.size(), 10u);
+
+        for (auto i = 0u; i < 10; ++i) {
+            if (i < 5)
+                EXPECT_EQ(result[i].row[0], i);
+            else
+                EXPECT_EQ(result[i].row[0], 123456);
+        }
+    }
+}
+
+TEST_CASE(update_all_rows)
+{
+    ScopeGuard guard([]() { unlink(db_name); });
+    {
+        auto database = SQL::Database::construct(db_name);
+        EXPECT(!database->open().is_error());
+
+        create_table(database);
+        for (auto count = 0; count < 10; ++count) {
+            auto result = execute(database, DeprecatedString::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count));
+            EXPECT_EQ(result.size(), 1u);
+        }
+
+        execute(database, "UPDATE TestSchema.TestTable SET IntColumn=123456;");
+
+        auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
+        EXPECT_EQ(result.size(), 10u);
+
+        for (auto i = 0u; i < 10; ++i)
+            EXPECT_EQ(result[i].row[0], 123456);
+    }
+    {
+        auto database = SQL::Database::construct(db_name);
+        EXPECT(!database->open().is_error());
+
+        auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
+        EXPECT_EQ(result.size(), 10u);
+
+        for (auto i = 0u; i < 10; ++i)
+            EXPECT_EQ(result[i].row[0], 123456);
+    }
+}
+
 }
 }

+ 2 - 0
Userland/Libraries/LibSQL/AST/AST.h

@@ -992,6 +992,8 @@ public:
     RefPtr<Expression> const& where_clause() const { return m_where_clause; }
     RefPtr<Expression> const& where_clause() const { return m_where_clause; }
     RefPtr<ReturningClause> const& returning_clause() const { return m_returning_clause; }
     RefPtr<ReturningClause> const& returning_clause() const { return m_returning_clause; }
 
 
+    virtual ResultOr<ResultSet> execute(ExecutionContext&) const override;
+
 private:
 private:
     RefPtr<CommonTableExpressionList> m_common_table_expression_list;
     RefPtr<CommonTableExpressionList> m_common_table_expression_list;
     ConflictResolution m_conflict_resolution;
     ConflictResolution m_conflict_resolution;

+ 1 - 11
Userland/Libraries/LibSQL/AST/Insert.cpp

@@ -12,15 +12,6 @@
 
 
 namespace SQL::AST {
 namespace SQL::AST {
 
 
-static bool does_value_data_type_match(SQLType expected, SQLType actual)
-{
-    if (actual == SQLType::Null)
-        return false;
-    if (expected == SQLType::Integer)
-        return actual == SQLType::Integer || actual == SQLType::Float;
-    return expected == actual;
-}
-
 ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const
 ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const
 {
 {
     auto table_def = TRY(context.database->get_table(m_schema_name, m_table_name));
     auto table_def = TRY(context.database->get_table(m_schema_name, m_table_name));
@@ -49,13 +40,12 @@ ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const
             return Result { SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, DeprecatedString::empty() };
             return Result { SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, DeprecatedString::empty() };
 
 
         for (auto ix = 0u; ix < values.size(); ix++) {
         for (auto ix = 0u; ix < values.size(); ix++) {
-            auto input_value_type = values[ix].type();
             auto& tuple_descriptor = *row.descriptor();
             auto& tuple_descriptor = *row.descriptor();
             // In case of having column names, this must succeed since we checked for every column name for existence in the table.
             // In case of having column names, this must succeed since we checked for every column name for existence in the table.
             auto element_index = m_column_names.is_empty() ? ix : tuple_descriptor.find_if([&](auto element) { return element.name == m_column_names[ix]; }).index();
             auto element_index = m_column_names.is_empty() ? ix : tuple_descriptor.find_if([&](auto element) { return element.name == m_column_names[ix]; }).index();
             auto element_type = tuple_descriptor[element_index].type;
             auto element_type = tuple_descriptor[element_index].type;
 
 
-            if (!does_value_data_type_match(element_type, input_value_type))
+            if (!values[ix].is_type_compatible_with(element_type))
                 return Result { SQLCommand::Insert, SQLErrorCode::InvalidValueType, table_def->columns()[element_index].name() };
                 return Result { SQLCommand::Insert, SQLErrorCode::InvalidValueType, table_def->columns()[element_index].name() };
 
 
             row[element_index] = move(values[ix]);
             row[element_index] = move(values[ix]);

+ 63 - 0
Userland/Libraries/LibSQL/AST/Update.cpp

@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2022, Tim Flynn <trflynn89@serenityos.org>
+ *
+ * SPDX-License-Identifier: BSD-2-Clause
+ */
+
+#include <LibSQL/AST/AST.h>
+#include <LibSQL/Database.h>
+#include <LibSQL/Meta.h>
+#include <LibSQL/Row.h>
+
+namespace SQL::AST {
+
+ResultOr<ResultSet> Update::execute(ExecutionContext& context) const
+{
+    auto const& schema_name = m_qualified_table_name->schema_name();
+    auto const& table_name = m_qualified_table_name->table_name();
+    auto table_def = TRY(context.database->get_table(schema_name, table_name));
+
+    Vector<Row> matched_rows;
+
+    for (auto& table_row : TRY(context.database->select_all(*table_def))) {
+        context.current_row = &table_row;
+
+        if (auto const& where_clause = this->where_clause()) {
+            auto where_result = TRY(where_clause->evaluate(context)).to_bool();
+            if (!where_result.has_value() || !where_result.value())
+                continue;
+        }
+
+        TRY(matched_rows.try_append(move(table_row)));
+    }
+
+    ResultSet result { SQLCommand::Update };
+
+    for (auto& update_column : m_update_columns) {
+        auto row_value = TRY(update_column.expression->evaluate(context));
+
+        for (auto& table_row : matched_rows) {
+            auto& row_descriptor = *table_row.descriptor();
+
+            for (auto const& column_name : update_column.column_names) {
+                if (!table_row.has(column_name))
+                    return Result { SQLCommand::Update, SQLErrorCode::ColumnDoesNotExist, column_name };
+
+                auto column_index = row_descriptor.find_if([&](auto element) { return element.name == column_name; }).index();
+                auto column_type = row_descriptor[column_index].type;
+
+                if (!row_value.is_type_compatible_with(column_type))
+                    return Result { SQLCommand::Update, SQLErrorCode::InvalidValueType, column_name };
+
+                table_row[column_index] = row_value;
+            }
+
+            TRY(context.database->update(table_row));
+            result.insert_row(table_row, {});
+        }
+    }
+
+    return result;
+}
+
+}

+ 2 - 1
Userland/Libraries/LibSQL/CMakeLists.txt

@@ -11,6 +11,7 @@ set(SOURCES
     AST/Statement.cpp
     AST/Statement.cpp
     AST/SyntaxHighlighter.cpp
     AST/SyntaxHighlighter.cpp
     AST/Token.cpp
     AST/Token.cpp
+    AST/Update.cpp
     BTree.cpp
     BTree.cpp
     BTreeIterator.cpp
     BTreeIterator.cpp
     Database.cpp
     Database.cpp
@@ -26,7 +27,7 @@ set(SOURCES
     TreeNode.cpp
     TreeNode.cpp
     Tuple.cpp
     Tuple.cpp
     Value.cpp
     Value.cpp
-    )
+)
 
 
 if (SERENITYOS)
 if (SERENITYOS)
     list(APPEND SOURCES SQLClient.cpp)
     list(APPEND SOURCES SQLClient.cpp)

+ 15 - 0
Userland/Libraries/LibSQL/Value.cpp

@@ -100,6 +100,21 @@ StringView Value::type_name() const
     }
     }
 }
 }
 
 
+bool Value::is_type_compatible_with(SQLType other_type) const
+{
+    switch (type()) {
+    case SQLType::Null:
+        return false;
+    case SQLType::Integer:
+    case SQLType::Float:
+        return other_type == SQLType::Integer || other_type == SQLType::Float;
+    default:
+        break;
+    }
+
+    return type() == other_type;
+}
+
 bool Value::is_null() const
 bool Value::is_null() const
 {
 {
     return !m_value.has_value();
     return !m_value.has_value();

+ 1 - 0
Userland/Libraries/LibSQL/Value.h

@@ -47,6 +47,7 @@ public:
 
 
     [[nodiscard]] SQLType type() const;
     [[nodiscard]] SQLType type() const;
     [[nodiscard]] StringView type_name() const;
     [[nodiscard]] StringView type_name() const;
+    [[nodiscard]] bool is_type_compatible_with(SQLType) const;
     [[nodiscard]] bool is_null() const;
     [[nodiscard]] bool is_null() const;
 
 
     [[nodiscard]] DeprecatedString to_deprecated_string() const;
     [[nodiscard]] DeprecatedString to_deprecated_string() const;