Parcourir la source

LibSQL: Add better error handling to `evaluate` and `execute` methods

There was a lot of `VERIFY_NOT_REACHED` error handling going on. Fixed
most of those.

A bit of a caveat is that after every `evaluate` call for expressions
that are part of a statement the error status of the `SQLResult` return
value must be called.
Jan de Visser il y a 3 ans
Parent
commit
9022cf99ff

+ 61 - 23
Userland/Libraries/LibSQL/AST/AST.h

@@ -437,13 +437,32 @@ private:
     String m_column_name;
 };
 
+#define __enum_UnaryOperator(S) \
+    S(Minus, "-")               \
+    S(Plus, "+")                \
+    S(BitwiseNot, "~")          \
+    S(Not, "NOT")
+
 enum class UnaryOperator {
-    Minus,
-    Plus,
-    BitwiseNot,
-    Not,
+#undef __UnaryOperator
+#define __UnaryOperator(code, name) code,
+    __enum_UnaryOperator(__UnaryOperator)
+#undef __UnaryOperator
 };
 
+constexpr char const* UnaryOperator_name(UnaryOperator op)
+{
+    switch (op) {
+#undef __UnaryOperator
+#define __UnaryOperator(code, name) \
+    case UnaryOperator::code:       \
+        return name;
+        __enum_UnaryOperator(__UnaryOperator)
+#undef __UnaryOperator
+            default : VERIFY_NOT_REACHED();
+    }
+}
+
 class UnaryOperatorExpression : public NestedExpression {
 public:
     UnaryOperatorExpression(UnaryOperator type, NonnullRefPtr<Expression> expression)
@@ -459,28 +478,47 @@ private:
     UnaryOperator m_type;
 };
 
+// Note: These are in order of highest-to-lowest operator precedence.
+#define __enum_BinaryOperator(S) \
+    S(Concatenate, "||")         \
+    S(Multiplication, "*")       \
+    S(Division, "/")             \
+    S(Modulo, "%")               \
+    S(Plus, "+")                 \
+    S(Minus, "-")                \
+    S(ShiftLeft, "<<")           \
+    S(ShiftRight, ">>")          \
+    S(BitwiseAnd, "&")           \
+    S(BitwiseOr, "|")            \
+    S(LessThan, "<")             \
+    S(LessThanEquals, "<=")      \
+    S(GreaterThan, ">")          \
+    S(GreaterThanEquals, ">=")   \
+    S(Equals, "=")               \
+    S(NotEquals, "!=")           \
+    S(And, "and")                \
+    S(Or, "or")
+
 enum class BinaryOperator {
-    // Note: These are in order of highest-to-lowest operator precedence.
-    Concatenate,
-    Multiplication,
-    Division,
-    Modulo,
-    Plus,
-    Minus,
-    ShiftLeft,
-    ShiftRight,
-    BitwiseAnd,
-    BitwiseOr,
-    LessThan,
-    LessThanEquals,
-    GreaterThan,
-    GreaterThanEquals,
-    Equals,
-    NotEquals,
-    And,
-    Or,
+#undef __BinaryOperator
+#define __BinaryOperator(code, name) code,
+    __enum_BinaryOperator(__BinaryOperator)
+#undef __BinaryOperator
 };
 
+constexpr char const* BinaryOperator_name(BinaryOperator op)
+{
+    switch (op) {
+#undef __BinaryOperator
+#define __BinaryOperator(code, name) \
+    case BinaryOperator::code:       \
+        return name;
+        __enum_BinaryOperator(__BinaryOperator)
+#undef __BinaryOperator
+            default : VERIFY_NOT_REACHED();
+    }
+}
+
 class BinaryOperatorExpression : public NestedDoubleExpression {
 public:
     BinaryOperatorExpression(BinaryOperator type, NonnullRefPtr<Expression> lhs, NonnullRefPtr<Expression> rhs)

+ 30 - 17
Userland/Libraries/LibSQL/AST/Expression.cpp

@@ -14,15 +14,19 @@ Value Expression::evaluate(ExecutionContext&) const
     return Value::null();
 }
 
-Value NumericLiteral::evaluate(ExecutionContext&) const
+Value NumericLiteral::evaluate(ExecutionContext& context) const
 {
+    if (context.result->has_error())
+        return Value::null();
     Value ret(SQLType::Float);
     ret = value();
     return ret;
 }
 
