浏览代码

LibRegex: Fully interpret the Compare Op when looking for overlaps

We had a really naive and simplistic implementation, which lead to
various issues where the optimiser incorrectly rewrote the regex to use
atomic groups; this commit fixes that.
Ali Mohammad Pur 3 年之前
父节点
当前提交
6e655b7f89

+ 1 - 0
Tests/LibRegex/Regex.cpp

@@ -921,6 +921,7 @@ TEST_CASE(optimizer_atomic_groups)
         Tuple { "a*b"sv, "aaaaa"sv, false },
         Tuple { "a*b"sv, "aaaaa"sv, false },
         Tuple { "a+b"sv, "aaaaa"sv, false },
         Tuple { "a+b"sv, "aaaaa"sv, false },
         Tuple { "\\\\(\\d+)"sv, "\\\\"sv, false }, // Rewrite bug turning a+ to a*, see #10952.
         Tuple { "\\\\(\\d+)"sv, "\\\\"sv, false }, // Rewrite bug turning a+ to a*, see #10952.
+        Tuple { "[a-z.]+\\."sv, "..."sv, true },   // Rewrite bug, incorrect interpretation of Compare.
         // Alternative fuse
         // Alternative fuse
         Tuple { "(abcfoo|abcbar|abcbaz).*x"sv, "abcbarx"sv, true },
         Tuple { "(abcfoo|abcbar|abcbaz).*x"sv, "abcbarx"sv, true },
         Tuple { "(a|a)"sv, "a"sv, true },
         Tuple { "(a|a)"sv, "a"sv, true },

+ 26 - 88
Userland/Libraries/LibRegex/RegexByteCode.cpp

