Parcourir la source

LibRegex: Make codegen+optimisation for alternatives much faster

Just a little thinking outside the box, and we can now parse and
optimise a million copies of "a|" chained together in just a second :^)
Ali Mohammad Pur il y a 3 ans
Parent
commit
97a333608e

+ 1 - 1
Tests/LibRegex/Regex.cpp

@@ -498,7 +498,7 @@ TEST_CASE(posix_extended_nested_capture_group)
     EXPECT_EQ(result.capture_group_matches[0][2].view, "llo"sv);
 }
 
-auto parse_test_case_long_disjunction_chain = String::repeated("a|"sv, 10000);
+auto parse_test_case_long_disjunction_chain = String::repeated("a|"sv, 100000);
 
 TEST_CASE(ECMA262_parse)
 {

+ 1 - 0
Userland/Libraries/LibRegex/RegexBytecodeStreamOptimizer.h

@@ -14,6 +14,7 @@ namespace regex {
 class Optimizer {
 public:
     static void append_alternation(ByteCode& target, ByteCode&& left, ByteCode&& right);
+    static void append_alternation(ByteCode& target, Span<ByteCode> alternatives);
     static void append_character_class(ByteCode& target, Vector<CompareTypeAndValuePair>&& pairs);
 };
 

+ 5 - 2
Userland/Libraries/LibRegex/RegexDebug.h

@@ -35,9 +35,12 @@ public:
     template<typename T>
     void print_bytecode(Regex<T> const& regex) const
     {
-        MatchState state;
-        auto& bytecode = regex.parser_result.bytecode;
+        print_bytecode(regex.parser_result.bytecode);
+    }
 
+    void print_bytecode(ByteCode const& bytecode) const
+    {
+        MatchState state;
         for (;;) {
             auto& opcode = bytecode.get_opcode(state);
             print_opcode("PrintBytecode", opcode, state);

+ 140 - 48
Userland/Libraries/LibRegex/RegexOptimizer.cpp

@@ -9,6 +9,10 @@
 #include <AK/Stack.h>
 #include <LibRegex/Regex.h>
 #include <LibRegex/RegexBytecodeStreamOptimizer.h>
+#if REGEX_DEBUG
+#    include <AK/ScopeGuard.h>
+#    include <AK/ScopeLogger.h>
+#endif
 
 namespace regex {
 
@@ -444,78 +448,166 @@ void Regex<Parser>::attempt_rewrite_loops_as_atomic_groups(BasicBlockList const&
 
 void Optimizer::append_alternation(ByteCode& target, ByteCode&& left, ByteCode&& right)
 {
-    auto left_is_empty = left.is_empty();
-    auto right_is_empty = right.is_empty();
-    if (left_is_empty || right_is_empty) {
-        if (left_is_empty && right_is_empty)
-            return;
-
-        // ForkJump left (+ 2 + right.size())
-        // (right)
-        // Jump end (+ left.size())
-        // (left)
-        // LABEL end
-        target.append(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
-        target.append(2 + right.size());
-        target.extend(move(right));
-        target.append(static_cast<ByteCodeValueType>(OpCodeId::Jump));
-        target.append(left.size());
-        target.extend(move(left));
+    Array<ByteCode, 2> alternatives;
+    alternatives[0] = move(left);
+    alternatives[1] = move(right);
+
+    append_alternation(target, alternatives);
+}
+
+void Optimizer::append_alternation(ByteCode& target, Span<ByteCode> alternatives)
+{
+    if (alternatives.size() == 0)
+        return;
+
+    if (alternatives.size() == 1)
+        return target.extend(move(alternatives[0]));
+
+    if (all_of(alternatives, [](auto& x) { return x.is_empty(); }))
         return;
+
+    for (auto& entry : alternatives)
+        entry.flatten();
+
+#if REGEX_DEBUG
+    ScopeLogger<true> log;
+    warnln("Alternations:");
+    RegexDebug dbg;
+    for (auto& entry : alternatives) {
+        warnln("----------");
+        dbg.print_bytecode(entry);
     }
+    ScopeGuard print_at_end {
+        [&] {
+            warnln("======================");
+            RegexDebug dbg;
+            dbg.print_bytecode(target);
+        }
+    };
+#endif
 
-    left.flatten();
-    right.flatten();
+    Vector<Vector<Detail::Block>> basic_blocks;
+    basic_blocks.ensure_capacity(alternatives.size());
 
-    auto left_blocks = Regex<PosixBasicParser>::split_basic_blocks(left);
-    auto right_blocks = Regex<PosixBasicParser>::split_basic_blocks(right);
+    for (auto& entry : alternatives)
+        basic_blocks.append(Regex<PosixBasicParser>::split_basic_blocks(entry));
 
     size_t left_skip = 0;
-    MatchState state;
-    for (size_t block_index = 0; block_index < left_blocks.size() && block_index < right_blocks.size(); block_index++) {
-        auto& left_block = left_blocks[block_index];
-        auto& right_block = right_blocks[block_index];
-        auto left_end = block_index + 1 == left_blocks.size() ? left_block.end : left_blocks[block_index + 1].start;
-        auto right_end = block_index + 1 == right_blocks.size() ? right_block.end : right_blocks[block_index + 1].start;
+    size_t shared_block_count = basic_blocks.first().size();
+    for (auto& entry : basic_blocks)
+        shared_block_count = min(shared_block_count, entry.size());
 
-        if (left_end - left_block.start != right_end - right_block.start)
-            break;
+    MatchState state;
+    for (size_t block_index = 0; block_index < shared_block_count; block_index++) {
+        auto& left_block = basic_blocks.first()[block_index];
+        auto left_end = block_index + 1 == basic_blocks.first().size() ? left_block.end : basic_blocks.first()[block_index + 1].start;
+        auto can_continue = true;
+        for (size_t i = 1; i < alternatives.size(); ++i) {
+            auto& right_blocks = basic_blocks[i];
+            auto& right_block = right_blocks[block_index];
+            auto right_end = block_index + 1 == right_blocks.size() ? right_block.end : right_blocks[block_index + 1].start;
+
+            if (left_end - left_block.start != right_end - right_block.start) {
+                can_continue = false;
+                break;
+            }
 
-        if (left.spans().slice(left_block.start, left_end - left_block.start) != right.spans().slice(right_block.start, right_end - right_block.start))
+            if (alternatives[0].spans().slice(left_block.start, left_end - left_block.start) != alternatives[i].spans().slice(right_block.start, right_end - right_block.start)) {
+                can_continue = false;
+                break;
+            }
+        }
+        if (!can_continue)
             break;
 
-        state.instruction_position = 0;
-        while (state.instruction_position < left_end) {
-            auto& opcode = left.get_opcode(state);
-            left_skip = state.instruction_position;
-            state.instruction_position += opcode.size();
+        size_t i = 0;
+        for (auto& entry : alternatives) {
+            auto& blocks = basic_blocks[i];
+            auto& block = blocks[block_index];
+            auto end = block_index + 1 == blocks.size() ? block.end : blocks[block_index + 1].start;
+            state.instruction_position = block.start;
+            size_t skip = 0;
+            while (state.instruction_position < end) {
+                auto& opcode = entry.get_opcode(state);
+                state.instruction_position += opcode.size();
+                skip = state.instruction_position;
+            }
+            left_skip = min(skip, left_skip);
         }
     }
 
-    dbgln_if(REGEX_DEBUG, "Skipping {}/{} bytecode entries from {}/{}", left_skip, 0, left.size(), right.size());
+    dbgln_if(REGEX_DEBUG, "Skipping {}/{} bytecode entries from {}", left_skip, 0, alternatives[0].size());
 
     if (left_skip > 0) {
-        target.extend(left.release_slice(left_blocks.first().start, left_skip));
-        right = right.release_slice(left_skip);
+        target.extend(alternatives[0].release_slice(basic_blocks.first().first().start, left_skip));
+        auto first = true;
+        for (auto& entry : alternatives) {
+            if (first) {
+                first = false;
+                continue;
+            }
+            entry = entry.release_slice(left_skip);
+        }
     }
 
-    auto left_size = left.size();
+    if (all_of(alternatives, [](auto& entry) { return entry.is_empty(); }))
+        return;
 
-    target.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
-    target.empend(right.size() + (left_size > 0 ? 2 : 0)); // Jump to the _ALT label
+    size_t patch_start = target.size();
+    for (size_t i = 1; i < alternatives.size(); ++i) {
+        target.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
+        target.empend(0u); // To be filled later.
+    }
 
-    target.extend(move(right));
+    size_t size_to_jump = 0;
+    bool seen_one_empty = false;
+    for (size_t i = alternatives.size(); i > 0; --i) {
+        auto& entry = alternatives[i - 1];
+        if (entry.is_empty()) {
+            if (seen_one_empty)
+                continue;
+            seen_one_empty = true;
+        }
 
-    if (left_size != 0) {
-        target.empend(static_cast<ByteCodeValueType>(OpCodeId::Jump));
-        target.empend(left.size()); // Jump to the _END label
+        auto is_first = i == 1;
+        auto instruction_size = entry.size() + (is_first ? 0 : 2); // Jump; -> +2
+        size_to_jump += instruction_size;
+
+        if (!is_first)
+            target[patch_start + (i - 2) * 2 + 1] = size_to_jump + (alternatives.size() - i) * 2;
+
+        dbgln_if(REGEX_DEBUG, "{} size = {}, cum={}", i - 1, instruction_size, size_to_jump);
     }
 
-    // LABEL _ALT = bytecode.size() + 2
+    seen_one_empty = false;
+    for (size_t i = alternatives.size(); i > 0; --i) {
+        auto& chunk = alternatives[i - 1];
+        if (chunk.is_empty()) {
+            if (seen_one_empty)
+                continue;
+            seen_one_empty = true;
+        }
 
-    target.extend(move(left));
+        ByteCode* previous_chunk = nullptr;
+        size_t j = i - 1;
+        auto seen_one_empty_before = chunk.is_empty();
+        while (j >= 1) {
+            --j;
+            auto& candidate_chunk = alternatives[j];
+            if (candidate_chunk.is_empty()) {
+                if (seen_one_empty_before)
+                    continue;
+            }
+            previous_chunk = &candidate_chunk;
+            break;
+        }
+
+        size_to_jump -= chunk.size() + (previous_chunk ? 2 : 0);
 
-    // LABEL _END = alterantive_bytecode.size
+        target.extend(move(chunk));
+        target.empend(static_cast<ByteCodeValueType>(OpCodeId::Jump));
+        target.empend(size_to_jump); // Jump to the _END label
+    }
 }
 
 enum class LookupTableInsertionOutcome {

+ 2 - 13
Userland/Libraries/LibRegex/RegexParser.cpp

@@ -958,7 +958,7 @@ bool ECMA262Parser::parse_disjunction(ByteCode& stack, size_t& match_length_mini
 {
     size_t total_match_length_minimum = NumericLimits<size_t>::max();
     Vector<ByteCode> alternatives;
-    do {
+    while (true) {
         ByteCode alternative_stack;
         size_t alternative_minimum_length = 0;
         auto alt_ok = parse_alternative(alternative_stack, alternative_minimum_length, unicode, named);
@@ -971,20 +971,9 @@ bool ECMA262Parser::parse_disjunction(ByteCode& stack, size_t& match_length_mini
         if (!match(TokenType::Pipe))
             break;
         consume();
-    } while (true);
-
-    Optional<ByteCode> alternative_stack {};
-    for (auto& alternative : alternatives) {
-        if (alternative_stack.has_value()) {
-            ByteCode target_stack;
-            target_stack.insert_bytecode_alternation(alternative_stack.release_value(), move(alternative));
-            alternative_stack = move(target_stack);
-        } else {
-            alternative_stack = move(alternative);
-        }
     }
 
-    stack.extend(alternative_stack.release_value());
+    Optimizer::append_alternation(stack, alternatives.span());
     match_length_minimum = total_match_length_minimum;
     return true;
 }