Browse Source

LibSQL: Handle statements with malformed exists expressions correctly

Previously, statements containing malformed exists expressions such as:

`INSERT INTO t(a) VALUES (SELECT 1)`;

could cause the parser to crash. The parser will now return an error
message instead.
Tim Ledbetter 2 years ago
parent
commit
896d1e4f42

+ 40 - 0
Tests/LibSQL/TestSqlStatementParser.cpp

@@ -307,6 +307,14 @@ TEST_CASE(insert)
     EXPECT(parse("INSERT INTO table_name VALUES"sv).is_error());
     EXPECT(parse("INSERT INTO table_name VALUES"sv).is_error());
     EXPECT(parse("INSERT INTO table_name VALUES ();"sv).is_error());
     EXPECT(parse("INSERT INTO table_name VALUES ();"sv).is_error());
     EXPECT(parse("INSERT INTO table_name VALUES (1)"sv).is_error());
     EXPECT(parse("INSERT INTO table_name VALUES (1)"sv).is_error());
+    EXPECT(parse("INSERT INTO table_name VALUES SELECT"sv).is_error());
+    EXPECT(parse("INSERT INTO table_name VALUES EXISTS"sv).is_error());
+    EXPECT(parse("INSERT INTO table_name VALUES NOT"sv).is_error());
+    EXPECT(parse("INSERT INTO table_name VALUES EXISTS (SELECT 1)"sv).is_error());
+    EXPECT(parse("INSERT INTO table_name VALUES (SELECT)"sv).is_error());
+    EXPECT(parse("INSERT INTO table_name VALUES (EXISTS SELECT)"sv).is_error());
+    EXPECT(parse("INSERT INTO table_name VALUES ((SELECT))"sv).is_error());
+    EXPECT(parse("INSERT INTO table_name VALUES (EXISTS (SELECT))"sv).is_error());
     EXPECT(parse("INSERT INTO table_name SELECT"sv).is_error());
     EXPECT(parse("INSERT INTO table_name SELECT"sv).is_error());
     EXPECT(parse("INSERT INTO table_name SELECT * from table_name"sv).is_error());
     EXPECT(parse("INSERT INTO table_name SELECT * from table_name"sv).is_error());
     EXPECT(parse("INSERT OR INTO table_name DEFAULT VALUES;"sv).is_error());
     EXPECT(parse("INSERT OR INTO table_name DEFAULT VALUES;"sv).is_error());
@@ -367,6 +375,12 @@ TEST_CASE(insert)
     validate("INSERT INTO table_name VALUES (1, 2);"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 2 }, false);
     validate("INSERT INTO table_name VALUES (1, 2);"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 2 }, false);
     validate("INSERT INTO table_name VALUES (1, 2), (3, 4, 5);"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 2, 3 }, false);
     validate("INSERT INTO table_name VALUES (1, 2), (3, 4, 5);"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 2, 3 }, false);
 
 
+    validate("INSERT INTO table_name VALUES ((SELECT 1));"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 1 }, false);
+    validate("INSERT INTO table_name VALUES (EXISTS (SELECT 1));"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 1 }, false);
+    validate("INSERT INTO table_name VALUES (NOT EXISTS (SELECT 1));"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 1 }, false);
+    validate("INSERT INTO table_name VALUES ((SELECT 1), (SELECT 1));"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 2 }, false);
+    validate("INSERT INTO table_name VALUES ((SELECT 1), (SELECT 1)), ((SELECT 1), (SELECT 1), (SELECT 1));"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 2, 3 }, false);
+
     validate("INSERT INTO table_name SELECT * FROM table_name;"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, {}, true);
     validate("INSERT INTO table_name SELECT * FROM table_name;"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, {}, true);
 }
 }
 
 
@@ -379,11 +393,21 @@ TEST_CASE(update)
     EXPECT(parse("UPDATE table_name SET column_name=4"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4, ;"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4, ;"sv).is_error());
     EXPECT(parse("UPDATE table_name SET (column_name)=4"sv).is_error());
     EXPECT(parse("UPDATE table_name SET (column_name)=4"sv).is_error());
