diff --git a/Tests/LibSQL/TestSqlStatementExecution.cpp b/Tests/LibSQL/TestSqlStatementExecution.cpp index dadcb8d100d..87228dad590 100644 --- a/Tests/LibSQL/TestSqlStatementExecution.cpp +++ b/Tests/LibSQL/TestSqlStatementExecution.cpp @@ -886,4 +886,123 @@ TEST_CASE(delete_all_rows) } } +TEST_CASE(update_single_row) +{ + ScopeGuard guard([]() { unlink(db_name); }); + { + auto database = SQL::Database::construct(db_name); + EXPECT(!database->open().is_error()); + + create_table(database); + for (auto count = 0; count < 10; ++count) { + auto result = execute(database, DeprecatedString::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count)); + EXPECT_EQ(result.size(), 1u); + } + + execute(database, "UPDATE TestSchema.TestTable SET IntColumn=123456 WHERE (TextColumn = 'T3');"); + + auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;"); + EXPECT_EQ(result.size(), 10u); + + for (auto i = 0u; i < 10; ++i) { + if (i < 3) + EXPECT_EQ(result[i].row[0], i); + else if (i < 9) + EXPECT_EQ(result[i].row[0], i + 1); + else + EXPECT_EQ(result[i].row[0], 123456); + } + } + { + auto database = SQL::Database::construct(db_name); + EXPECT(!database->open().is_error()); + + auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;"); + EXPECT_EQ(result.size(), 10u); + + for (auto i = 0u; i < 10; ++i) { + if (i < 3) + EXPECT_EQ(result[i].row[0], i); + else if (i < 9) + EXPECT_EQ(result[i].row[0], i + 1); + else + EXPECT_EQ(result[i].row[0], 123456); + } + } +} + +TEST_CASE(update_multiple_rows) +{ + ScopeGuard guard([]() { unlink(db_name); }); + { + auto database = SQL::Database::construct(db_name); + EXPECT(!database->open().is_error()); + + create_table(database); + for (auto count = 0; count < 10; ++count) { + auto result = execute(database, DeprecatedString::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count)); + EXPECT_EQ(result.size(), 1u); + } + + execute(database, "UPDATE TestSchema.TestTable SET IntColumn=123456 WHERE (IntColumn > 4);"); + + auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;"); + EXPECT_EQ(result.size(), 10u); + + for (auto i = 0u; i < 10; ++i) { + if (i < 5) + EXPECT_EQ(result[i].row[0], i); + else + EXPECT_EQ(result[i].row[0], 123456); + } + } + { + auto database = SQL::Database::construct(db_name); + EXPECT(!database->open().is_error()); + + auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;"); + EXPECT_EQ(result.size(), 10u); + + for (auto i = 0u; i < 10; ++i) { + if (i < 5) + EXPECT_EQ(result[i].row[0], i); + else + EXPECT_EQ(result[i].row[0], 123456); + } + } +} + +TEST_CASE(update_all_rows) +{ + ScopeGuard guard([]() { unlink(db_name); }); + { + auto database = SQL::Database::construct(db_name); + EXPECT(!database->open().is_error()); + + create_table(database); + for (auto count = 0; count < 10; ++count) { + auto result = execute(database, DeprecatedString::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count)); + EXPECT_EQ(result.size(), 1u); + } + + execute(database, "UPDATE TestSchema.TestTable SET IntColumn=123456;"); + + auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;"); + EXPECT_EQ(result.size(), 10u); + + for (auto i = 0u; i < 10; ++i) + EXPECT_EQ(result[i].row[0], 123456); + } + { + auto database = SQL::Database::construct(db_name); + EXPECT(!database->open().is_error()); + + auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;"); + EXPECT_EQ(result.size(), 10u); + + for (auto i = 0u; i < 10; ++i) + EXPECT_EQ(result[i].row[0], 123456); + } +} + } diff --git a/Userland/Libraries/LibSQL/AST/AST.h b/Userland/Libraries/LibSQL/AST/AST.h index e1d89956ab3..4d64c9d48ab 100644 --- a/Userland/Libraries/LibSQL/AST/AST.h +++ b/Userland/Libraries/LibSQL/AST/AST.h @@ -992,6 +992,8 @@ public: RefPtr const& where_clause() const { return m_where_clause; } RefPtr const& returning_clause() const { return m_returning_clause; } + virtual ResultOr execute(ExecutionContext&) const override; + private: RefPtr m_common_table_expression_list; ConflictResolution m_conflict_resolution; diff --git a/Userland/Libraries/LibSQL/AST/Insert.cpp b/Userland/Libraries/LibSQL/AST/Insert.cpp index bda6c6fa7d9..76be3de7639 100644 --- a/Userland/Libraries/LibSQL/AST/Insert.cpp +++ b/Userland/Libraries/LibSQL/AST/Insert.cpp @@ -12,15 +12,6 @@ namespace SQL::AST { -static bool does_value_data_type_match(SQLType expected, SQLType actual) -{ - if (actual == SQLType::Null) - return false; - if (expected == SQLType::Integer) - return actual == SQLType::Integer || actual == SQLType::Float; - return expected == actual; -} - ResultOr Insert::execute(ExecutionContext& context) const { auto table_def = TRY(context.database->get_table(m_schema_name, m_table_name)); @@ -49,13 +40,12 @@ ResultOr Insert::execute(ExecutionContext& context) const return Result { SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, DeprecatedString::empty() }; for (auto ix = 0u; ix < values.size(); ix++) { - auto input_value_type = values[ix].type(); auto& tuple_descriptor = *row.descriptor(); // In case of having column names, this must succeed since we checked for every column name for existence in the table. auto element_index = m_column_names.is_empty() ? ix : tuple_descriptor.find_if([&](auto element) { return element.name == m_column_names[ix]; }).index(); auto element_type = tuple_descriptor[element_index].type; - if (!does_value_data_type_match(element_type, input_value_type)) + if (!values[ix].is_type_compatible_with(element_type)) return Result { SQLCommand::Insert, SQLErrorCode::InvalidValueType, table_def->columns()[element_index].name() }; row[element_index] = move(values[ix]); diff --git a/Userland/Libraries/LibSQL/AST/Update.cpp b/Userland/Libraries/LibSQL/AST/Update.cpp new file mode 100644 index 00000000000..de65951f413 --- /dev/null +++ b/Userland/Libraries/LibSQL/AST/Update.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2022, Tim Flynn + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include +#include +#include +#include + +namespace SQL::AST { + +ResultOr Update::execute(ExecutionContext& context) const +{ + auto const& schema_name = m_qualified_table_name->schema_name(); + auto const& table_name = m_qualified_table_name->table_name(); + auto table_def = TRY(context.database->get_table(schema_name, table_name)); + + Vector matched_rows; + + for (auto& table_row : TRY(context.database->select_all(*table_def))) { + context.current_row = &table_row; + + if (auto const& where_clause = this->where_clause()) { + auto where_result = TRY(where_clause->evaluate(context)).to_bool(); + if (!where_result.has_value() || !where_result.value()) + continue; + } + + TRY(matched_rows.try_append(move(table_row))); + } + + ResultSet result { SQLCommand::Update }; + + for (auto& update_column : m_update_columns) { + auto row_value = TRY(update_column.expression->evaluate(context)); + + for (auto& table_row : matched_rows) { + auto& row_descriptor = *table_row.descriptor(); + + for (auto const& column_name : update_column.column_names) { + if (!table_row.has(column_name)) + return Result { SQLCommand::Update, SQLErrorCode::ColumnDoesNotExist, column_name }; + + auto column_index = row_descriptor.find_if([&](auto element) { return element.name == column_name; }).index(); + auto column_type = row_descriptor[column_index].type; + + if (!row_value.is_type_compatible_with(column_type)) + return Result { SQLCommand::Update, SQLErrorCode::InvalidValueType, column_name }; + + table_row[column_index] = row_value; + } + + TRY(context.database->update(table_row)); + result.insert_row(table_row, {}); + } + } + + return result; +} + +} diff --git a/Userland/Libraries/LibSQL/CMakeLists.txt b/Userland/Libraries/LibSQL/CMakeLists.txt index 4a726c45487..aa92d3c0ae6 100644 --- a/Userland/Libraries/LibSQL/CMakeLists.txt +++ b/Userland/Libraries/LibSQL/CMakeLists.txt @@ -11,6 +11,7 @@ set(SOURCES AST/Statement.cpp AST/SyntaxHighlighter.cpp AST/Token.cpp + AST/Update.cpp BTree.cpp BTreeIterator.cpp Database.cpp @@ -26,7 +27,7 @@ set(SOURCES TreeNode.cpp Tuple.cpp Value.cpp - ) +) if (SERENITYOS) list(APPEND SOURCES SQLClient.cpp) diff --git a/Userland/Libraries/LibSQL/Value.cpp b/Userland/Libraries/LibSQL/Value.cpp index 7f7f5c841b2..c7657cc145f 100644 --- a/Userland/Libraries/LibSQL/Value.cpp +++ b/Userland/Libraries/LibSQL/Value.cpp @@ -100,6 +100,21 @@ StringView Value::type_name() const } } +bool Value::is_type_compatible_with(SQLType other_type) const +{ + switch (type()) { + case SQLType::Null: + return false; + case SQLType::Integer: + case SQLType::Float: + return other_type == SQLType::Integer || other_type == SQLType::Float; + default: + break; + } + + return type() == other_type; +} + bool Value::is_null() const { return !m_value.has_value(); diff --git a/Userland/Libraries/LibSQL/Value.h b/Userland/Libraries/LibSQL/Value.h index 8ae98a41026..0cc5d3658b3 100644 --- a/Userland/Libraries/LibSQL/Value.h +++ b/Userland/Libraries/LibSQL/Value.h @@ -47,6 +47,7 @@ public: [[nodiscard]] SQLType type() const; [[nodiscard]] StringView type_name() const; + [[nodiscard]] bool is_type_compatible_with(SQLType) const; [[nodiscard]] bool is_null() const; [[nodiscard]] DeprecatedString to_deprecated_string() const;