Explorar o código

LibSQL: Convert SQL expression evaluation to use ResultOr

Instead of setting an error in the execution context, we can directly
return that error or the successful value. This lets all callers, who
were already TRY-capable, simply TRY the expression evaluation.
Timothy Flynn %!s(int64=3) %!d(string=hai) anos
pai
achega
f3c6cb40d7

+ 10 - 10
Userland/Libraries/LibSQL/AST/AST.h

@@ -306,7 +306,7 @@ struct ExecutionContext {
 
 class Expression : public ASTNode {
 public:
-    virtual Value evaluate(ExecutionContext&) const;
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const;
 };
 
 class ErrorExpression final : public Expression {
@@ -320,7 +320,7 @@ public:
     }
 
     double value() const { return m_value; }
-    virtual Value evaluate(ExecutionContext&) const override;
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
 
 private:
     double m_value;
@@ -334,7 +334,7 @@ public:
     }
 
     const String& value() const { return m_value; }
-    virtual Value evaluate(ExecutionContext&) const override;
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
 
 private:
     String m_value;
@@ -355,13 +355,13 @@ private:
 
 class NullLiteral : public Expression {
 public:
-    virtual Value evaluate(ExecutionContext&) const override;
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
 };
 
 class NestedExpression : public Expression {
 public:
     const NonnullRefPtr<Expression>& expression() const { return m_expression; }
-    virtual Value evaluate(ExecutionContext&) const override;
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
 
 protected:
     explicit NestedExpression(NonnullRefPtr<Expression> expression)
@@ -432,7 +432,7 @@ public:
     const String& schema_name() const { return m_schema_name; }
     const String& table_name() const { return m_table_name; }
     const String& column_name() const { return m_column_name; }
-    virtual Value evaluate(ExecutionContext&) const override;
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
 
 private:
     String m_schema_name;
@@ -475,7 +475,7 @@ public:
     }
 
     UnaryOperator type() const { return m_type; }
-    virtual Value evaluate(ExecutionContext&) const override;
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
 
 private:
     UnaryOperator m_type;
@@ -531,7 +531,7 @@ public:
     }
 
     BinaryOperator type() const { return m_type; }
-    virtual Value evaluate(ExecutionContext&) const override;
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
 
 private:
     BinaryOperator m_type;
@@ -545,7 +545,7 @@ public:
     }
 
     const NonnullRefPtrVector<Expression>& expressions() const { return m_expressions; }
-    virtual Value evaluate(ExecutionContext&) const override;
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
 
 private:
     NonnullRefPtrVector<Expression> m_expressions;
@@ -638,7 +638,7 @@ public:
 
     MatchOperator type() const { return m_type; }
     const RefPtr<Expression>& escape() const { return m_escape; }
-    virtual Value evaluate(ExecutionContext&) const override;
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
 
 private:
     MatchOperator m_type;

+ 50 - 72
Userland/Libraries/LibSQL/AST/Expression.cpp

@@ -12,66 +12,55 @@ namespace SQL::AST {
 
 static const String s_posix_basic_metacharacters = ".^$*[]+\\";
 
-Value Expression::evaluate(ExecutionContext&) const
+ResultOr<Value> Expression::evaluate(ExecutionContext&) const
 {
     return Value::null();
 }
 
-Value NumericLiteral::evaluate(ExecutionContext& context) const
+ResultOr<Value> NumericLiteral::evaluate(ExecutionContext&) const
 {
-    if (context.result->is_error())
-        return Value::null();
     Value ret(SQLType::Float);
     ret = value();
     return ret;
 }
 
-Value StringLiteral::evaluate(ExecutionContext& context) const
+ResultOr<Value> StringLiteral::evaluate(ExecutionContext&) const
 {
-    if (context.result->is_error())
-        return Value::null();
     Value ret(SQLType::Text);
     ret = value();
     return ret;
 }
 
-Value NullLiteral::evaluate(ExecutionContext&) const
+ResultOr<Value> NullLiteral::evaluate(ExecutionContext&) const
 {
     return Value::null();
 }
 
-Value NestedExpression::evaluate(ExecutionContext& context) const
+ResultOr<Value> NestedExpression::evaluate(ExecutionContext& context) const
 {
-    if (context.result->is_error())
-        return Value::null();
     return expression()->evaluate(context);
 }
 
-Value ChainedExpression::evaluate(ExecutionContext& context) const
+ResultOr<Value> ChainedExpression::evaluate(ExecutionContext& context) const
 {
-    if (context.result->is_error())
-        return Value::null();
     Value ret(SQLType::Tuple);
     Vector<Value> values;
-    for (auto& expression : expressions()) {
-        values.append(expression.evaluate(context));
-    }
+    for (auto& expression : expressions())
+        values.append(TRY(expression.evaluate(context)));
     ret = values;
     return ret;
 }
 
-Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const
+ResultOr<Value> BinaryOperatorExpression::evaluate(ExecutionContext& context) const
 {
-    if (context.result->is_error())
-        return Value::null();
-    Value lhs_value = lhs()->evaluate(context);
-    Value rhs_value = rhs()->evaluate(context);
+    Value lhs_value = TRY(lhs()->evaluate(context));
+    Value rhs_value = TRY(rhs()->evaluate(context));
+
     switch (type()) {
     case BinaryOperator::Concatenate: {
-        if (lhs_value.type() != SQLType::Text) {
-            context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) };
-            return Value::null();
-        }
+        if (lhs_value.type() != SQLType::Text)
+            return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) };
+
         AK::StringBuilder builder;
         builder.append(lhs_value.to_string());
         builder.append(rhs_value.to_string());
