ソースを参照

LibSQL: Convert binary SQL operations to be fallible

Now that expression evaluation can use TRY, we can allow binary operator
methods to fail as well. This also fixes a few instances of converting a
Value to a double when we meant to convert to an integer.
Timothy Flynn 3 年 前
コミット
bfe1bd9726

+ 125 - 0
Tests/LibSQL/TestSqlStatementExecution.cpp

@@ -7,6 +7,7 @@
 
 #include <unistd.h>
 
+#include <AK/QuickSort.h>
 #include <AK/ScopeGuard.h>
 #include <LibSQL/AST/Parser.h>
 #include <LibSQL/Database.h>
@@ -629,4 +630,128 @@ TEST_CASE(describe_table)
     EXPECT_EQ(result[1].row[1].to_string(), "int");
 }
 
+TEST_CASE(binary_operator_execution)
+{
+    ScopeGuard guard([]() { unlink(db_name); });
+    auto database = SQL::Database::construct(db_name);
+    EXPECT(!database->open().is_error());
+    create_table(database);
+
+    for (auto count = 0; count < 10; ++count) {
+        auto result = execute(database, String::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count));
+        EXPECT_EQ(result.size(), 1u);
+    }
+
+    auto compare_result = [](SQL::ResultSet const& result, Vector<int> const& expected) {
+        EXPECT_EQ(result.command(), SQL::SQLCommand::Select);
+        EXPECT_EQ(result.size(), expected.size());
+
+        Vector<int> result_values;
+        result_values.ensure_capacity(result.size());
+
+        for (size_t i = 0; i < result.size(); ++i) {
+            auto const& result_row = result.at(i).row;
+            EXPECT_EQ(result_row.size(), 1u);
+
+            auto result_column = result_row[0].to_int();
+            result_values.append(result_column.value());
+        }
+
+        quick_sort(result_values);
+        EXPECT_EQ(result_values, expected);
+    };
+
+    auto result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn + 1) < 5);");
+    compare_result(result, { 0, 1, 2, 3 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn + 1) <= 5);");
+    compare_result(result, { 0, 1, 2, 3, 4 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn - 1) > 4);");
+    compare_result(result, { 6, 7, 8, 9 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn - 1) >= 4);");
+    compare_result(result, { 5, 6, 7, 8, 9 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn * 2) < 10);");
+    compare_result(result, { 0, 1, 2, 3, 4 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn * 2) <= 10);");
+    compare_result(result, { 0, 1, 2, 3, 4, 5 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn / 3) > 2);");
+    compare_result(result, { 7, 8, 9 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn / 3) >= 2);");
+    compare_result(result, { 6, 7, 8, 9 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn % 2) = 0);");
+    compare_result(result, { 0, 2, 4, 6, 8 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn % 2) = 1);");
+    compare_result(result, { 1, 3, 5, 7, 9 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((1 << IntColumn) <= 32);");
+    compare_result(result, { 0, 1, 2, 3, 4, 5 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((1024 >> IntColumn) >= 32);");
+    compare_result(result, { 0, 1, 2, 3, 4, 5 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn | 1) != IntColumn);");
+    compare_result(result, { 0, 2, 4, 6, 8 });
+
+    result = execute(database, "SELECT IntColumn FROM TestSchema.TestTable WHERE ((IntColumn & 1) = 1);");
+    compare_result(result, { 1, 3, 5, 7, 9 });
+}
+
+TEST_CASE(binary_operator_failure)
+{
+    ScopeGuard guard([]() { unlink(db_name); });
+    auto database = SQL::Database::construct(db_name);
+    EXPECT(!database->open().is_error());
+    create_table(database);
+
+    for (auto count = 0; count < 10; ++count) {
+        auto result = execute(database, String::formatted("INSERT INTO TestSchema.TestTable VALUES ( 'T{}', {} );", count, count));
+        EXPECT_EQ(result.size(), 1u);
+    }
+
+    auto expect_failure = [](auto result, auto op) {
+        EXPECT(result.is_error());
+
+        auto error = result.release_error();
+        EXPECT_EQ(error.error(), SQL::SQLErrorCode::NumericOperatorTypeMismatch);
+
+        auto message = String::formatted("NumericOperatorTypeMismatch: Cannot apply '{}' operator to non-numeric operands", op);
+        EXPECT_EQ(error.error_string(), message);
+    };
+
+    auto result = try_execute(database, "SELECT * FROM TestSchema.TestTable WHERE ((IntColumn + TextColumn) < 5);");
+    expect_failure(move(result), '+');
+
+    result = try_execute(database, "SELECT * FROM TestSchema.TestTable WHERE ((IntColumn - TextColumn) < 5);");
+    expect_failure(move(result), '-');
+
+    result = try_execute(database, "SELECT * FROM TestSchema.TestTable WHERE ((IntColumn * TextColumn) < 5);");
+    expect_failure(move(result), '*');
+
+    result = try_execute(database, "SELECT * FROM TestSchema.TestTable WHERE ((IntColumn / TextColumn) < 5);");
+    expect_failure(move(result), '/');
+
+    result = try_execute(database, "SELECT * FROM TestSchema.TestTable WHERE ((IntColumn % TextColumn) < 5);");
+    expect_failure(move(result), '%');
+
+    result = try_execute(database, "SELECT * FROM TestSchema.TestTable WHERE ((IntColumn << TextColumn) < 5);");
+    expect_failure(move(result), "<<"sv);
+
+    result = try_execute(database, "SELECT * FROM TestSchema.TestTable WHERE ((IntColumn >> TextColumn) < 5);");
+    expect_failure(move(result), ">>"sv);
+
+    result = try_execute(database, "SELECT * FROM TestSchema.TestTable WHERE ((IntColumn | TextColumn) < 5);");
+    expect_failure(move(result), '|');
+
+    result = try_execute(database, "SELECT * FROM TestSchema.TestTable WHERE ((IntColumn & TextColumn) < 5);");
+    expect_failure(move(result), '&');
+}
+
 }

+ 33 - 50
Userland/Libraries/LibSQL/Value.cpp

@@ -4,6 +4,7 @@
  * SPDX-License-Identifier: BSD-2-Clause
  */
 
+#include <LibSQL/AST/AST.h>
 #include <LibSQL/Serializer.h>
 #include <LibSQL/Value.h>
 #include <math.h>
@@ -390,135 +391,117 @@ bool Value::operator>=(Value const& other) const
     return compare(other) >= 0;
 }
 
-Value Value::add(Value const& other) const
+static Result invalid_type_for_numeric_operator(AST::BinaryOperator op)
+{
+    return { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, BinaryOperator_name(op) };
+}
+
+ResultOr<Value> Value::add(Value const& other) const
 {
     if (auto double_maybe = to_double(); double_maybe.has_value()) {
         if (auto other_double_maybe = other.to_double(); other_double_maybe.has_value())
             return Value(double_maybe.value() + other_double_maybe.value());
         if (auto int_maybe = other.to_int(); int_maybe.has_value())
             return Value(double_maybe.value() + (double)int_maybe.value());
-        VERIFY_NOT_REACHED();
-    }
-    if (auto int_maybe = to_double(); int_maybe.has_value()) {
+    } else if (auto int_maybe = to_int(); int_maybe.has_value()) {
         if (auto other_double_maybe = other.to_double(); other_double_maybe.has_value())
             return Value(other_double_maybe.value() + (double)int_maybe.value());
         if (auto other_int_maybe = other.to_int(); other_int_maybe.has_value())
             return Value(int_maybe.value() + other_int_maybe.value());
-        VERIFY_NOT_REACHED();
     }
-    VERIFY_NOT_REACHED();
+    return invalid_type_for_numeric_operator(AST::BinaryOperator::Plus);
 }
 
-Value Value::subtract(Value const& other) const
+ResultOr<Value> Value::subtract(Value const& other) const
 {
     if (auto double_maybe = to_double(); double_maybe.has_value()) {
         if (auto other_double_maybe = other.to_double(); other_double_maybe.has_value())
             return Value(double_maybe.value() - other_double_maybe.value());
         if (auto int_maybe = other.to_int(); int_maybe.has_value())
             return Value(double_maybe.value() - (double)int_maybe.value());
-        VERIFY_NOT_REACHED();
-    }
-    if (auto int_maybe = to_double(); int_maybe.has_value()) {
+    } else if (auto int_maybe = to_int(); int_maybe.has_value()) {
         if (auto other_double_maybe = other.to_double(); other_double_maybe.has_value())
             return Value((double)int_maybe.value() - other_double_maybe.value());
         if (auto other_int_maybe = other.to_int(); other_int_maybe.has_value())
             return Value(int_maybe.value() - other_int_maybe.value());
-        VERIFY_NOT_REACHED();
     }
-    VERIFY_NOT_REACHED();
+    return invalid_type_for_numeric_operator(AST::BinaryOperator::Minus);
 }
 
-Value Value::multiply(Value const& other) const
+ResultOr<Value> Value::multiply(Value const& other) const
 {
     if (auto double_maybe = to_double(); double_maybe.has_value()) {
         if (auto other_double_maybe = other.to_double(); other_double_maybe.has_value())
             return Value(double_maybe.value() * other_double_maybe.value());
         if (auto int_maybe = other.to_int(); int_maybe.has_value())
             return Value(double_maybe.value() * (double)int_maybe.value());
-        VERIFY_NOT_REACHED();
-    }
-    if (auto int_maybe = to_double(); int_maybe.has_value()) {
+    } else if (auto int_maybe = to_int(); int_maybe.has_value()) {
         if (auto other_double_maybe = other.to_double(); other_double_maybe.has_value())
             return Value((double)int_maybe.value() * other_double_maybe.value());
         if (auto other_int_maybe = other.to_int(); other_int_maybe.has_value())
             return Value(int_maybe.value() * other_int_maybe.value());
-        VERIFY_NOT_REACHED();
     }
-    VERIFY_NOT_REACHED();
+    return invalid_type_for_numeric_operator(AST::BinaryOperator::Multiplication);
 }
 
-Value Value::divide(Value const& other) const
+ResultOr<Value> Value::divide(Value const& other) const
 {
     if (auto double_maybe = to_double(); double_maybe.has_value()) {
         if (auto other_double_maybe = other.to_double(); other_double_maybe.has_value())
             return Value(double_maybe.value() / other_double_maybe.value());
         if (auto int_maybe = other.to_int(); int_maybe.has_value())
             return Value(double_maybe.value() / (double)int_maybe.value());
-        VERIFY_NOT_REACHED();
-    }
-
-    if (auto int_maybe = to_double(); int_maybe.has_value()) {
+    } else if (auto int_maybe = to_int(); int_maybe.has_value()) {
         if (auto other_double_maybe = other.to_double(); other_double_maybe.has_value())
             return Value((double)int_maybe.value() / other_double_maybe.value());
         if (auto other_int_maybe = other.to_int(); other_int_maybe.has_value())
             return Value(int_maybe.value() / other_int_maybe.value());
-        VERIFY_NOT_REACHED();
     }
-    VERIFY_NOT_REACHED();
+    return invalid_type_for_numeric_operator(AST::BinaryOperator::Division);
 }
 
-Value Value::modulo(Value const& other) const
+ResultOr<Value> Value::modulo(Value const& other) const
 {
     auto int_maybe_1 = to_int();
     auto int_maybe_2 = other.to_int();
-    if (!int_maybe_1.has_value() || !int_maybe_2.has_value()) {
-        // TODO Error handling
-        VERIFY_NOT_REACHED();
-    }
+    if (!int_maybe_1.has_value() || !int_maybe_2.has_value())
+        return invalid_type_for_numeric_operator(AST::BinaryOperator::Modulo);
     return Value(int_maybe_1.value() % int_maybe_2.value());
 }
 
-Value Value::shift_left(Value const& other) const
+ResultOr<Value> Value::shift_left(Value const& other) const
 {
     auto u32_maybe = to_u32();
     auto num_bytes_maybe = other.to_int();
-    if (!u32_maybe.has_value() || !num_bytes_maybe.has_value()) {
-        // TODO Error handling
-        VERIFY_NOT_REACHED();
-    }
+    if (!u32_maybe.has_value() || !num_bytes_maybe.has_value())
+        return invalid_type_for_numeric_operator(AST::BinaryOperator::ShiftLeft);
     return Value(u32_maybe.value() << num_bytes_maybe.value());
 }
 
-Value Value::shift_right(Value const& other) const
+ResultOr<Value> Value::shift_right(Value const& other) const
 {
     auto u32_maybe = to_u32();
     auto num_bytes_maybe = other.to_int();
-    if (!u32_maybe.has_value() || !num_bytes_maybe.has_value()) {
-        // TODO Error handling
-        VERIFY_NOT_REACHED();
-    }
+    if (!u32_maybe.has_value() || !num_bytes_maybe.has_value())
+        return invalid_type_for_numeric_operator(AST::BinaryOperator::ShiftRight);
     return Value(u32_maybe.value() >> num_bytes_maybe.value());
 }
 
-Value Value::bitwise_or(Value const& other) const
+ResultOr<Value> Value::bitwise_or(Value const& other) const
 {
     auto u32_maybe_1 = to_u32();
     auto u32_maybe_2 = other.to_u32();
-    if (!u32_maybe_1.has_value() || !u32_maybe_2.has_value()) {
-        // TODO Error handling
-        VERIFY_NOT_REACHED();
-    }
+    if (!u32_maybe_1.has_value() || !u32_maybe_2.has_value())
+        return invalid_type_for_numeric_operator(AST::BinaryOperator::BitwiseOr);
     return Value(u32_maybe_1.value() | u32_maybe_2.value());
 }
 
-Value Value::bitwise_and(Value const& other) const
+ResultOr<Value> Value::bitwise_and(Value const& other) const
 {
     auto u32_maybe_1 = to_u32();
     auto u32_maybe_2 = other.to_u32();
-    if (!u32_maybe_1.has_value() || !u32_maybe_2.has_value()) {
-        // TODO Error handling
-        VERIFY_NOT_REACHED();
-    }
+    if (!u32_maybe_1.has_value() || !u32_maybe_2.has_value())
+        return invalid_type_for_numeric_operator(AST::BinaryOperator::BitwiseAnd);
     return Value(u32_maybe_1.value() & u32_maybe_2.value());
 }
 

+ 10 - 9
Userland/Libraries/LibSQL/Value.h

@@ -12,6 +12,7 @@
 #include <AK/String.h>
 #include <AK/Variant.h>
 #include <LibSQL/Forward.h>
+#include <LibSQL/Result.h>
 #include <LibSQL/TupleDescriptor.h>
 #include <LibSQL/Type.h>
 #include <LibSQL/ValueImpl.h>
@@ -118,15 +119,15 @@ public:
     bool operator>(Value const&) const;
     bool operator>=(Value const&) const;
 
-    Value add(Value const&) const;
-    Value subtract(Value const&) const;
-    Value multiply(Value const&) const;
-    Value divide(Value const&) const;
-    Value modulo(Value const&) const;
-    Value shift_left(Value const&) const;
-    Value shift_right(Value const&) const;
-    Value bitwise_or(Value const&) const;
-    Value bitwise_and(Value const&) const;
+    ResultOr<Value> add(Value const&) const;
+    ResultOr<Value> subtract(Value const&) const;
+    ResultOr<Value> multiply(Value const&) const;
+    ResultOr<Value> divide(Value const&) const;
+    ResultOr<Value> modulo(Value const&) const;
+    ResultOr<Value> shift_left(Value const&) const;
+    ResultOr<Value> shift_right(Value const&) const;
+    ResultOr<Value> bitwise_or(Value const&) const;
+    ResultOr<Value> bitwise_and(Value const&) const;
 
     [[nodiscard]] TupleElementDescriptor descriptor() const
     {