LibSQL: Partially implement the UPDATE command

This implements enough to update rows filtered by a WHERE clause.
This commit is contained in:
Timothy Flynn 2022-12-05 07:55:21 -05:00 committed by Andreas Kling
parent 1574f2c3f6
commit 53f8d62ea4
Notes: sideshowbarker 2024-07-17 20:58:35 +09:00
7 changed files with 203 additions and 12 deletions

View file

@ -886,4 +886,123 @@ TEST_CASE(delete_all_rows)
}
}
TEST_CASE(update_single_row)
{
ScopeGuard guard([]() { unlink(db_name); });
{
auto database = SQL::Database::construct(db_name);
EXPECT(!database->open().is_error());
create_table(database);
for (auto count = 0; count < 10; ++count) {
auto result = execute(database, DeprecatedString::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count));
EXPECT_EQ(result.size(), 1u);
}
execute(database, "UPDATE TestSchema.TestTable SET IntColumn=123456 WHERE (TextColumn = 'T3');");
auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
EXPECT_EQ(result.size(), 10u);
for (auto i = 0u; i < 10; ++i) {
if (i < 3)
EXPECT_EQ(result[i].row[0], i);
else if (i < 9)
EXPECT_EQ(result[i].row[0], i + 1);
else
EXPECT_EQ(result[i].row[0], 123456);
}
}
{
auto database = SQL::Database::construct(db_name);
EXPECT(!database->open().is_error());
auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
EXPECT_EQ(result.size(), 10u);
for (auto i = 0u; i < 10; ++i) {
if (i < 3)
EXPECT_EQ(result[i].row[0], i);
else if (i < 9)
EXPECT_EQ(result[i].row[0], i + 1);
else
EXPECT_EQ(result[i].row[0], 123456);
}
}
}
TEST_CASE(update_multiple_rows)
{
ScopeGuard guard([]() { unlink(db_name); });
{
auto database = SQL::Database::construct(db_name);
EXPECT(!database->open().is_error());
create_table(database);
for (auto count = 0; count < 10; ++count) {
auto result = execute(database, DeprecatedString::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count));
EXPECT_EQ(result.size(), 1u);
}
execute(database, "UPDATE TestSchema.TestTable SET IntColumn=123456 WHERE (IntColumn > 4);");
auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
EXPECT_EQ(result.size(), 10u);
for (auto i = 0u; i < 10; ++i) {
if (i < 5)
EXPECT_EQ(result[i].row[0], i);
else
EXPECT_EQ(result[i].row[0], 123456);
}
}
{
auto database = SQL::Database::construct(db_name);
EXPECT(!database->open().is_error());
auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
EXPECT_EQ(result.size(), 10u);
for (auto i = 0u; i < 10; ++i) {
if (i < 5)
EXPECT_EQ(result[i].row[0], i);
else
EXPECT_EQ(result[i].row[0], 123456);
}
}
}
TEST_CASE(update_all_rows)
{
ScopeGuard guard([]() { unlink(db_name); });
{
auto database = SQL::Database::construct(db_name);
EXPECT(!database->open().is_error());
create_table(database);
for (auto count = 0; count < 10; ++count) {
auto result = execute(database, DeprecatedString::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count));
EXPECT_EQ(result.size(), 1u);
}
execute(database, "UPDATE TestSchema.TestTable SET IntColumn=123456;");
auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
EXPECT_EQ(result.size(), 10u);
for (auto i = 0u; i < 10; ++i)
EXPECT_EQ(result[i].row[0], 123456);
}
{
auto database = SQL::Database::construct(db_name);
EXPECT(!database->open().is_error());
auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable ORDER BY IntColumn;");
EXPECT_EQ(result.size(), 10u);
for (auto i = 0u; i < 10; ++i)
EXPECT_EQ(result[i].row[0], 123456);
}
}
}

View file

@ -992,6 +992,8 @@ public:
RefPtr<Expression> const& where_clause() const { return m_where_clause; }
RefPtr<ReturningClause> const& returning_clause() const { return m_returning_clause; }
virtual ResultOr<ResultSet> execute(ExecutionContext&) const override;
private:
RefPtr<CommonTableExpressionList> m_common_table_expression_list;
ConflictResolution m_conflict_resolution;

View file