@@ -110,19 +99,17 @@ Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const
     case BinaryOperator::And: {
         auto lhs_bool_maybe = lhs_value.to_bool();
         auto rhs_bool_maybe = rhs_value.to_bool();
-        if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) {
-            context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) };
-            return Value::null();
-        }
+        if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value())
+            return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) };
+
         return Value(lhs_bool_maybe.release_value() && rhs_bool_maybe.release_value());
     }
     case BinaryOperator::Or: {
         auto lhs_bool_maybe = lhs_value.to_bool();
         auto rhs_bool_maybe = rhs_value.to_bool();
-        if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) {
-            context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) };
-            return Value::null();
-        }
+        if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value())
+            return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) };
+
         return Value(lhs_bool_maybe.release_value() || rhs_bool_maybe.release_value());
     }
     default:
@@ -130,17 +117,15 @@ Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const
     }
 }
 
-Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const
+ResultOr<Value> UnaryOperatorExpression::evaluate(ExecutionContext& context) const
 {
-    if (context.result->is_error())
-        return Value::null();
-    Value expression_value = NestedExpression::evaluate(context);
+    Value expression_value = TRY(NestedExpression::evaluate(context));
+
     switch (type()) {
     case UnaryOperator::Plus:
         if (expression_value.type() == SQLType::Integer || expression_value.type() == SQLType::Float)
             return expression_value;
-        context.result = Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) };
-        return Value::null();
+        return Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) };
     case UnaryOperator::Minus:
         if (expression_value.type() == SQLType::Integer) {
             expression_value = -int(expression_value);
@@ -150,32 +135,29 @@ Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const
             expression_value = -double(expression_value);
             return expression_value;
         }
-        context.result = Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) };
-        return Value::null();
+        return Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) };
     case UnaryOperator::Not:
         if (expression_value.type() == SQLType::Boolean) {
             expression_value = !bool(expression_value);
             return expression_value;
         }
-        context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, UnaryOperator_name(type()) };
-        return Value::null();
+        return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, UnaryOperator_name(type()) };
     case UnaryOperator::BitwiseNot:
         if (expression_value.type() == SQLType::Integer) {
             expression_value = ~u32(expression_value);
             return expression_value;
         }
-        context.result = Result { SQLCommand::Unknown, SQLErrorCode::IntegerOperatorTypeMismatch, UnaryOperator_name(type()) };
-        return Value::null();
+        return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOperatorTypeMismatch, UnaryOperator_name(type()) };
+    default:
+        VERIFY_NOT_REACHED();
     }
-    VERIFY_NOT_REACHED();
 }
 
-Value ColumnNameExpression::evaluate(ExecutionContext& context) const
+ResultOr<Value> ColumnNameExpression::evaluate(ExecutionContext& context) const
 {
-    if (!context.current_row) {
-        context.result = Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, column_name() };
-        return Value::null();
-    }
+    if (!context.current_row)
+        return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, column_name() };
+
     auto& descriptor = *context.current_row->descriptor();
     VERIFY(context.current_row->size() == descriptor.size());
     Optional<size_t> index_in_row;
@@ -184,34 +166,30 @@ Value ColumnNameExpression::evaluate(ExecutionContext& context) const
         if (!table_name().is_empty() && column_descriptor.table != table_name())
             continue;
         if (column_descriptor.name == column_name()) {
-            if (index_in_row.has_value()) {
-                context.result = Result { SQLCommand::Unknown, SQLErrorCode::AmbiguousColumnName, column_name() };
-                return Value::null();
-            }
+            if (index_in_row.has_value())
+                return Result { SQLCommand::Unknown, SQLErrorCode::AmbiguousColumnName, column_name() };
+
             index_in_row = ix;
         }
     }
     if (index_in_row.has_value())
         return (*context.current_row)[index_in_row.value()];
-    context.result = Result { SQLCommand::Unknown, SQLErrorCode::ColumnDoesNotExist, column_name() };
-    return Value::null();
+
+    return Result { SQLCommand::Unknown, SQLErrorCode::ColumnDoesNotExist, column_name() };
 }
 
