Browse Source

LibRegex: Make infinite repetitions short-circuit on empty matches

This makes (addmittedly weird) patterns like `(a*)*` work correctly
without going into an infinite fork loop.
Ali Mohammad Pur 3 years ago
parent
commit
abbe9da255

+ 15 - 0
Tests/LibRegex/Regex.cpp

@@ -852,6 +852,21 @@ TEST_CASE(extremely_long_fork_chain)
     EXPECT_EQ(result.success, true);
     EXPECT_EQ(result.success, true);
 }
 }
 
 
+TEST_CASE(theoretically_infinite_loop)
+{
+    Array patterns {
+        "(a*)*"sv,  // Infinitely matching empty substrings, the outer loop should short-circuit.
+        "(a*?)*"sv, // Infinitely matching empty substrings, the outer loop should short-circuit.
+        "(a*)*?"sv, // Should match exactly nothing.
+        "(?:)*?"sv, // Should not generate an infinite fork loop.
+    };
+    for (auto& pattern : patterns) {
+        Regex<ECMA262> re(pattern);
+        auto result = re.match("");
+        EXPECT_EQ(result.success, true);
+    }
+}
+
 static auto g_lots_of_a_s = String::repeated('a', 10'000'000);
 static auto g_lots_of_a_s = String::repeated('a', 10'000'000);
 
 
 BENCHMARK_CASE(fork_performance)
 BENCHMARK_CASE(fork_performance)

+ 54 - 54
Userland/Libraries/LibRegex/RegexByteCode.cpp

@@ -46,6 +46,22 @@ char const* execution_result_name(ExecutionResult result)
     }
     }
 }
 }
 
 
+char const* opcode_id_name(OpCodeId opcode)
+{
+    switch (opcode) {
+#define __ENUMERATE_OPCODE(x) \
+    case OpCodeId::x:         \
+        return #x;
+
+        ENUMERATE_OPCODES
+
+#undef __ENUMERATE_OPCODE
+    default:
+        VERIFY_NOT_REACHED();
+        return "<Unknown>";
+    }
+}
+
 char const* boundary_check_type_name(BoundaryCheckType ty)
 char const* boundary_check_type_name(BoundaryCheckType ty)
 {
 {
     switch (ty) {
     switch (ty) {
@@ -144,60 +160,14 @@ void ByteCode::ensure_opcodes_initialized()
         return;
         return;
     for (u32 i = (u32)OpCodeId::First; i <= (u32)OpCodeId::Last; ++i) {
     for (u32 i = (u32)OpCodeId::First; i <= (u32)OpCodeId::Last; ++i) {
         switch ((OpCodeId)i) {
         switch ((OpCodeId)i) {
-        case OpCodeId::Exit:
-            s_opcodes[i] = make<OpCode_Exit>();
-            break;
-        case OpCodeId::Jump:
-            s_opcodes[i] = make<OpCode_Jump>();
-            break;
-        case OpCodeId::Compare:
-            s_opcodes[i] = make<OpCode_Compare>();
-            break;
-        case OpCodeId::CheckEnd:
-            s_opcodes[i] = make<OpCode_CheckEnd>();
-            break;
-        case OpCodeId::CheckBoundary:
-            s_opcodes[i] = make<OpCode_CheckBoundary>();
-            break;
-        case OpCodeId::ForkJump:
-            s_opcodes[i] = make<OpCode_ForkJump>();
-            break;
-        case OpCodeId::ForkStay:
-            s_opcodes[i] = make<OpCode_ForkStay>();
-            break;
-        case OpCodeId::FailForks:
-            s_opcodes[i] = make<OpCode_FailForks>();
-            break;
-        case OpCodeId::Save:
-            s_opcodes[i] = make<OpCode_Save>();
-            break;
-        case OpCodeId::Restore:
-            s_opcodes[i] = make<OpCode_Restore>();
-            break;
-        case OpCodeId::GoBack:
-            s_opcodes[i] = make<OpCode_GoBack>();
-            break;
-        case OpCodeId::CheckBegin:
-            s_opcodes[i] = make<OpCode_CheckBegin>();
-            break;
-        case OpCodeId::ClearCaptureGroup:
-            s_opcodes[i] = make<OpCode_ClearCaptureGroup>();
-            break;
-        case OpCodeId::SaveLeftCaptureGroup:
-            s_opcodes[i] = make<OpCode_SaveLeftCaptureGroup>();
-            break;
-        case OpCodeId::SaveRightCaptureGroup:
-            s_opcodes[i] = make<OpCode_SaveRightCaptureGroup>();
-            break;
-        case OpCodeId::SaveRightNamedCaptureGroup:
-            s_opcodes[i] = make<OpCode_SaveRightNamedCaptureGroup>();
-            break;
-        case OpCodeId::Repeat:
-            s_opcodes[i] = make<OpCode_Repeat>();
-            break;
-        case OpCodeId::ResetRepeat:
-            s_opcodes[i] = make<OpCode_ResetRepeat>();
-            break;
+#define __ENUMERATE_OPCODE(OpCode)              \
+    case OpCodeId::OpCode:                      \
+        s_opcodes[i] = make<OpCode_##OpCode>(); \
+        break;
+
+            ENUMERATE_OPCODES
+
+#undef __ENUMERATE_OPCODE
         }
         }
     }
     }
     s_opcodes_initialized = true;
     s_opcodes_initialized = true;
@@ -901,4 +871,34 @@ ALWAYS_INLINE ExecutionResult OpCode_ResetRepeat::execute(MatchInput const&, Mat
     return ExecutionResult::Continue;
     return ExecutionResult::Continue;
 }
 }
 
 
+ALWAYS_INLINE ExecutionResult OpCode_Checkpoint::execute(MatchInput const&, MatchState& state) const
+{
+    state.checkpoints.set(state.instruction_position, state.string_position);
+    return ExecutionResult::Continue;
+}
+
+ALWAYS_INLINE ExecutionResult OpCode_JumpNonEmpty::execute(MatchInput const&, MatchState& state) const
+{
+    auto current_position = state.string_position;
+    auto checkpoint_ip = state.instruction_position + size() + checkpoint();
+    if (state.checkpoints.get(checkpoint_ip).value_or(current_position) != current_position) {
+        auto form = this->form();
+
+        if (form == OpCodeId::Jump) {
+            state.instruction_position += offset();
+            return ExecutionResult::Continue;
+        }
+
+        state.fork_at_position = state.instruction_position + size() + offset();
+
+        if (form == OpCodeId::ForkJump)
+            return ExecutionResult::Fork_PrioHigh;
+
+        if (form == OpCodeId::ForkStay)
+            return ExecutionResult::Fork_PrioLow;
+    }
+
+    return ExecutionResult::Continue;
+}
+
 }
 }

+ 61 - 15
Userland/Libraries/LibRegex/RegexByteCode.h

@@ -27,6 +27,7 @@ using ByteCodeValueType = u64;
 #define ENUMERATE_OPCODES                          \
 #define ENUMERATE_OPCODES                          \
     __ENUMERATE_OPCODE(Compare)                    \
     __ENUMERATE_OPCODE(Compare)                    \
     __ENUMERATE_OPCODE(Jump)                       \
     __ENUMERATE_OPCODE(Jump)                       \
+    __ENUMERATE_OPCODE(JumpNonEmpty)               \
     __ENUMERATE_OPCODE(ForkJump)                   \
     __ENUMERATE_OPCODE(ForkJump)                   \
     __ENUMERATE_OPCODE(ForkStay)                   \
     __ENUMERATE_OPCODE(ForkStay)                   \
     __ENUMERATE_OPCODE(FailForks)                  \
     __ENUMERATE_OPCODE(FailForks)                  \
@@ -42,6 +43,7 @@ using ByteCodeValueType = u64;
     __ENUMERATE_OPCODE(ClearCaptureGroup)          \
     __ENUMERATE_OPCODE(ClearCaptureGroup)          \
     __ENUMERATE_OPCODE(Repeat)                     \
     __ENUMERATE_OPCODE(Repeat)                     \
     __ENUMERATE_OPCODE(ResetRepeat)                \
     __ENUMERATE_OPCODE(ResetRepeat)                \
+    __ENUMERATE_OPCODE(Checkpoint)                 \
     __ENUMERATE_OPCODE(Exit)
     __ENUMERATE_OPCODE(Exit)
 
 
 // clang-format off
 // clang-format off
@@ -319,16 +321,14 @@ public:
         empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
         empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
         empend(right.size() + 2); // Jump to the _ALT label
         empend(right.size() + 2); // Jump to the _ALT label
 
 
-        for (auto& op : right)
-            append(move(op));
+        extend(right);
 
 
         empend(static_cast<ByteCodeValueType>(OpCodeId::Jump));
         empend(static_cast<ByteCodeValueType>(OpCodeId::Jump));
         empend(left.size()); // Jump to the _END label
         empend(left.size()); // Jump to the _END label
 
 
         // LABEL _ALT = bytecode.size() + 2
         // LABEL _ALT = bytecode.size() + 2
 
 
-        for (auto& op : left)
-            append(move(op));
+        extend(left);
 
 
         // LABEL _END = alterantive_bytecode.size
         // LABEL _END = alterantive_bytecode.size
     }
     }
@@ -376,10 +376,21 @@ public:
                 new_bytecode[pre_loop_fork_jump_index - 1] = (ByteCodeValueType)(fork_jump_address - pre_loop_fork_jump_index);
                 new_bytecode[pre_loop_fork_jump_index - 1] = (ByteCodeValueType)(fork_jump_address - pre_loop_fork_jump_index);
             }
             }
         } else {
         } else {
-            // no maximum value set, repeat finding if possible
+            // no maximum value set, repeat finding if possible:
+            // (REPEAT REGEXP MIN)
+            // LABEL _START
+            // CHECKPOINT _C
+            // REGEXP
+            // JUMP_NONEMPTY _C _START FORK
+
+            // Note: This is only safe because REPEAT will leave one iteration outside (see repetition_n)
+            new_bytecode.insert(new_bytecode.size() - bytecode_to_repeat.size(), (ByteCodeValueType)OpCodeId::Checkpoint);
+
             auto jump_kind = static_cast<ByteCodeValueType>(greedy ? OpCodeId::ForkJump : OpCodeId::ForkStay);
             auto jump_kind = static_cast<ByteCodeValueType>(greedy ? OpCodeId::ForkJump : OpCodeId::ForkStay);
+            new_bytecode.empend((ByteCodeValueType)OpCodeId::JumpNonEmpty);
+            new_bytecode.empend(-bytecode_to_repeat.size() - 4 - 1); // Jump to the last iteration
+            new_bytecode.empend(-bytecode_to_repeat.size() - 4 - 1); // if _C is not empty.
             new_bytecode.empend(jump_kind);
             new_bytecode.empend(jump_kind);
-            new_bytecode.empend(-bytecode_to_repeat.size() - 2); // Jump to the last iteration
         }
         }
 
 
         bytecode_to_repeat = move(new_bytecode);
         bytecode_to_repeat = move(new_bytecode);
@@ -412,23 +423,29 @@ public:
     static void transform_bytecode_repetition_min_one(ByteCode& bytecode_to_repeat, bool greedy)
     static void transform_bytecode_repetition_min_one(ByteCode& bytecode_to_repeat, bool greedy)
     {
     {
         // LABEL _START = -bytecode_to_repeat.size()
         // LABEL _START = -bytecode_to_repeat.size()
+        // CHECKPOINT _C
         // REGEXP
         // REGEXP
-        // FORKSTAY _START  (FORKJUMP -> Greedy)
+        // JUMP_NONEMPTY _C _START FORKSTAY (FORKJUMP -> Greedy)
+
+        bytecode_to_repeat.prepend((ByteCodeValueType)OpCodeId::Checkpoint);
+
+        bytecode_to_repeat.empend((ByteCodeValueType)OpCodeId::JumpNonEmpty);
+        bytecode_to_repeat.empend(-bytecode_to_repeat.size() - 3); // Jump to the _START label...
+        bytecode_to_repeat.empend(-bytecode_to_repeat.size() - 2); // ...if _C is not empty
 
 
         if (greedy)
         if (greedy)
             bytecode_to_repeat.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
             bytecode_to_repeat.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
         else
         else
             bytecode_to_repeat.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkStay));
             bytecode_to_repeat.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkStay));
-
-        bytecode_to_repeat.empend(-(bytecode_to_repeat.size() + 1)); // Jump to the _START label
     }
     }
 
 
     static void transform_bytecode_repetition_any(ByteCode& bytecode_to_repeat, bool greedy)
     static void transform_bytecode_repetition_any(ByteCode& bytecode_to_repeat, bool greedy)
     {
     {
         // LABEL _START
         // LABEL _START
         // FORKJUMP _END  (FORKSTAY -> Greedy)
         // FORKJUMP _END  (FORKSTAY -> Greedy)
+        // CHECKPOINT _C
         // REGEXP
         // REGEXP
-        // JUMP  _START
+        // JUMP_NONEMPTY _C _START JUMP
         // LABEL _END
         // LABEL _END
 
 
         // LABEL _START = m_bytes.size();
         // LABEL _START = m_bytes.size();
@@ -439,13 +456,17 @@ public:
         else
         else
             bytecode.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
             bytecode.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
 
 
-        bytecode.empend(bytecode_to_repeat.size() + 2); // Jump to the _END label
+        bytecode.empend(bytecode_to_repeat.size() + 1 + 4); // Jump to the _END label
 
 
-        for (auto& op : bytecode_to_repeat)
-            bytecode.append(move(op));
+        auto c_label = bytecode.size();
+        bytecode.empend(static_cast<ByteCodeValueType>(OpCodeId::Checkpoint));
+
+        bytecode.extend(bytecode_to_repeat);
 
 
-        bytecode.empend(static_cast<ByteCodeValueType>(OpCodeId::Jump));
-        bytecode.empend(-bytecode.size() - 1); // Jump to the _START label
+        bytecode.empend(static_cast<ByteCodeValueType>(OpCodeId::JumpNonEmpty));
+        bytecode.empend(-bytecode.size() - 3);          // Jump(...) to the _START label...
+        bytecode.empend(c_label - bytecode.size() - 2); // ...only if _C passes.
+        bytecode.empend((ByteCodeValueType)OpCodeId::Jump);
         // LABEL _END = bytecode.size()
         // LABEL _END = bytecode.size()
 
 
         bytecode_to_repeat = move(bytecode);
         bytecode_to_repeat = move(bytecode);
@@ -744,6 +765,31 @@ public:
     }
     }
 };
 };
 
 
+class OpCode_Checkpoint final : public OpCode {
+public:
+    ExecutionResult execute(MatchInput const& input, MatchState& state) const override;
+    ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::Checkpoint; }
+    ALWAYS_INLINE size_t size() const override { return 1; }
+    String const arguments_string() const override { return ""; }
+};
+
+class OpCode_JumpNonEmpty final : public OpCode {
+public:
+    ExecutionResult execute(MatchInput const& input, MatchState& state) const override;
+    ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::JumpNonEmpty; }
+    ALWAYS_INLINE size_t size() const override { return 4; }
+    ALWAYS_INLINE ssize_t offset() const { return argument(0); }
+    ALWAYS_INLINE ssize_t checkpoint() const { return argument(1); }
+    ALWAYS_INLINE OpCodeId form() const { return (OpCodeId)argument(2); }
+    String const arguments_string() const override
+    {
+        return String::formatted("{} offset={} [&{}], cp={} [&{}]",
+            opcode_id_name(form()),
+            offset(), state().instruction_position + size() + offset(),
+            checkpoint(), state().instruction_position + size() + checkpoint());
+    }
+};
+
 template<typename T>
 template<typename T>
 bool is(OpCode const&);
 bool is(OpCode const&);
 
 

+ 1 - 0
Userland/Libraries/LibRegex/RegexMatch.h

@@ -524,6 +524,7 @@ struct MatchState {
     Vector<Match> matches;
     Vector<Match> matches;
     Vector<Vector<Match>> capture_group_matches;
     Vector<Vector<Match>> capture_group_matches;
     Vector<u64> repetition_marks;
     Vector<u64> repetition_marks;
+    HashMap<u64, u64> checkpoints;
 };
 };
 
 
 }
 }