+    EXPECT(parse("UPDATE table_name SET (column_name)=EXISTS"sv).is_error());
+    EXPECT(parse("UPDATE table_name SET (column_name)=SELECT"sv).is_error());
+    EXPECT(parse("UPDATE table_name SET (column_name)=(SELECT)"sv).is_error());
+    EXPECT(parse("UPDATE table_name SET (column_name)=NOT (SELECT 1)"sv).is_error());
     EXPECT(parse("UPDATE table_name SET (column_name)=4, ;"sv).is_error());
     EXPECT(parse("UPDATE table_name SET (column_name)=4, ;"sv).is_error());
     EXPECT(parse("UPDATE table_name SET (column_name, )=4;"sv).is_error());
     EXPECT(parse("UPDATE table_name SET (column_name, )=4;"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 FROM"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 FROM"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 FROM table_name"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 FROM table_name"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 WHERE"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 WHERE"sv).is_error());
+    EXPECT(parse("UPDATE table_name SET column_name=4 WHERE EXISTS"sv).is_error());
+    EXPECT(parse("UPDATE table_name SET column_name=4 WHERE NOT"sv).is_error());
+    EXPECT(parse("UPDATE table_name SET column_name=4 WHERE NOT EXISTS"sv).is_error());
+    EXPECT(parse("UPDATE table_name SET column_name=4 WHERE SELECT"sv).is_error());
+    EXPECT(parse("UPDATE table_name SET column_name=4 WHERE (SELECT)"sv).is_error());
+    EXPECT(parse("UPDATE table_name SET column_name=4 WHERE NOT (SELECT)"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 WHERE 1==1"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 WHERE 1==1"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 RETURNING"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 RETURNING"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 RETURNING *"sv).is_error());
     EXPECT(parse("UPDATE table_name SET column_name=4 RETURNING *"sv).is_error());
@@ -452,11 +476,18 @@ TEST_CASE(update)
     validate("UPDATE table_name AS foo SET column_name=1;"sv, resolution, {}, "TABLE_NAME"sv, "FOO"sv, update_columns, false, false, {});
     validate("UPDATE table_name AS foo SET column_name=1;"sv, resolution, {}, "TABLE_NAME"sv, "FOO"sv, update_columns, false, false, {});
 
 
     validate("UPDATE table_name SET column_name=1;"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, false, false, {});
     validate("UPDATE table_name SET column_name=1;"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, false, false, {});
+    validate("UPDATE table_name SET column_name=(SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, false, false, {});
+    validate("UPDATE table_name SET column_name=EXISTS (SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, false, false, {});
+    validate("UPDATE table_name SET column_name=NOT EXISTS (SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, false, false, {});
     validate("UPDATE table_name SET column1=1, column2=2;"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN1"sv }, { "COLUMN2"sv } }, false, false, {});
     validate("UPDATE table_name SET column1=1, column2=2;"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN1"sv }, { "COLUMN2"sv } }, false, false, {});
     validate("UPDATE table_name SET (column1, column2)=1, column3=2;"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN1"sv, "COLUMN2"sv }, { "COLUMN3"sv } }, false, false, {});
     validate("UPDATE table_name SET (column1, column2)=1, column3=2;"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN1"sv, "COLUMN2"sv }, { "COLUMN3"sv } }, false, false, {});
 
 
     validate("UPDATE table_name SET column_name=1 WHERE 1==1;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, true, false, {});
     validate("UPDATE table_name SET column_name=1 WHERE 1==1;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, true, false, {});
 
 
+    validate("UPDATE table_name SET column_name=1 WHERE (SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, true, false, {});
+    validate("UPDATE table_name SET column_name=1 WHERE EXISTS (SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, true, false, {});
+    validate("UPDATE table_name SET column_name=1 WHERE NOT EXISTS (SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, true, false, {});
+
     validate("UPDATE table_name SET column_name=1 RETURNING *;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, false, true, {});
     validate("UPDATE table_name SET column_name=1 RETURNING *;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, false, true, {});
     validate("UPDATE table_name SET column_name=1 RETURNING column_name;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, false, true, { {} });
     validate("UPDATE table_name SET column_name=1 RETURNING column_name;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, false, true, { {} });
     validate("UPDATE table_name SET column_name=1 RETURNING column_name AS alias;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, false, true, { "ALIAS"sv });
     validate("UPDATE table_name SET column_name=1 RETURNING column_name AS alias;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, false, true, { "ALIAS"sv });
@@ -469,6 +500,12 @@ TEST_CASE(delete_)
     EXPECT(parse("DELETE FROM"sv).is_error());
     EXPECT(parse("DELETE FROM"sv).is_error());
     EXPECT(parse("DELETE FROM table_name"sv).is_error());
     EXPECT(parse("DELETE FROM table_name"sv).is_error());
     EXPECT(parse("DELETE FROM table_name WHERE"sv).is_error());
     EXPECT(parse("DELETE FROM table_name WHERE"sv).is_error());
+    EXPECT(parse("DELETE FROM table_name WHERE EXISTS"sv).is_error());
+    EXPECT(parse("DELETE FROM table_name WHERE NOT"sv).is_error());
+    EXPECT(parse("DELETE FROM table_name WHERE NOT (SELECT 1)"sv).is_error());
+    EXPECT(parse("DELETE FROM table_name WHERE NOT EXISTS"sv).is_error());
+    EXPECT(parse("DELETE FROM table_name WHERE SELECT"sv).is_error());
+    EXPECT(parse("DELETE FROM table_name WHERE (SELECT)"sv).is_error());
     EXPECT(parse("DELETE FROM table_name WHERE 15"sv).is_error());
     EXPECT(parse("DELETE FROM table_name WHERE 15"sv).is_error());
     EXPECT(parse("DELETE FROM table_name WHERE 15 RETURNING"sv).is_error());
     EXPECT(parse("DELETE FROM table_name WHERE 15 RETURNING"sv).is_error());
     EXPECT(parse("DELETE FROM table_name WHERE 15 RETURNING *"sv).is_error());
     EXPECT(parse("DELETE FROM table_name WHERE 15 RETURNING *"sv).is_error());
@@ -514,6 +551,9 @@ TEST_CASE(delete_)
     validate("DELETE FROM schema_name.table_name;"sv, "SCHEMA_NAME"sv, "TABLE_NAME"sv, {}, false, false, {});
     validate("DELETE FROM schema_name.table_name;"sv, "SCHEMA_NAME"sv, "TABLE_NAME"sv, {}, false, false, {});
     validate("DELETE FROM schema_name.table_name AS alias;"sv, "SCHEMA_NAME"sv, "TABLE_NAME"sv, "ALIAS"sv, false, false, {});
     validate("DELETE FROM schema_name.table_name AS alias;"sv, "SCHEMA_NAME"sv, "TABLE_NAME"sv, "ALIAS"sv, false, false, {});
     validate("DELETE FROM table_name WHERE (1 == 1);"sv, {}, "TABLE_NAME"sv, {}, true, false, {});
     validate("DELETE FROM table_name WHERE (1 == 1);"sv, {}, "TABLE_NAME"sv, {}, true, false, {});
+    validate("DELETE FROM table_name WHERE EXISTS (SELECT 1);"sv, {}, "TABLE_NAME"sv, {}, true, false, {});
+    validate("DELETE FROM table_name WHERE NOT EXISTS (SELECT 1);"sv, {}, "TABLE_NAME"sv, {}, true, false, {});
+    validate("DELETE FROM table_name WHERE (SELECT 1);"sv, {}, "TABLE_NAME"sv, {}, true, false, {});
     validate("DELETE FROM table_name RETURNING *;"sv, {}, "TABLE_NAME"sv, {}, false, true, {});
     validate("DELETE FROM table_name RETURNING *;"sv, {}, "TABLE_NAME"sv, {}, false, true, {});
     validate("DELETE FROM table_name RETURNING column_name;"sv, {}, "TABLE_NAME"sv, {}, false, true, { {} });
     validate("DELETE FROM table_name RETURNING column_name;"sv, {}, "TABLE_NAME"sv, {}, false, true, { {} });
     validate("DELETE FROM table_name RETURNING column_name AS alias;"sv, {}, "TABLE_NAME"sv, {}, false, true, { "ALIAS"sv });
     validate("DELETE FROM table_name RETURNING column_name AS alias;"sv, {}, "TABLE_NAME"sv, {}, false, true, { "ALIAS"sv });

+ 33 - 18
Userland/Libraries/LibSQL/AST/Parser.cpp

@@ -219,7 +219,7 @@ NonnullRefPtr<Insert> Parser::parse_insert_statement(RefPtr<CommonTableExpressio
     if (consume_if(TokenType::Values)) {
     if (consume_if(TokenType::Values)) {
         parse_comma_separated_list(false, [&]() {
         parse_comma_separated_list(false, [&]() {
             if (auto chained_expression = parse_chained_expression()) {
             if (auto chained_expression = parse_chained_expression()) {
-                auto chained_expr = dynamic_cast<ChainedExpression*>(chained_expression.ptr());
+                auto* chained_expr = verify_cast<ChainedExpression>(chained_expression.ptr());
                 if ((column_names.size() > 0) && (chained_expr->expressions().size() != column_names.size())) {
                 if ((column_names.size() > 0) && (chained_expr->expressions().size() != column_names.size())) {
                     syntax_error("Number of expressions does not match number of columns");
                     syntax_error("Number of expressions does not match number of columns");
                 } else {
                 } else {
@@ -422,17 +422,34 @@ NonnullRefPtr<Expression> Parser::parse_primary_expression()
     if (auto expression = parse_unary_operator_expression())
     if (auto expression = parse_unary_operator_expression())
         return expression.release_nonnull();
         return expression.release_nonnull();
 
 
-    if (auto expression = parse_chained_expression())
-        return expression.release_nonnull();
-
     if (auto expression = parse_cast_expression())
     if (auto expression = parse_cast_expression())
         return expression.release_nonnull();
         return expression.release_nonnull();
 
 
     if (auto expression = parse_case_expression())
     if (auto expression = parse_case_expression())
         return expression.release_nonnull();
         return expression.release_nonnull();
 
 
-    if (auto expression = parse_exists_expression(false))
-        return expression.release_nonnull();
+    if (auto invert_expression = consume_if(TokenType::Not); invert_expression || consume_if(TokenType::Exists)) {
+        if (auto expression = parse_exists_expression(invert_expression))
+            return expression.release_nonnull();
+
+        expected("Exists expression"sv);
+    }
+
+    if (consume_if(TokenType::ParenOpen)) {
+        // Encountering a Select token at this point means this must be an ExistsExpression with no EXISTS keyword.
+        if (match(TokenType::Select)) {
+            auto select_statement = parse_select_statement({});
+            consume(TokenType::ParenClose);
+            return create_ast_node<ExistsExpression>(move(select_statement), false);
+        }
+
+        if (auto expression = parse_chained_expression(false)) {
+            consume(TokenType::ParenClose);
+            return expression.release_nonnull();
+        }
+
+        expected("Chained expression"sv);
+    }
 
 
     expected("Primary Expression"sv);
     expected("Primary Expression"sv);
     consume();
     consume();
@@ -662,17 +679,16 @@ RefPtr<Expression> Parser::parse_binary_operator_expression(NonnullRefPtr<Expres
     return {};
     return {};
 }
 }
 
 
-RefPtr<Expression> Parser::parse_chained_expression()
+RefPtr<Expression> Parser::parse_chained_expression(bool surrounded_by_parentheses)
 {
 {
-    if (!consume_if(TokenType::ParenOpen))
+    if (surrounded_by_parentheses && !consume_if(TokenType::ParenOpen))
         return {};
         return {};
 
 
-    if (match(TokenType::Select))
-        return parse_exists_expression(false, TokenType::Select);
-
     Vector<NonnullRefPtr<Expression>> expressions;
     Vector<NonnullRefPtr<Expression>> expressions;
     parse_comma_separated_list(false, [&]() { expressions.append(parse_expression()); });
     parse_comma_separated_list(false, [&]() { expressions.append(parse_expression()); });
-    consume(TokenType::ParenClose);
+
+    if (surrounded_by_parentheses)
+        consume(TokenType::ParenClose);
 
 
     return create_ast_node<ChainedExpression>(move(expressions));
     return create_ast_node<ChainedExpression>(move(expressions));
 }
 }
@@ -726,15 +742,14 @@ RefPtr<Expression> Parser::parse_case_expression()
     return create_ast_node<CaseExpression>(move(case_expression), move(when_then_clauses), move(else_expression));
     return create_ast_node<CaseExpression>(move(case_expression), move(when_then_clauses), move(else_expression));
 }
 }
 
 
-RefPtr<Expression> Parser::parse_exists_expression(bool invert_expression, TokenType opening_token)
+RefPtr<Expression> Parser::parse_exists_expression(bool invert_expression)
 {
 {
-    VERIFY((opening_token == TokenType::Exists) || (opening_token == TokenType::Select));
-
-    if ((opening_token == TokenType::Exists) && !consume_if(TokenType::Exists))
+    if (!(match(TokenType::Exists) || match(TokenType::ParenOpen)))
         return {};
         return {};
 
 
-    if (opening_token == TokenType::Exists)
-        consume(TokenType::ParenOpen);
+    consume_if(TokenType::Exists);
+    consume(TokenType::ParenOpen);
+
     auto select_statement = parse_select_statement({});
     auto select_statement = parse_select_statement({});
     consume(TokenType::ParenClose);
     consume(TokenType::ParenClose);
 
 

+ 2 - 2
Userland/Libraries/LibSQL/AST/Parser.h

@@ -77,10 +77,10 @@ private:
     RefPtr<Expression> parse_column_name_expression(DeprecatedString with_parsed_identifier = {}, bool with_parsed_period = false);
     RefPtr<Expression> parse_column_name_expression(DeprecatedString with_parsed_identifier = {}, bool with_parsed_period = false);
     RefPtr<Expression> parse_unary_operator_expression();
     RefPtr<Expression> parse_unary_operator_expression();
     RefPtr<Expression> parse_binary_operator_expression(NonnullRefPtr<Expression> lhs);
     RefPtr<Expression> parse_binary_operator_expression(NonnullRefPtr<Expression> lhs);
-    RefPtr<Expression> parse_chained_expression();
+    RefPtr<Expression> parse_chained_expression(bool surrounded_by_parentheses = true);
     RefPtr<Expression> parse_cast_expression();
     RefPtr<Expression> parse_cast_expression();
     RefPtr<Expression> parse_case_expression();
     RefPtr<Expression> parse_case_expression();
-    RefPtr<Expression> parse_exists_expression(bool invert_expression, TokenType opening_token = TokenType::Exists);
+    RefPtr<Expression> parse_exists_expression(bool invert_expression);
     RefPtr<Expression> parse_collate_expression(NonnullRefPtr<Expression> expression);
     RefPtr<Expression> parse_collate_expression(NonnullRefPtr<Expression> expression);
     RefPtr<Expression> parse_is_expression(NonnullRefPtr<Expression> expression);
     RefPtr<Expression> parse_is_expression(NonnullRefPtr<Expression> expression);
     RefPtr<Expression> parse_match_expression(NonnullRefPtr<Expression> lhs, bool invert_expression);
     RefPtr<Expression> parse_match_expression(NonnullRefPtr<Expression> lhs, bool invert_expression);