-Value MatchExpression::evaluate(ExecutionContext& context) const
+ResultOr<Value> MatchExpression::evaluate(ExecutionContext& context) const
 {
-    if (context.result->is_error())
-        return Value::null();
     switch (type()) {
     case MatchOperator::Like: {
-        Value lhs_value = lhs()->evaluate(context);
-        Value rhs_value = rhs()->evaluate(context);
+        Value lhs_value = TRY(lhs()->evaluate(context));
+        Value rhs_value = TRY(rhs()->evaluate(context));
+
         char escape_char = '\0';
         if (escape()) {
-            auto escape_str = escape()->evaluate(context).to_string();
-            if (escape_str.length() != 1) {
-                context.result = Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, "ESCAPE should be a single character" };
-                return Value::null();
-            }
+            auto escape_str = TRY(escape()->evaluate(context)).to_string();
+            if (escape_str.length() != 1)
+                return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, "ESCAPE should be a single character" };
             escape_char = escape_str[0];
         }
 
@@ -237,14 +215,15 @@ Value MatchExpression::evaluate(ExecutionContext& context) const
             }
         }
         builder.append('$');
+
         // FIXME: We should probably cache this regex.
         auto regex = Regex<PosixBasic>(builder.build());
         auto result = regex.match(lhs_value.to_string(), PosixFlags::Insensitive | PosixFlags::Unicode);
         return Value(invert_expression() ? !result.success : result.success);
     }
     case MatchOperator::Regexp: {
-        Value lhs_value = lhs()->evaluate(context);
-        Value rhs_value = rhs()->evaluate(context);
+        Value lhs_value = TRY(lhs()->evaluate(context));
+        Value rhs_value = TRY(rhs()->evaluate(context));
 
         auto regex = Regex<PosixExtended>(rhs_value.to_string());
         auto err = regex.parser_result.error;
@@ -253,8 +232,7 @@ Value MatchExpression::evaluate(ExecutionContext& context) const
             builder.append("Regular expression: ");
             builder.append(get_error_string(err));
 
-            context.result = Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, builder.build() };
-            return Value(false);
+            return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, builder.build() };
         }
 
         auto result = regex.match(lhs_value.to_string(), PosixFlags::Insensitive | PosixFlags::Unicode);

+ 1 - 4
Userland/Libraries/LibSQL/AST/Insert.cpp

@@ -47,10 +47,7 @@ ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const
                 row[column_def.name()] = column_def.default_value();
         }
 
-        auto row_value = row_expr.evaluate(context);
-        if (context.result->is_error())
-            return context.result.release_value();
-
+        auto row_value = TRY(row_expr.evaluate(context));
         VERIFY(row_value.type() == SQLType::Tuple);
         auto values = row_value.to_vector().value();
 

+ 5 - 11
Userland/Libraries/LibSQL/AST/Select.cpp

@@ -93,9 +93,7 @@ ResultOr<ResultSet> Select::execute(ExecutionContext& context) const
         context.current_row = &row;
 
         if (where_clause()) {
-            auto where_result = where_clause()->evaluate(context);
-            if (context.result->is_error())
-                return context.result.release_value();
+            auto where_result = TRY(where_clause()->evaluate(context));
             if (!where_result)
                 continue;
         }
@@ -103,18 +101,14 @@ ResultOr<ResultSet> Select::execute(ExecutionContext& context) const
         tuple.clear();
 
         for (auto& col : columns) {
-            auto value = col.expression()->evaluate(context);
-            if (context.result->is_error())
-                return context.result.release_value();
+            auto value = TRY(col.expression()->evaluate(context));
             tuple.append(value);
         }
 
         if (has_ordering) {
             sort_key.clear();
             for (auto& term : m_ordering_term_list) {
-                auto value = term.expression()->evaluate(context);
-                if (context.result->is_error())
-                    return context.result.release_value();
+                auto value = TRY(term.expression()->evaluate(context));
                 sort_key.append(value);
             }
         }
@@ -126,7 +120,7 @@ ResultOr<ResultSet> Select::execute(ExecutionContext& context) const
         size_t limit_value = NumericLimits<size_t>::max();
         size_t offset_value = 0;
 
-        auto limit = m_limit_clause->limit_expression()->evaluate(context);
+        auto limit = TRY(m_limit_clause->limit_expression()->evaluate(context));
         if (!limit.is_null()) {
             auto limit_value_maybe = limit.to_u32();
             if (!limit_value_maybe.has_value())
@@ -136,7 +130,7 @@ ResultOr<ResultSet> Select::execute(ExecutionContext& context) const
         }
 
         if (m_limit_clause->offset_expression() != nullptr) {
-            auto offset = m_limit_clause->offset_expression()->evaluate(context);
+            auto offset = TRY(m_limit_clause->offset_expression()->evaluate(context));
             if (!offset.is_null()) {
                 auto offset_value_maybe = offset.to_u32();
                 if (!offset_value_maybe.has_value())