-Value StringLiteral::evaluate(ExecutionContext&) const
+Value StringLiteral::evaluate(ExecutionContext& context) const
 {
+    if (context.result->has_error())
+        return Value::null();
     Value ret(SQLType::Text);
     ret = value();
     return ret;
@@ -35,11 +39,15 @@ Value NullLiteral::evaluate(ExecutionContext&) const
 
 Value NestedExpression::evaluate(ExecutionContext& context) const
 {
+    if (context.result->has_error())
+        return Value::null();
     return expression()->evaluate(context);
 }
 
 Value ChainedExpression::evaluate(ExecutionContext& context) const
 {
+    if (context.result->has_error())
+        return Value::null();
     Value ret(SQLType::Tuple);
     Vector<Value> values;
     for (auto& expression : expressions()) {
@@ -51,6 +59,8 @@ Value ChainedExpression::evaluate(ExecutionContext& context) const
 
 Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const
 {
+    if (context.result->has_error())
+        return Value::null();
     Value lhs_value = lhs()->evaluate(context);
     Value rhs_value = rhs()->evaluate(context);
     switch (type()) {
@@ -97,8 +107,8 @@ Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const
         auto lhs_bool_maybe = lhs_value.to_bool();
         auto rhs_bool_maybe = rhs_value.to_bool();
         if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) {
-            // TODO Error handling
-            VERIFY_NOT_REACHED();
+            context.result->set_error(SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()));
+            return Value::null();
         }
         return Value(lhs_bool_maybe.value() && rhs_bool_maybe.value());
     }
@@ -106,24 +116,27 @@ Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const
         auto lhs_bool_maybe = lhs_value.to_bool();
         auto rhs_bool_maybe = rhs_value.to_bool();
         if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) {
-            // TODO Error handling
-            VERIFY_NOT_REACHED();
+            context.result->set_error(SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()));
+            return Value::null();
         }
         return Value(lhs_bool_maybe.value() || rhs_bool_maybe.value());
     }
+    default:
+        VERIFY_NOT_REACHED();
     }
-    VERIFY_NOT_REACHED();
 }
 
 Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const
 {
+    if (context.result->has_error())
+        return Value::null();
     Value expression_value = NestedExpression::evaluate(context);
     switch (type()) {
     case UnaryOperator::Plus:
         if (expression_value.type() == SQLType::Integer || expression_value.type() == SQLType::Float)
             return expression_value;
-        // TODO: Error handling.
-        VERIFY_NOT_REACHED();
+        context.result->set_error(SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()));
+        return Value::null();
     case UnaryOperator::Minus:
         if (expression_value.type() == SQLType::Integer) {
             expression_value = -int(expression_value);
@@ -133,22 +146,22 @@ Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const
             expression_value = -double(expression_value);
             return expression_value;
         }
-        // TODO: Error handling.
-        VERIFY_NOT_REACHED();
+        context.result->set_error(SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()));
+        return Value::null();
     case UnaryOperator::Not:
         if (expression_value.type() == SQLType::Boolean) {
             expression_value = !bool(expression_value);
             return expression_value;
         }
-        // TODO: Error handling.
-        VERIFY_NOT_REACHED();
+        context.result->set_error(SQLErrorCode::BooleanOperatorTypeMismatch, UnaryOperator_name(type()));
+        return Value::null();
     case UnaryOperator::BitwiseNot:
         if (expression_value.type() == SQLType::Integer) {
             expression_value = ~u32(expression_value);
             return expression_value;
         }
-        // TODO: Error handling.
-        VERIFY_NOT_REACHED();
+        context.result->set_error(SQLErrorCode::IntegerOperatorTypeMismatch, UnaryOperator_name(type()));
+        return Value::null();
     }
     VERIFY_NOT_REACHED();
 }
@@ -162,8 +175,8 @@ Value ColumnNameExpression::evaluate(ExecutionContext& context) const
         if (column_descriptor.name == column_name())
             return { (*context.current_row)[ix] };
     }
-    // TODO: Error handling.
-    VERIFY_NOT_REACHED();
+    context.result->set_error(SQLErrorCode::ColumnDoesNotExist, column_name());
+    return Value::null();
 }
 
 }

+ 4 - 0
Userland/Libraries/LibSQL/AST/Insert.cpp

@@ -44,6 +44,7 @@ RefPtr<SQLResult> Insert::execute(ExecutionContext& context) const
 
     Vector<Row> inserted_rows;
     inserted_rows.ensure_capacity(m_chained_expressions.size());
+    context.result = SQLResult::construct();
     for (auto& row_expr : m_chained_expressions) {
         for (auto& column_def : table_def->columns()) {
             if (!m_column_names.contains_slow(column_def.name())) {
@@ -51,6 +52,8 @@ RefPtr<SQLResult> Insert::execute(ExecutionContext& context) const
             }
         }
         auto row_value = row_expr.evaluate(context);
+        if (context.result->has_error())
+            return context.result;
         VERIFY(row_value.type() == SQLType::Tuple);
         auto values = row_value.to_vector().value();
 
@@ -76,6 +79,7 @@ RefPtr<SQLResult> Insert::execute(ExecutionContext& context) const
 
     for (auto& inserted_row : inserted_rows) {
         context.database->insert(inserted_row);
+        // FIXME Error handling
     }
 
     return SQLResult::construct(SQLCommand::Insert, 0, m_chained_expressions.size(), 0);

+ 4 - 0
Userland/Libraries/LibSQL/AST/Select.cpp

@@ -39,12 +39,16 @@ RefPtr<SQLResult> Select::execute(ExecutionContext& context) const
             context.current_row = &row;
             if (where_clause()) {
                 auto where_result = where_clause()->evaluate(context);
+                if (context.result->has_error())
+                    return context.result;
                 if (!where_result)
                     continue;
             }
             tuple.clear();
             for (auto& col : columns) {
                 auto value = col.expression()->evaluate(context);
+                if (context.result->has_error())
+                    return context.result;
                 tuple.append(value);
             }
             context.result->append(tuple);

+ 26 - 15
Userland/Libraries/LibSQL/SQLResult.h

@@ -41,21 +41,25 @@ constexpr char const* command_tag(SQLCommand command)
     }
 }
 
-#define ENUMERATE_SQL_ERRORS(S)                                   \
-    S(NoError, "No error")                                        \
-    S(DatabaseUnavailable, "Database Unavailable")                \
-    S(StatementUnavailable, "Statement with id '{}' Unavailable") \
-    S(SyntaxError, "Syntax Error")                                \
-    S(DatabaseDoesNotExist, "Database '{}' does not exist")       \
-    S(SchemaDoesNotExist, "Schema '{}' does not exist")           \
-    S(SchemaExists, "Schema '{}' already exist")                  \
-    S(TableDoesNotExist, "Table '{}' does not exist")             \
-    S(ColumnDoesNotExist, "Column '{}' does not exist")           \
-    S(TableExists, "Table '{}' already exist")                    \
-    S(InvalidType, "Invalid type '{}'")                           \
-    S(InvalidDatabaseName, "Invalid database name '{}'")          \
-    S(InvalidValueType, "Invalid type for attribute '{}'")        \
-    S(InvalidNumberOfValues, "Number of values does not match number of columns")
+#define ENUMERATE_SQL_ERRORS(S)                                                          \
+    S(NoError, "No error")                                                               \
+    S(DatabaseUnavailable, "Database Unavailable")                                       \
+    S(StatementUnavailable, "Statement with id '{}' Unavailable")                        \
+    S(SyntaxError, "Syntax Error")                                                       \
+    S(DatabaseDoesNotExist, "Database '{}' does not exist")                              \
+    S(SchemaDoesNotExist, "Schema '{}' does not exist")                                  \
+    S(SchemaExists, "Schema '{}' already exist")                                         \
+    S(TableDoesNotExist, "Table '{}' does not exist")                                    \
+    S(ColumnDoesNotExist, "Column '{}' does not exist")                                  \
+    S(TableExists, "Table '{}' already exist")                                           \
+    S(InvalidType, "Invalid type '{}'")                                                  \
+    S(InvalidDatabaseName, "Invalid database name '{}'")                                 \
+    S(InvalidValueType, "Invalid type for attribute '{}'")                               \
+    S(InvalidNumberOfValues, "Number of values does not match number of columns")        \
+    S(BooleanOperatorTypeMismatch, "Cannot apply '{}' operator to non-boolean operands") \
+    S(NumericOperatorTypeMismatch, "Cannot apply '{}' operator to non-numeric operands") \
+    S(IntegerOperatorTypeMismatch, "Cannot apply '{}' operator to non-numeric operands") \
+    S(InvalidOperator, "Invalid operator '{}'")
 
 enum class SQLErrorCode {
 #undef __ENUMERATE_SQL_ERROR
@@ -113,6 +117,13 @@ public:
     int updated() const { return m_update_count; }
     int inserted() const { return m_insert_count; }
     int deleted() const { return m_delete_count; }
+    void set_error(SQLErrorCode code, String argument = {})
+    {
+        m_error.code = code;
+        m_error.error_argument = argument;
+    }
+
+    bool has_error() const { return m_error.code != SQLErrorCode::NoError; }
     SQLError const& error() const { return m_error; }
     bool has_results() const { return m_has_results; }
     Vector<Tuple> const& results() const { return m_result_set; }