@ -12,15 +12,6 @@
namespace SQL::AST {
static bool does_value_data_type_match(SQLType expected, SQLType actual)
{
if (actual == SQLType::Null)
return false;
if (expected == SQLType::Integer)
return actual == SQLType::Integer || actual == SQLType::Float;
return expected == actual;
}
ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const
{
auto table_def = TRY(context.database->get_table(m_schema_name, m_table_name));
@ -49,13 +40,12 @@ ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const
return Result { SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, DeprecatedString::empty() };
for (auto ix = 0u; ix < values.size(); ix++) {
auto input_value_type = values[ix].type();
auto& tuple_descriptor = *row.descriptor();
// In case of having column names, this must succeed since we checked for every column name for existence in the table.
auto element_index = m_column_names.is_empty() ? ix : tuple_descriptor.find_if([&](auto element) { return element.name == m_column_names[ix]; }).index();
auto element_type = tuple_descriptor[element_index].type;
if (!does_value_data_type_match(element_type, input_value_type))
if (!values[ix].is_type_compatible_with(element_type))
return Result { SQLCommand::Insert, SQLErrorCode::InvalidValueType, table_def->columns()[element_index].name() };
row[element_index] = move(values[ix]);

View file

@ -0,0 +1,63 @@
/*
* Copyright (c) 2022, Tim Flynn <trflynn89@serenityos.org>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <LibSQL/AST/AST.h>
#include <LibSQL/Database.h>
#include <LibSQL/Meta.h>
#include <LibSQL/Row.h>
namespace SQL::AST {
ResultOr<ResultSet> Update::execute(ExecutionContext& context) const
{
auto const& schema_name = m_qualified_table_name->schema_name();
auto const& table_name = m_qualified_table_name->table_name();
auto table_def = TRY(context.database->get_table(schema_name, table_name));
Vector<Row> matched_rows;
for (auto& table_row : TRY(context.database->select_all(*table_def))) {
context.current_row = &table_row;
if (auto const& where_clause = this->where_clause()) {
auto where_result = TRY(where_clause->evaluate(context)).to_bool();
if (!where_result.has_value() || !where_result.value())
continue;
}
TRY(matched_rows.try_append(move(table_row)));
}
ResultSet result { SQLCommand::Update };
for (auto& update_column : m_update_columns) {
auto row_value = TRY(update_column.expression->evaluate(context));
for (auto& table_row : matched_rows) {
auto& row_descriptor = *table_row.descriptor();
for (auto const& column_name : update_column.column_names) {
if (!table_row.has(column_name))
return Result { SQLCommand::Update, SQLErrorCode::ColumnDoesNotExist, column_name };
auto column_index = row_descriptor.find_if([&](auto element) { return element.name == column_name; }).index();
auto column_type = row_descriptor[column_index].type;
if (!row_value.is_type_compatible_with(column_type))
return Result { SQLCommand::Update, SQLErrorCode::InvalidValueType, column_name };
table_row[column_index] = row_value;
}
TRY(context.database->update(table_row));
result.insert_row(table_row, {});
}
}
return result;
}
}

View file

@ -11,6 +11,7 @@ set(SOURCES
AST/Statement.cpp
AST/SyntaxHighlighter.cpp
AST/Token.cpp
AST/Update.cpp
BTree.cpp
BTreeIterator.cpp
Database.cpp
@ -26,7 +27,7 @@ set(SOURCES
TreeNode.cpp
Tuple.cpp
Value.cpp
)
)
if (SERENITYOS)
list(APPEND SOURCES SQLClient.cpp)

View file

@ -100,6 +100,21 @@ StringView Value::type_name() const
}
}
bool Value::is_type_compatible_with(SQLType other_type) const
{
switch (type()) {
case SQLType::Null:
return false;
case SQLType::Integer:
case SQLType::Float:
return other_type == SQLType::Integer || other_type == SQLType::Float;
default:
break;
}
return type() == other_type;
}
bool Value::is_null() const
{
return !m_value.has_value();

View file

@ -47,6 +47,7 @@ public:
[[nodiscard]] SQLType type() const;
[[nodiscard]] StringView type_name() const;
[[nodiscard]] bool is_type_compatible_with(SQLType) const;
[[nodiscard]] bool is_null() const;
[[nodiscard]] DeprecatedString to_deprecated_string() const;