Selaa lähdekoodia

LibSQL: Fix parsing of lists of common-table-expression

Misread the graph: In the "WITH [RECURSIVE] common-table-expression"
section, common-table-expression is actually a repeating list. This
changes the parser to correctly parse this section as a list. Create a
new AST node, CommonTableExpressionList, to store both this list and the
boolean RECURSIVE attribute (because every statement that uses this list
also includes the RECURSIVE attribute beforehand).
Timothy Flynn 4 vuotta sitten
vanhempi
commit
6a7d7624a7

+ 21 - 7
Userland/Libraries/LibSQL/AST.h

@@ -97,6 +97,23 @@ private:
     Vector<String> m_column_names;
 };
 
+class CommonTableExpressionList : public ASTNode {
+public:
+    CommonTableExpressionList(bool recursive, NonnullRefPtrVector<CommonTableExpression> common_table_expressions)
+        : m_recursive(recursive)
+        , m_common_table_expressions(move(common_table_expressions))
+    {
+        VERIFY(!m_common_table_expressions.is_empty());
+    }
+
+    bool recursive() const { return m_recursive; }
+    const NonnullRefPtrVector<CommonTableExpression>& common_table_expressions() const { return m_common_table_expressions; }
+
+private:
+    bool m_recursive;
+    NonnullRefPtrVector<CommonTableExpression> m_common_table_expressions;
+};
+
 class QualifiedTableName : public ASTNode {
 public:
     QualifiedTableName(String schema_name, String table_name, String alias)
@@ -533,24 +550,21 @@ private:
 
 class Delete : public Statement {
 public:
-    Delete(bool recursive, RefPtr<CommonTableExpression> common_table_expression, NonnullRefPtr<QualifiedTableName> qualified_table_name, RefPtr<Expression> where_clause, RefPtr<ReturningClause> returning_clause)
-        : m_recursive(recursive)
-        , m_common_table_expression(move(common_table_expression))
+    Delete(RefPtr<CommonTableExpressionList> common_table_expression_list, NonnullRefPtr<QualifiedTableName> qualified_table_name, RefPtr<Expression> where_clause, RefPtr<ReturningClause> returning_clause)
+        : m_common_table_expression_list(move(common_table_expression_list))
         , m_qualified_table_name(move(qualified_table_name))
         , m_where_clause(move(where_clause))
         , m_returning_clause(move(returning_clause))
     {
     }
 
-    bool recursive() const { return m_recursive; }
-    const RefPtr<CommonTableExpression>& common_table_expression() const { return m_common_table_expression; }
+    const RefPtr<CommonTableExpressionList>& common_table_expression_list() const { return m_common_table_expression_list; }
     const NonnullRefPtr<QualifiedTableName>& qualified_table_name() const { return m_qualified_table_name; }
     const RefPtr<Expression>& where_clause() const { return m_where_clause; }
     const RefPtr<ReturningClause>& returning_clause() const { return m_returning_clause; }
 
 private:
-    bool m_recursive;
-    RefPtr<CommonTableExpression> m_common_table_expression;
+    RefPtr<CommonTableExpressionList> m_common_table_expression_list;
     NonnullRefPtr<QualifiedTableName> m_qualified_table_name;
     RefPtr<Expression> m_where_clause;
     RefPtr<ReturningClause> m_returning_clause;

+ 1 - 0
Userland/Libraries/LibSQL/Forward.h

@@ -18,6 +18,7 @@ class CollateExpression;
 class ColumnDefinition;
 class ColumnNameExpression;
 class CommonTableExpression;
+class CommonTableExpressionList;
 class CreateTable;
 class Delete;
 class DropTable;

+ 14 - 5
Userland/Libraries/LibSQL/Parser.cpp

@@ -112,11 +112,20 @@ NonnullRefPtr<Delete> Parser::parse_delete_statement()
 {
     // https://sqlite.org/lang_delete.html
 
-    bool recursive = false;
-    RefPtr<CommonTableExpression> common_table_expression;
+    RefPtr<CommonTableExpressionList> common_table_expression_list;
     if (consume_if(TokenType::With)) {
-        recursive = consume_if(TokenType::Recursive);
-        common_table_expression = parse_common_table_expression();
+        NonnullRefPtrVector<CommonTableExpression> common_table_expression;
+        bool recursive = consume_if(TokenType::Recursive);
+
+        do {
+            common_table_expression.append(parse_common_table_expression());
+            if (!match(TokenType::Comma))
+                break;
+
+            consume(TokenType::Comma);
+        } while (!match(TokenType::Eof));
+
+        common_table_expression_list = create_ast_node<CommonTableExpressionList>(recursive, move(common_table_expression));
     }
 
     consume(TokenType::Delete);
@@ -133,7 +142,7 @@ NonnullRefPtr<Delete> Parser::parse_delete_statement()
 
     consume(TokenType::SemiColon);
 
-    return create_ast_node<Delete>(recursive, move(common_table_expression), move(qualified_table_name), move(where_clause), move(returning_clause));
+    return create_ast_node<Delete>(move(common_table_expression_list), move(qualified_table_name), move(where_clause), move(returning_clause));
 }
 
 NonnullRefPtr<Expression> Parser::parse_expression()

+ 30 - 17
Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp

@@ -153,13 +153,17 @@ TEST_CASE(delete_)
     EXPECT(parse("WITH RECURSIVE table DELETE FROM table;").is_error());
     EXPECT(parse("WITH RECURSIVE table AS DELETE FROM table;").is_error());
 
-    struct SelectedTable {
+    struct SelectedTableList {
+        struct SelectedTable {
+            StringView table_name {};
+            Vector<StringView> column_names {};
+        };
+
         bool recursive { false };
-        StringView table_name {};
-        Vector<StringView> column_names {};
+        Vector<SelectedTable> selected_tables {};
     };
 
-    auto validate = [](StringView sql, SelectedTable expected_selected_table, StringView expected_schema, StringView expected_table, StringView expected_alias, bool expect_where_clause, bool expect_returning_clause, Vector<StringView> expected_returned_column_aliases) {
+    auto validate = [](StringView sql, SelectedTableList expected_selected_tables, StringView expected_schema, StringView expected_table, StringView expected_alias, bool expect_where_clause, bool expect_returning_clause, Vector<StringView> expected_returned_column_aliases) {
         auto result = parse(sql);
         EXPECT(!result.is_error());
 
@@ -167,15 +171,24 @@ TEST_CASE(delete_)
         EXPECT(is<SQL::Delete>(*statement));
 
         const auto& delete_ = static_cast<const SQL::Delete&>(*statement);
-        EXPECT_EQ(delete_.recursive(), expected_selected_table.recursive);
-
-        const auto& common_table_expression = delete_.common_table_expression();
-        EXPECT_EQ(common_table_expression.is_null(), expected_selected_table.table_name.is_empty());
-        if (common_table_expression) {
-            EXPECT_EQ(common_table_expression->table_name(), expected_selected_table.table_name);
-            EXPECT_EQ(common_table_expression->column_names().size(), expected_selected_table.column_names.size());
-            for (size_t i = 0; i < common_table_expression->column_names().size(); ++i)
-                EXPECT_EQ(common_table_expression->column_names()[i], expected_selected_table.column_names[i]);
+
+        const auto& common_table_expression_list = delete_.common_table_expression_list();
+        EXPECT_EQ(common_table_expression_list.is_null(), expected_selected_tables.selected_tables.is_empty());
+        if (common_table_expression_list) {
+            EXPECT_EQ(common_table_expression_list->recursive(), expected_selected_tables.recursive);
+
+            const auto& common_table_expressions = common_table_expression_list->common_table_expressions();
+            EXPECT_EQ(common_table_expressions.size(), expected_selected_tables.selected_tables.size());
+
+            for (size_t i = 0; i < common_table_expressions.size(); ++i) {
+                const auto& common_table_expression = common_table_expressions[i];
+                const auto& expected_common_table_expression = expected_selected_tables.selected_tables[i];
+                EXPECT_EQ(common_table_expression.table_name(), expected_common_table_expression.table_name);
+                EXPECT_EQ(common_table_expression.column_names().size(), expected_common_table_expression.column_names.size());
+
+                for (size_t j = 0; j < common_table_expression.column_names().size(); ++j)
+                    EXPECT_EQ(common_table_expression.column_names()[j], expected_common_table_expression.column_names[j]);
+            }
         }
 
         const auto& qualified_table_name = delete_.qualified_table_name();
@@ -213,10 +226,10 @@ TEST_CASE(delete_)
     validate("DELETE FROM table RETURNING column1 AS alias1, column2 AS alias2;", {}, {}, "table", {}, false, true, { "alias1", "alias2" });
 
     // FIXME: When parsing of SELECT statements are supported, the common-table-expressions below will become invalid due to the empty "AS ()" clause.
-    validate("WITH table AS () DELETE FROM table;", { false, "table", {} }, {}, "table", {}, false, false, {});
-    validate("WITH table (column) AS () DELETE FROM table;", { false, "table", { "column" } }, {}, "table", {}, false, false, {});
-    validate("WITH table (column1, column2) AS () DELETE FROM table;", { false, "table", { "column1", "column2" } }, {}, "table", {}, false, false, {});
-    validate("WITH RECURSIVE table AS () DELETE FROM table;", { true, "table", {} }, {}, "table", {}, false, false, {});
+    validate("WITH table AS () DELETE FROM table;", { false, { { "table" } } }, {}, "table", {}, false, false, {});
+    validate("WITH table (column) AS () DELETE FROM table;", { false, { { "table", { "column" } } } }, {}, "table", {}, false, false, {});
+    validate("WITH table (column1, column2) AS () DELETE FROM table;", { false, { { "table", { "column1", "column2" } } } }, {}, "table", {}, false, false, {});
+    validate("WITH RECURSIVE table AS () DELETE FROM table;", { true, { { "table", {} } } }, {}, "table", {}, false, false, {});
 }
 
 TEST_MAIN(SqlStatementParser)