@@ -666,7 +666,17 @@ ALWAYS_INLINE bool OpCode_Compare::compare_string(MatchInput const& input, Match
 
 
 ALWAYS_INLINE void OpCode_Compare::compare_character_class(MatchInput const& input, MatchState& state, CharClass character_class, u32 ch, bool inverse, bool& inverse_matched)
 ALWAYS_INLINE void OpCode_Compare::compare_character_class(MatchInput const& input, MatchState& state, CharClass character_class, u32 ch, bool inverse, bool& inverse_matched)
 {
 {
-    auto is_space_or_line_terminator = [](u32 code_point) {
+    if (matches_character_class(character_class, ch, input.regex_options & AllFlags::Insensitive)) {
+        if (inverse)
+            inverse_matched = true;
+        else
+            advance_string_position(state, input.view, ch);
+    }
+}
+
+bool OpCode_Compare::matches_character_class(CharClass character_class, u32 ch, bool insensitive)
+{
+    constexpr auto is_space_or_line_terminator = [](u32 code_point) {
         static auto space_separator = Unicode::general_category_from_string("Space_Separator"sv);
         static auto space_separator = Unicode::general_category_from_string("Space_Separator"sv);
         if (!space_separator.has_value())
         if (!space_separator.has_value())
             return is_ascii_space(code_point);
             return is_ascii_space(code_point);
@@ -680,106 +690,34 @@ ALWAYS_INLINE void OpCode_Compare::compare_character_class(MatchInput const& inp
 
 
     switch (character_class) {
     switch (character_class) {
     case CharClass::Alnum:
     case CharClass::Alnum:
-        if (is_ascii_alphanumeric(ch)) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_alphanumeric(ch);
     case CharClass::Alpha:
     case CharClass::Alpha:
-        if (is_ascii_alpha(ch))
-            advance_string_position(state, input.view, ch);
-        break;
+        return is_ascii_alpha(ch);
     case CharClass::Blank:
     case CharClass::Blank:
-        if (is_ascii_blank(ch)) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_blank(ch);
     case CharClass::Cntrl:
     case CharClass::Cntrl:
-        if (is_ascii_control(ch)) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_control(ch);
     case CharClass::Digit:
     case CharClass::Digit:
-        if (is_ascii_digit(ch)) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_digit(ch);
     case CharClass::Graph:
     case CharClass::Graph:
-        if (is_ascii_graphical(ch)) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_graphical(ch);
     case CharClass::Lower:
     case CharClass::Lower:
-        if (is_ascii_lower_alpha(ch) || ((input.regex_options & AllFlags::Insensitive) && is_ascii_upper_alpha(ch))) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_lower_alpha(ch) || (insensitive && is_ascii_upper_alpha(ch));
     case CharClass::Print:
     case CharClass::Print:
-        if (is_ascii_printable(ch)) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_printable(ch);
     case CharClass::Punct:
     case CharClass::Punct:
-        if (is_ascii_punctuation(ch)) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_punctuation(ch);
     case CharClass::Space:
     case CharClass::Space:
-        if (is_space_or_line_terminator(ch)) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_space_or_line_terminator(ch);
     case CharClass::Upper:
     case CharClass::Upper:
-        if (is_ascii_upper_alpha(ch) || ((input.regex_options & AllFlags::Insensitive) && is_ascii_lower_alpha(ch))) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_upper_alpha(ch) || (insensitive && is_ascii_lower_alpha(ch));
     case CharClass::Word:
     case CharClass::Word:
-        if (is_ascii_alphanumeric(ch) || ch == '_') {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_alphanumeric(ch) || ch == '_';
     case CharClass::Xdigit:
     case CharClass::Xdigit:
-        if (is_ascii_hex_digit(ch)) {
-            if (inverse)
-                inverse_matched = true;
-            else
-                advance_string_position(state, input.view, ch);
-        }
-        break;
+        return is_ascii_hex_digit(ch);
     }
     }
+
+    VERIFY_NOT_REACHED();
 }
 }
 
 
 ALWAYS_INLINE void OpCode_Compare::compare_character_range(MatchInput const& input, MatchState& state, u32 from, u32 to, u32 ch, bool inverse, bool& inverse_matched)
 ALWAYS_INLINE void OpCode_Compare::compare_character_range(MatchInput const& input, MatchState& state, u32 from, u32 to, u32 ch, bool inverse, bool& inverse_matched)

+ 1 - 0
Userland/Libraries/LibRegex/RegexByteCode.h

@@ -745,6 +745,7 @@ public:
     String arguments_string() const override;
     String arguments_string() const override;
     Vector<String> variable_arguments_to_string(Optional<MatchInput> input = {}) const;
     Vector<String> variable_arguments_to_string(Optional<MatchInput> input = {}) const;
     Vector<CompareTypeAndValuePair> flat_compares() const;
     Vector<CompareTypeAndValuePair> flat_compares() const;
+    static bool matches_character_class(CharClass, u32, bool insensitive);
 
 
 private:
 private:
     ALWAYS_INLINE static void compare_char(MatchInput const& input, MatchState& state, u32 ch1, bool inverse, bool& inverse_matched);
     ALWAYS_INLINE static void compare_char(MatchInput const& input, MatchState& state, u32 ch1, bool inverse, bool& inverse_matched);

+ 178 - 10
Userland/Libraries/LibRegex/RegexOptimizer.cpp

@@ -108,6 +108,182 @@ typename Regex<Parser>::BasicBlockList Regex<Parser>::split_basic_blocks(ByteCod
     return block_boundaries;
     return block_boundaries;
 }
 }
 
 
+static bool has_overlap(Vector<CompareTypeAndValuePair> const& lhs, Vector<CompareTypeAndValuePair> const& rhs)
+{
+
+    // We have to fully interpret the two sequences to determine if they overlap (that is, keep track of inversion state and what ranges they cover).
+    bool inverse { false };
+    bool temporary_inverse { false };
+    bool reset_temporary_inverse { false };
+
+    auto current_lhs_inversion_state = [&]() -> bool { return temporary_inverse ^ inverse; };
+
+    RedBlackTree<u32, u32> lhs_ranges;
+    RedBlackTree<u32, u32> lhs_negated_ranges;
+    HashTable<CharClass> lhs_char_classes;
+    HashTable<CharClass> lhs_negated_char_classes;
+
+    auto range_contains = [&]<typename T>(T& value) -> bool {
+        u32 start;
+        u32 end;
+
+        if constexpr (IsSame<T, CharRange>) {
+            start = value.from;
+            end = value.to;
+        } else {
+            start = value;
+            end = value;
+        }
+
+        auto* max = lhs_ranges.find_smallest_not_below(start);
+        return max && *max <= end;
+    };
+
+    auto char_class_contains = [&](CharClass const& value) -> bool {
+        if (lhs_char_classes.contains(value))
+            return true;
+
+        if (lhs_negated_char_classes.contains(value))
+            return false;
+
+        // This char class might match something in the ranges we have, and checking that is far too expensive, so just bail out.
+        return true;
+    };
+
+    for (auto const& pair : lhs) {
+        if (reset_temporary_inverse) {
+            reset_temporary_inverse = false;
+            temporary_inverse = false;
+        } else {
+            reset_temporary_inverse = true;
+        }
+
+        switch (pair.type) {
+        case CharacterCompareType::Inverse:
+            inverse = !inverse;
+            break;
+        case CharacterCompareType::TemporaryInverse:
+            temporary_inverse = !temporary_inverse;
+            break;
+        case CharacterCompareType::AnyChar:
+            // Special case: if not inverted, AnyChar is always in the range.
+            if (!current_lhs_inversion_state())
+                return true;
+            break;
+        case CharacterCompareType::Char:
+            if (!current_lhs_inversion_state())
+                lhs_ranges.insert(pair.value, pair.value);
+            else
+                lhs_negated_ranges.insert(pair.value, pair.value);
+            break;
+        case CharacterCompareType::String:
+            // FIXME: We just need to look at the last character of this string, but we only have the first character here.
+            //        Just bail out to avoid false positives.
+            return true;
+        case CharacterCompareType::CharClass:
+            if (!current_lhs_inversion_state())
+                lhs_char_classes.set(static_cast<CharClass>(pair.value));
+            else
+                lhs_negated_char_classes.set(static_cast<CharClass>(pair.value));
+            break;
+        case CharacterCompareType::CharRange: {
+            auto range = bit_cast<CharRange>(pair.value);
+            if (!current_lhs_inversion_state())
+                lhs_ranges.insert(range.from, range.to);
+            else
+                lhs_negated_ranges.insert(range.from, range.to);
+            break;
+        }
+        case CharacterCompareType::LookupTable:
+            // We've transformed this into a series of ranges in flat_compares(), so bail out if we see it.
+            return true;
+        case CharacterCompareType::Reference:
+            // We've handled this before coming here.
+            break;
+        case CharacterCompareType::Property:
+        case CharacterCompareType::GeneralCategory:
+        case CharacterCompareType::Script:
+        case CharacterCompareType::ScriptExtension:
+            // FIXME: These are too difficult to handle, so bail out.
+            return true;
+        case CharacterCompareType::Undefined:
+        case CharacterCompareType::RangeExpressionDummy:
+            // These do not occur in valid bytecode.
+            VERIFY_NOT_REACHED();
+        }
+    }
+
+    if constexpr (REGEX_DEBUG) {
+        dbgln("lhs ranges:");
+        for (auto it = lhs_ranges.begin(); it != lhs_ranges.end(); ++it)
+            dbgln("  {}..{}", it.key(), *it);
+        dbgln("lhs negated ranges:");
+        for (auto it = lhs_negated_ranges.begin(); it != lhs_negated_ranges.end(); ++it)
+            dbgln("  {}..{}", it.key(), *it);
+    }
+
+    for (auto const& pair : rhs) {
+        if (reset_temporary_inverse) {
+            reset_temporary_inverse = false;
+            temporary_inverse = false;
+        } else {
+            reset_temporary_inverse = true;
+        }
+
+        dbgln_if(REGEX_DEBUG, "check {} ({})...", character_compare_type_name(pair.type), pair.value);
+
+        switch (pair.type) {
+        case CharacterCompareType::Inverse:
+            inverse = !inverse;
+            break;
+        case CharacterCompareType::TemporaryInverse:
+            temporary_inverse = !temporary_inverse;
+            break;
+        case CharacterCompareType::AnyChar:
+            // Special case: if not inverted, AnyChar is always in the range.
+            if (!current_lhs_inversion_state())
+                return true;
+            break;
+        case CharacterCompareType::Char:
+            if (!current_lhs_inversion_state() && range_contains(pair.value))
+                return true;
+            break;
+        case CharacterCompareType::String:
+            // FIXME: We just need to look at the last character of this string, but we only have the first character here.
+            //        Just bail out to avoid false positives.
+            return true;
+        case CharacterCompareType::CharClass:
+            if (!current_lhs_inversion_state() && char_class_contains(static_cast<CharClass>(pair.value)))
+                return true;
+            break;
+        case CharacterCompareType::CharRange: {
+            auto range = bit_cast<CharRange>(pair.value);
+            if (!current_lhs_inversion_state() && range_contains(range))
+                return true;
+            break;
+        }
+        case CharacterCompareType::LookupTable:
+            // We've transformed this into a series of ranges in flat_compares(), so bail out if we see it.
+            return true;
+        case CharacterCompareType::Reference:
+            // We've handled this before coming here.
+            break;
+        case CharacterCompareType::Property:
+        case CharacterCompareType::GeneralCategory:
+        case CharacterCompareType::Script:
+        case CharacterCompareType::ScriptExtension:
+            // FIXME: These are too difficult to handle, so bail out.
+            return true;
+        case CharacterCompareType::Undefined:
+        case CharacterCompareType::RangeExpressionDummy:
+            // These do not occur in valid bytecode.
+            VERIFY_NOT_REACHED();
+        }
+    }
+
+    return false;
+}
+
 enum class AtomicRewritePreconditionResult {
 enum class AtomicRewritePreconditionResult {
     SatisfiedWithProperHeader,
     SatisfiedWithProperHeader,
     SatisfiedWithEmptyHeader,
     SatisfiedWithEmptyHeader,
@@ -179,17 +355,9 @@ static AtomicRewritePreconditionResult block_satisfies_atomic_rewrite_preconditi
                 }))
                 }))
                 return AtomicRewritePreconditionResult::NotSatisfied;
                 return AtomicRewritePreconditionResult::NotSatisfied;
 
 
-            for (auto& repeated_value : repeated_values) {
-                // FIXME: This is too naive!
-                if (any_of(repeated_value, [](auto& compare) { return compare.type == CharacterCompareType::AnyChar; }))
-                    return AtomicRewritePreconditionResult::NotSatisfied;
+            if (any_of(repeated_values, [&](auto& repeated_value) { return has_overlap(compares, repeated_value); }))
+                return AtomicRewritePreconditionResult::NotSatisfied;
 
 
-                for (auto& repeated_compare : repeated_value) {
-                    // FIXME: This is too naive! it will miss _tons_ of cases since it doesn't check ranges!
-                    if (any_of(compares, [&](auto& compare) { return compare.type == repeated_compare.type && compare.value == repeated_compare.value; }))
-                        return AtomicRewritePreconditionResult::NotSatisfied;
-                }
-            }
             return AtomicRewritePreconditionResult::SatisfiedWithProperHeader;
             return AtomicRewritePreconditionResult::SatisfiedWithProperHeader;
         }
         }
         case OpCodeId::CheckBegin:
         case OpCodeId::CheckBegin: