فهرست منبع

LibSQL: Parse and execute sequential placeholder values

This partially implements SQLite's bind-parameter expression to support
indicating placeholder values in a SQL statement. For example:

    INSERT INTO table VALUES (42, ?);

In the above statement, the '?' identifier is a placeholder. This will
allow clients to compile statements a single time while running those
statements any number of times with different placeholder values.

Further, this will help mitigate SQL injection attacks.
Timothy Flynn 2 سال پیش
والد
کامیت
b2b9ae27fd

+ 13 - 0
Tests/LibSQL/TestSqlExpressionParser.cpp

@@ -131,6 +131,19 @@ TEST_CASE(null_literal)
     validate("NULL"sv);
 }
 
+TEST_CASE(bind_parameter)
+{
+    auto validate = [](StringView sql) {
+        auto result = parse(sql);
+        EXPECT(!result.is_error());
+
+        auto expression = result.release_value();
+        EXPECT(is<SQL::AST::Placeholder>(*expression));
+    };
+
+    validate("?"sv);
+}
+
 TEST_CASE(column_name)
 {
     EXPECT(parse(".column_name"sv).is_error());

+ 63 - 4
Tests/LibSQL/TestSqlStatementExecution.cpp

@@ -21,19 +21,19 @@ namespace {
 
 constexpr char const* db_name = "/tmp/test.db";
 
-SQL::ResultOr<SQL::ResultSet> try_execute(NonnullRefPtr<SQL::Database> database, DeprecatedString const& sql)
+SQL::ResultOr<SQL::ResultSet> try_execute(NonnullRefPtr<SQL::Database> database, DeprecatedString const& sql, Vector<SQL::Value> placeholder_values = {})
 {
     auto parser = SQL::AST::Parser(SQL::AST::Lexer(sql));
     auto statement = parser.next_statement();
     EXPECT(!parser.has_errors());
     if (parser.has_errors())
         outln("{}", parser.errors()[0].to_deprecated_string());
-    return statement->execute(move(database));
+    return statement->execute(move(database), placeholder_values);
 }
 
-SQL::ResultSet execute(NonnullRefPtr<SQL::Database> database, DeprecatedString const& sql)
+SQL::ResultSet execute(NonnullRefPtr<SQL::Database> database, DeprecatedString const& sql, Vector<SQL::Value> placeholder_values = {})
 {
-    auto result = try_execute(move(database), sql);
+    auto result = try_execute(move(database), sql, move(placeholder_values));
     if (result.is_error()) {
         outln("{}", result.release_error().error_string());
         VERIFY_NOT_REACHED();
@@ -41,6 +41,12 @@ SQL::ResultSet execute(NonnullRefPtr<SQL::Database> database, DeprecatedString c
     return result.release_value();
 }
 
+template<typename... Args>
+Vector<SQL::Value> placeholders(Args&&... args)
+{
+    return { SQL::Value(forward<Args>(args))... };
+}
+
 void create_schema(NonnullRefPtr<SQL::Database> database)
 {
     auto result = execute(database, "CREATE SCHEMA TestSchema;");
@@ -175,6 +181,59 @@ TEST_CASE(insert_without_column_names)
     EXPECT_EQ(rows_or_error.value().size(), 2u);
 }
 
+TEST_CASE(insert_with_placeholders)
+{
+    ScopeGuard guard([]() { unlink(db_name); });
+
+    auto database = SQL::Database::construct(db_name);
+    EXPECT(!database->open().is_error());
+    create_table(database);
+
+    {
+        auto result = try_execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?);");
+        EXPECT(result.is_error());
+        EXPECT_EQ(result.error().error(), SQL::SQLErrorCode::InvalidNumberOfPlaceholderValues);
+
+        result = try_execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?);", placeholders("Test_1"sv));
+        EXPECT(result.is_error());
+        EXPECT_EQ(result.error().error(), SQL::SQLErrorCode::InvalidNumberOfPlaceholderValues);
+
+        result = try_execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?);", placeholders(42, 42));
+        EXPECT(result.is_error());
+        EXPECT_EQ(result.error().error(), SQL::SQLErrorCode::InvalidValueType);
+
+        result = try_execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?);", placeholders("Test_1"sv, "Test_2"sv));
+        EXPECT(result.is_error());
+        EXPECT_EQ(result.error().error(), SQL::SQLErrorCode::InvalidValueType);
+    }
+    {
+        auto result = execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?);", placeholders("Test_1"sv, 42));
+        EXPECT_EQ(result.size(), 1u);
+
+        result = execute(database, "SELECT TextColumn, IntColumn FROM TestSchema.TestTable ORDER BY TextColumn;");
+        EXPECT_EQ(result.size(), 1u);
+
+        EXPECT_EQ(result[0].row[0], "Test_1"sv);
+        EXPECT_EQ(result[0].row[1], 42);
+    }
+    {
+        auto result = execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?), (?, ?);", placeholders("Test_2"sv, 43, "Test_3"sv, 44));
+        EXPECT_EQ(result.size(), 2u);
+
+        result = execute(database, "SELECT TextColumn, IntColumn FROM TestSchema.TestTable ORDER BY TextColumn;");
+        EXPECT_EQ(result.size(), 3u);
+
+        EXPECT_EQ(result[0].row[0], "Test_1"sv);
+        EXPECT_EQ(result[0].row[1], 42);
+
+        EXPECT_EQ(result[1].row[0], "Test_2"sv);
+        EXPECT_EQ(result[1].row[1], 43);
+
+        EXPECT_EQ(result[2].row[0], "Test_3"sv);
+        EXPECT_EQ(result[2].row[1], 44);
+    }
+}
+
 TEST_CASE(select_from_empty_table)
 {
     ScopeGuard guard([]() { unlink(db_name); });

+ 7 - 0
Tests/LibSQL/TestSqlStatementParser.cpp

@@ -752,6 +752,13 @@ TEST_CASE(nested_subquery_limit)
     EXPECT(parse(DeprecatedString::formatted("SELECT * FROM ({});"sv, subquery)).is_error());
 }
 
+TEST_CASE(bound_parameter_limit)
+{
+    auto subquery = DeprecatedString::repeated("?, "sv, SQL::AST::Limits::maximum_bound_parameters);
+    EXPECT(!parse(DeprecatedString::formatted("INSERT INTO table_name VALUES ({}42);"sv, subquery)).is_error());
+    EXPECT(parse(DeprecatedString::formatted("INSERT INTO table_name VALUES ({}?);"sv, subquery)).is_error());
+}
+
 TEST_CASE(describe_table)
 {
     EXPECT(parse("DESCRIBE"sv).is_error());

+ 18 - 2
Userland/Libraries/LibSQL/AST/AST.h

@@ -300,7 +300,8 @@ private:
 
 struct ExecutionContext {
     NonnullRefPtr<Database> database;
-    class Statement const* statement;
+    Statement const* statement { nullptr };
+    Span<Value const> placeholder_values {};
     Tuple* current_row { nullptr };
 };
 
@@ -361,6 +362,21 @@ public:
     virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
 };
 
+class Placeholder : public Expression {
+public:
+    explicit Placeholder(size_t parameter_index)
+        : m_parameter_index(parameter_index)
+    {
+    }
+
+    size_t parameter_index() const { return m_parameter_index; }
+
+    virtual ResultOr<Value> evaluate(ExecutionContext&) const override;
+
+private:
+    size_t m_parameter_index { 0 };
+};
+
 class NestedExpression : public Expression {
 public:
     NonnullRefPtr<Expression> const& expression() const { return m_expression; }
@@ -729,7 +745,7 @@ private:
 
 class Statement : public ASTNode {
 public:
-    ResultOr<ResultSet> execute(AK::NonnullRefPtr<Database> database) const;
+    ResultOr<ResultSet> execute(AK::NonnullRefPtr<Database> database, Span<Value const> placeholder_values = {}) const;
 
     virtual ResultOr<ResultSet> execute(ExecutionContext&) const
     {

+ 7 - 0
Userland/Libraries/LibSQL/AST/Expression.cpp

@@ -29,6 +29,13 @@ ResultOr<Value> NullLiteral::evaluate(ExecutionContext&) const
     return Value {};
 }
 
+ResultOr<Value> Placeholder::evaluate(ExecutionContext& context) const
+{
+    if (parameter_index() >= context.placeholder_values.size())
+        return Result { SQLCommand::Unknown, SQLErrorCode::InvalidNumberOfPlaceholderValues };
+    return context.placeholder_values[parameter_index()];
+}
+
 ResultOr<Value> NestedExpression::evaluate(ExecutionContext& context) const
 {
     return expression()->evaluate(context);

+ 18 - 1
Userland/Libraries/LibSQL/AST/Parser.cpp

@@ -401,7 +401,6 @@ NonnullRefPtr<Expression> Parser::parse_expression()
     if (match_secondary_expression())
         expression = parse_secondary_expression(move(expression));
 
-    // FIXME: Parse 'bind-parameter'.
     // FIXME: Parse 'function-name'.
     // FIXME: Parse 'raise-function'.
 
@@ -414,6 +413,9 @@ NonnullRefPtr<Expression> Parser::parse_primary_expression()
     if (auto expression = parse_literal_value_expression())
         return expression.release_nonnull();
 
+    if (auto expression = parse_bind_parameter_expression())
+        return expression.release_nonnull();
+
     if (auto expression = parse_column_name_expression())
         return expression.release_nonnull();
 
@@ -528,6 +530,21 @@ RefPtr<Expression> Parser::parse_literal_value_expression()
     return {};
 }
 
+// https://sqlite.org/lang_expr.html#varparam
+RefPtr<Expression> Parser::parse_bind_parameter_expression()
+{
+    // FIXME: Support ?NNN, :AAAA, @AAAA, and $AAAA forms.
+    if (consume_if(TokenType::Placeholder)) {
+        auto parameter = m_parser_state.m_bound_parameters;
+        if (++m_parser_state.m_bound_parameters > Limits::maximum_bound_parameters)
+            syntax_error(DeprecatedString::formatted("Exceeded maximum number of bound parameters {}", Limits::maximum_bound_parameters));
+
+        return create_ast_node<Placeholder>(parameter);
+    }
+
+    return {};
+}
+
 RefPtr<Expression> Parser::parse_column_name_expression(DeprecatedString with_parsed_identifier, bool with_parsed_period)
 {
     if (with_parsed_identifier.is_null() && !match(TokenType::Identifier))

+ 3 - 0
Userland/Libraries/LibSQL/AST/Parser.h

@@ -19,6 +19,7 @@ namespace Limits {
 // https://www.sqlite.org/limits.html
 constexpr size_t maximum_expression_tree_depth = 1000;
 constexpr size_t maximum_subquery_depth = 100;
+constexpr size_t maximum_bound_parameters = 1000;
 }
 
 class Parser {
@@ -52,6 +53,7 @@ private:
         Vector<Error> m_errors;
         size_t m_current_expression_depth { 0 };
         size_t m_current_subquery_depth { 0 };
+        size_t m_bound_parameters { 0 };
     };
 
     NonnullRefPtr<Statement> parse_statement();
@@ -71,6 +73,7 @@ private:
     NonnullRefPtr<Expression> parse_secondary_expression(NonnullRefPtr<Expression> primary);
     bool match_secondary_expression() const;
     RefPtr<Expression> parse_literal_value_expression();
+    RefPtr<Expression> parse_bind_parameter_expression();
     RefPtr<Expression> parse_column_name_expression(DeprecatedString with_parsed_identifier = {}, bool with_parsed_period = false);
     RefPtr<Expression> parse_unary_operator_expression();
     RefPtr<Expression> parse_binary_operator_expression(NonnullRefPtr<Expression> lhs);

+ 2 - 2
Userland/Libraries/LibSQL/AST/Statement.cpp

@@ -11,9 +11,9 @@
 
 namespace SQL::AST {
 
-ResultOr<ResultSet> Statement::execute(AK::NonnullRefPtr<Database> database) const
+ResultOr<ResultSet> Statement::execute(AK::NonnullRefPtr<Database> database, Span<Value const> placeholder_values) const
 {
-    ExecutionContext context { move(database), this, nullptr };
+    ExecutionContext context { move(database), this, placeholder_values, nullptr };
     auto result = TRY(execute(context));
 
     // FIXME: When transactional sessions are supported, don't auto-commit modifications.

+ 1 - 0
Userland/Libraries/LibSQL/AST/Token.h

@@ -171,6 +171,7 @@ namespace SQL::AST {
     __ENUMERATE_SQL_TOKEN("_blob_", BlobLiteral, Blob)                    \
     __ENUMERATE_SQL_TOKEN("_eof_", Eof, Invalid)                          \
     __ENUMERATE_SQL_TOKEN("_invalid_", Invalid, Invalid)                  \
+    __ENUMERATE_SQL_TOKEN("?", Placeholder, Operator)                     \
     __ENUMERATE_SQL_TOKEN("&", Ampersand, Operator)                       \
     __ENUMERATE_SQL_TOKEN("*", Asterisk, Operator)                        \
     __ENUMERATE_SQL_TOKEN(",", Comma, Punctuation)                        \

+ 22 - 21
Userland/Libraries/LibSQL/Result.h

@@ -41,27 +41,28 @@ constexpr char const* command_tag(SQLCommand command)
     }
 }
 
-#define ENUMERATE_SQL_ERRORS(S)                                                          \
-    S(NoError, "No error")                                                               \
-    S(InternalError, "{}")                                                               \
-    S(NotYetImplemented, "{}")                                                           \
-    S(DatabaseUnavailable, "Database Unavailable")                                       \
-    S(StatementUnavailable, "Statement with id '{}' Unavailable")                        \
-    S(SyntaxError, "Syntax Error")                                                       \
-    S(DatabaseDoesNotExist, "Database '{}' does not exist")                              \
-    S(SchemaDoesNotExist, "Schema '{}' does not exist")                                  \
-    S(SchemaExists, "Schema '{}' already exist")                                         \
-    S(TableDoesNotExist, "Table '{}' does not exist")                                    \
-    S(ColumnDoesNotExist, "Column '{}' does not exist")                                  \
-    S(AmbiguousColumnName, "Column name '{}' is ambiguous")                              \
-    S(TableExists, "Table '{}' already exist")                                           \
-    S(InvalidType, "Invalid type '{}'")                                                  \
-    S(InvalidDatabaseName, "Invalid database name '{}'")                                 \
-    S(InvalidValueType, "Invalid type for attribute '{}'")                               \
-    S(InvalidNumberOfValues, "Number of values does not match number of columns")        \
-    S(BooleanOperatorTypeMismatch, "Cannot apply '{}' operator to non-boolean operands") \
-    S(NumericOperatorTypeMismatch, "Cannot apply '{}' operator to non-numeric operands") \
-    S(IntegerOperatorTypeMismatch, "Cannot apply '{}' operator to non-numeric operands") \
+#define ENUMERATE_SQL_ERRORS(S)                                                                   \
+    S(NoError, "No error")                                                                        \
+    S(InternalError, "{}")                                                                        \
+    S(NotYetImplemented, "{}")                                                                    \
+    S(DatabaseUnavailable, "Database Unavailable")                                                \
+    S(StatementUnavailable, "Statement with id '{}' Unavailable")                                 \
+    S(SyntaxError, "Syntax Error")                                                                \
+    S(DatabaseDoesNotExist, "Database '{}' does not exist")                                       \
+    S(SchemaDoesNotExist, "Schema '{}' does not exist")                                           \
+    S(SchemaExists, "Schema '{}' already exist")                                                  \
+    S(TableDoesNotExist, "Table '{}' does not exist")                                             \
+    S(ColumnDoesNotExist, "Column '{}' does not exist")                                           \
+    S(AmbiguousColumnName, "Column name '{}' is ambiguous")                                       \
+    S(TableExists, "Table '{}' already exist")                                                    \
+    S(InvalidType, "Invalid type '{}'")                                                           \
+    S(InvalidDatabaseName, "Invalid database name '{}'")                                          \
+    S(InvalidValueType, "Invalid type for attribute '{}'")                                        \
+    S(InvalidNumberOfPlaceholderValues, "Number of values does not match number of placeholders") \
+    S(InvalidNumberOfValues, "Number of values does not match number of columns")                 \
+    S(BooleanOperatorTypeMismatch, "Cannot apply '{}' operator to non-boolean operands")          \
+    S(NumericOperatorTypeMismatch, "Cannot apply '{}' operator to non-numeric operands")          \
+    S(IntegerOperatorTypeMismatch, "Cannot apply '{}' operator to non-numeric operands")          \
     S(InvalidOperator, "Invalid operator '{}'")
 
 enum class SQLErrorCode {