RegexOptimizer.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. /*
  2. * Copyright (c) 2021, Ali Mohammad Pur <mpfard@serenityos.org>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <AK/QuickSort.h>
  7. #include <AK/RedBlackTree.h>
  8. #include <AK/Stack.h>
  9. #include <LibRegex/Regex.h>
  10. #include <LibRegex/RegexBytecodeStreamOptimizer.h>
  11. namespace regex {
  12. using Detail::Block;
  13. template<typename Parser>
  14. void Regex<Parser>::run_optimization_passes()
  15. {
  16. // Rewrite fork loops as atomic groups
  17. // e.g. a*b -> (ATOMIC a*)b
  18. attempt_rewrite_loops_as_atomic_groups(split_basic_blocks());
  19. parser_result.bytecode.flatten();
  20. }
  21. template<typename Parser>
  22. typename Regex<Parser>::BasicBlockList Regex<Parser>::split_basic_blocks()
  23. {
  24. BasicBlockList block_boundaries;
  25. auto& bytecode = parser_result.bytecode;
  26. size_t end_of_last_block = 0;
  27. MatchState state;
  28. state.instruction_position = 0;
  29. auto check_jump = [&]<typename T>(OpCode const& opcode) {
  30. auto& op = static_cast<T const&>(opcode);
  31. ssize_t jump_offset = op.size() + op.offset();
  32. if (jump_offset >= 0) {
  33. block_boundaries.append({ end_of_last_block, state.instruction_position });
  34. end_of_last_block = state.instruction_position + opcode.size();
  35. } else {
  36. // This op jumps back, see if that's within this "block".
  37. if (jump_offset + state.instruction_position > end_of_last_block) {
  38. // Split the block!
  39. block_boundaries.append({ end_of_last_block, jump_offset + state.instruction_position });
  40. block_boundaries.append({ jump_offset + state.instruction_position, state.instruction_position });
  41. end_of_last_block = state.instruction_position + opcode.size();
  42. } else {
  43. // Nope, it's just a jump to another block
  44. block_boundaries.append({ end_of_last_block, state.instruction_position });
  45. end_of_last_block = state.instruction_position + opcode.size();
  46. }
  47. }
  48. };
  49. for (;;) {
  50. auto& opcode = bytecode.get_opcode(state);
  51. switch (opcode.opcode_id()) {
  52. case OpCodeId::Jump:
  53. check_jump.template operator()<OpCode_Jump>(opcode);
  54. break;
  55. case OpCodeId::JumpNonEmpty:
  56. check_jump.template operator()<OpCode_JumpNonEmpty>(opcode);
  57. break;
  58. case OpCodeId::ForkJump:
  59. check_jump.template operator()<OpCode_ForkJump>(opcode);
  60. break;
  61. case OpCodeId::ForkStay:
  62. check_jump.template operator()<OpCode_ForkStay>(opcode);
  63. break;
  64. case OpCodeId::FailForks:
  65. block_boundaries.append({ end_of_last_block, state.instruction_position });
  66. end_of_last_block = state.instruction_position + opcode.size();
  67. break;
  68. case OpCodeId::Repeat: {
  69. // Repeat produces two blocks, one containing its repeated expr, and one after that.
  70. auto repeat_start = state.instruction_position - static_cast<OpCode_Repeat const&>(opcode).offset();
  71. if (repeat_start > end_of_last_block)
  72. block_boundaries.append({ end_of_last_block, repeat_start });
  73. block_boundaries.append({ repeat_start, state.instruction_position });
  74. end_of_last_block = state.instruction_position + opcode.size();
  75. break;
  76. }
  77. default:
  78. break;
  79. }
  80. auto next_ip = state.instruction_position + opcode.size();
  81. if (next_ip < bytecode.size())
  82. state.instruction_position = next_ip;
  83. else
  84. break;
  85. }
  86. if (end_of_last_block < bytecode.size())
  87. block_boundaries.append({ end_of_last_block, bytecode.size() });
  88. quick_sort(block_boundaries, [](auto& a, auto& b) { return a.start < b.start; });
  89. return block_boundaries;
  90. }
  91. static bool block_satisfies_atomic_rewrite_precondition(ByteCode const& bytecode, Block const& repeated_block, Block const& following_block)
  92. {
  93. Vector<Vector<CompareTypeAndValuePair>> repeated_values;
  94. HashTable<size_t> active_capture_groups;
  95. MatchState state;
  96. for (state.instruction_position = repeated_block.start; state.instruction_position < repeated_block.end;) {
  97. auto& opcode = bytecode.get_opcode(state);
  98. switch (opcode.opcode_id()) {
  99. case OpCodeId::Compare: {
  100. auto compares = static_cast<OpCode_Compare const&>(opcode).flat_compares();
  101. if (repeated_values.is_empty() && any_of(compares, [](auto& compare) { return compare.type == CharacterCompareType::AnyChar; }))
  102. return false;
  103. repeated_values.append(move(compares));
  104. break;
  105. }
  106. case OpCodeId::CheckBegin:
  107. case OpCodeId::CheckEnd:
  108. if (repeated_values.is_empty())
  109. return true;
  110. break;
  111. case OpCodeId::CheckBoundary:
  112. // FIXME: What should we do with these? for now, let's fail.
  113. return false;
  114. case OpCodeId::Restore:
  115. case OpCodeId::GoBack:
  116. return false;
  117. case OpCodeId::SaveRightCaptureGroup:
  118. active_capture_groups.set(static_cast<OpCode_SaveRightCaptureGroup const&>(opcode).id());
  119. break;
  120. case OpCodeId::SaveLeftCaptureGroup:
  121. active_capture_groups.set(static_cast<OpCode_SaveLeftCaptureGroup const&>(opcode).id());
  122. break;
  123. default:
  124. break;
  125. }
  126. state.instruction_position += opcode.size();
  127. }
  128. dbgln_if(REGEX_DEBUG, "Found {} entries in reference", repeated_values.size());
  129. dbgln_if(REGEX_DEBUG, "Found {} active capture groups", active_capture_groups.size());
  130. // Find the first compare in the following block, it must NOT match any of the values in `repeated_values'.
  131. for (state.instruction_position = following_block.start; state.instruction_position < following_block.end;) {
  132. auto& opcode = bytecode.get_opcode(state);
  133. switch (opcode.opcode_id()) {
  134. // Note: These have to exist since we're effectively repeating the following block as well
  135. case OpCodeId::SaveRightCaptureGroup:
  136. active_capture_groups.set(static_cast<OpCode_SaveRightCaptureGroup const&>(opcode).id());
  137. break;
  138. case OpCodeId::SaveLeftCaptureGroup:
  139. active_capture_groups.set(static_cast<OpCode_SaveLeftCaptureGroup const&>(opcode).id());
  140. break;
  141. case OpCodeId::Compare: {
  142. // We found a compare, let's see what it has.
  143. auto compares = static_cast<OpCode_Compare const&>(opcode).flat_compares();
  144. if (compares.is_empty())
  145. break;
  146. if (any_of(compares, [&](auto& compare) {
  147. return compare.type == CharacterCompareType::AnyChar
  148. || (compare.type == CharacterCompareType::Reference && active_capture_groups.contains(compare.value));
  149. }))
  150. return false;
  151. for (auto& repeated_value : repeated_values) {
  152. // FIXME: This is too naive!
  153. if (any_of(repeated_value, [](auto& compare) { return compare.type == CharacterCompareType::AnyChar; }))
  154. return false;
  155. for (auto& repeated_compare : repeated_value) {
  156. // FIXME: This is too naive! it will miss _tons_ of cases since it doesn't check ranges!
  157. if (any_of(compares, [&](auto& compare) { return compare.type == repeated_compare.type && compare.value == repeated_compare.value; }))
  158. return false;
  159. }
  160. }
  161. return true;
  162. }
  163. case OpCodeId::CheckBegin:
  164. case OpCodeId::CheckEnd:
  165. return true; // Nothing can match the end!
  166. case OpCodeId::CheckBoundary:
  167. // FIXME: What should we do with these? For now, consider them a failure.
  168. return false;
  169. default:
  170. break;
  171. }
  172. state.instruction_position += opcode.size();
  173. }
  174. return true;
  175. }
  176. template<typename Parser>
  177. void Regex<Parser>::attempt_rewrite_loops_as_atomic_groups(BasicBlockList const& basic_blocks)
  178. {
  179. auto& bytecode = parser_result.bytecode;
  180. if constexpr (REGEX_DEBUG) {
  181. RegexDebug dbg;
  182. dbg.print_bytecode(*this);
  183. for (auto& block : basic_blocks)
  184. dbgln("block from {} to {}", block.start, block.end);
  185. }
  186. // A pattern such as:
  187. // bb0 | RE0
  188. // | ForkX bb0
  189. // -------------------------
  190. // bb1 | RE1
  191. // can be rewritten as:
  192. // loop.hdr | ForkStay bb1
  193. // -------------------------
  194. // bb0 | RE0
  195. // | ForkReplaceX bb0
  196. // -------------------------
  197. // bb1 | RE1
  198. // provided that first(RE1) not-in end(RE0), which is to say
  199. // that RE1 cannot start with whatever RE0 has matched (ever).
  200. //
  201. // Alternatively, a second form of this pattern can also occur:
  202. // bb0 | *
  203. // | ForkX bb2
  204. // ------------------------
  205. // bb1 | RE0
  206. // | Jump bb0
  207. // ------------------------
  208. // bb2 | RE1
  209. // which can be transformed (with the same preconditions) to:
  210. // bb0 | *
  211. // | ForkReplaceX bb2
  212. // ------------------------
  213. // bb1 | RE0
  214. // | Jump bb0
  215. // ------------------------
  216. // bb2 | RE1
  217. enum class AlternateForm {
  218. DirectLoopWithoutHeader, // loop without proper header, a block forking to itself. i.e. the first form.
  219. DirectLoopWithHeader, // loop with proper header, i.e. the second form.
  220. };
  221. struct CandidateBlock {
  222. Block forking_block;
  223. Optional<Block> new_target_block;
  224. AlternateForm form;
  225. };
  226. Vector<CandidateBlock> candidate_blocks;
  227. auto is_an_eligible_jump = [](OpCode const& opcode, size_t ip, size_t block_start, AlternateForm alternate_form) {
  228. switch (opcode.opcode_id()) {
  229. case OpCodeId::JumpNonEmpty: {
  230. auto& op = static_cast<OpCode_JumpNonEmpty const&>(opcode);
  231. auto form = op.form();
  232. if (form != OpCodeId::Jump && alternate_form == AlternateForm::DirectLoopWithHeader)
  233. return false;
  234. if (form != OpCodeId::ForkJump && form != OpCodeId::ForkStay && alternate_form == AlternateForm::DirectLoopWithoutHeader)
  235. return false;
  236. return op.offset() + ip + opcode.size() == block_start;
  237. }
  238. case OpCodeId::ForkJump:
  239. if (alternate_form == AlternateForm::DirectLoopWithHeader)
  240. return false;
  241. return static_cast<OpCode_ForkJump const&>(opcode).offset() + ip + opcode.size() == block_start;
  242. case OpCodeId::ForkStay:
  243. if (alternate_form == AlternateForm::DirectLoopWithHeader)
  244. return false;
  245. return static_cast<OpCode_ForkStay const&>(opcode).offset() + ip + opcode.size() == block_start;
  246. case OpCodeId::Jump:
  247. // Infinite loop does *not* produce forks.
  248. if (alternate_form == AlternateForm::DirectLoopWithoutHeader)
  249. return false;
  250. if (alternate_form == AlternateForm::DirectLoopWithHeader)
  251. return static_cast<OpCode_Jump const&>(opcode).offset() + ip + opcode.size() == block_start;
  252. VERIFY_NOT_REACHED();
  253. default:
  254. return false;
  255. }
  256. };
  257. for (size_t i = 0; i < basic_blocks.size(); ++i) {
  258. auto forking_block = basic_blocks[i];
  259. Optional<Block> fork_fallback_block;
  260. if (i + 1 < basic_blocks.size())
  261. fork_fallback_block = basic_blocks[i + 1];
  262. MatchState state;
  263. // Check if the last instruction in this block is a jump to the block itself:
  264. {
  265. state.instruction_position = forking_block.end;
  266. auto& opcode = bytecode.get_opcode(state);
  267. if (is_an_eligible_jump(opcode, state.instruction_position, forking_block.start, AlternateForm::DirectLoopWithoutHeader)) {
  268. // We've found RE0 (and RE1 is just the following block, if any), let's see if the precondition applies.
  269. // if RE1 is empty, there's no first(RE1), so this is an automatic pass.
  270. if (!fork_fallback_block.has_value() || fork_fallback_block->end == fork_fallback_block->start) {
  271. candidate_blocks.append({ forking_block, fork_fallback_block, AlternateForm::DirectLoopWithoutHeader });
  272. break;
  273. }
  274. if (block_satisfies_atomic_rewrite_precondition(bytecode, forking_block, *fork_fallback_block)) {
  275. candidate_blocks.append({ forking_block, fork_fallback_block, AlternateForm::DirectLoopWithoutHeader });
  276. break;
  277. }
  278. }
  279. }
  280. // Check if the last instruction in the last block is a direct jump to this block
  281. if (fork_fallback_block.has_value()) {
  282. state.instruction_position = fork_fallback_block->end;
  283. auto& opcode = bytecode.get_opcode(state);
  284. if (is_an_eligible_jump(opcode, state.instruction_position, forking_block.start, AlternateForm::DirectLoopWithHeader)) {
  285. // We've found bb1 and bb0, let's just make sure that bb0 forks to bb2.
  286. state.instruction_position = forking_block.end;
  287. auto& opcode = bytecode.get_opcode(state);
  288. if (opcode.opcode_id() == OpCodeId::ForkJump || opcode.opcode_id() == OpCodeId::ForkStay) {
  289. Optional<Block> block_following_fork_fallback;
  290. if (i + 2 < basic_blocks.size())
  291. block_following_fork_fallback = basic_blocks[i + 2];
  292. if (!block_following_fork_fallback.has_value() || block_satisfies_atomic_rewrite_precondition(bytecode, *fork_fallback_block, *block_following_fork_fallback)) {
  293. candidate_blocks.append({ forking_block, {}, AlternateForm::DirectLoopWithHeader });
  294. break;
  295. }
  296. }
  297. }
  298. }
  299. }
  300. dbgln_if(REGEX_DEBUG, "Found {} candidate blocks", candidate_blocks.size());
  301. if (candidate_blocks.is_empty()) {
  302. dbgln_if(REGEX_DEBUG, "Failed to find anything for {}", pattern_value);
  303. return;
  304. }
  305. RedBlackTree<size_t, size_t> needed_patches;
  306. // Reverse the blocks, so we can patch the bytecode without messing with the latter patches.
  307. quick_sort(candidate_blocks, [](auto& a, auto& b) { return b.forking_block.start > a.forking_block.start; });
  308. for (auto& candidate : candidate_blocks) {
  309. // Note that both forms share a ForkReplace patch in forking_block.
  310. // Patch the ForkX in forking_block to be a ForkReplaceX instead.
  311. auto& opcode_id = bytecode[candidate.forking_block.end];
  312. if (opcode_id == (ByteCodeValueType)OpCodeId::ForkStay) {
  313. opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceStay;
  314. } else if (opcode_id == (ByteCodeValueType)OpCodeId::ForkJump) {
  315. opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceJump;
  316. } else if (opcode_id == (ByteCodeValueType)OpCodeId::JumpNonEmpty) {
  317. auto& jump_opcode_id = bytecode[candidate.forking_block.end + 3];
  318. if (jump_opcode_id == (ByteCodeValueType)OpCodeId::ForkStay)
  319. jump_opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceStay;
  320. else if (jump_opcode_id == (ByteCodeValueType)OpCodeId::ForkJump)
  321. jump_opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceJump;
  322. else
  323. VERIFY_NOT_REACHED();
  324. } else {
  325. VERIFY_NOT_REACHED();
  326. }
  327. if (candidate.form == AlternateForm::DirectLoopWithoutHeader) {
  328. if (candidate.new_target_block.has_value()) {
  329. // Insert a fork-stay targeted at the second block.
  330. bytecode.insert(candidate.forking_block.start, (ByteCodeValueType)OpCodeId::ForkStay);
  331. bytecode.insert(candidate.forking_block.start + 1, candidate.new_target_block->start - candidate.forking_block.start);
  332. needed_patches.insert(candidate.forking_block.start, 2u);
  333. }
  334. }
  335. }
  336. if (!needed_patches.is_empty()) {
  337. MatchState state;
  338. state.instruction_position = 0;
  339. struct Patch {
  340. ssize_t value;
  341. size_t offset;
  342. bool should_negate { false };
  343. };
  344. for (;;) {
  345. if (state.instruction_position >= bytecode.size())
  346. break;
  347. auto& opcode = bytecode.get_opcode(state);
  348. Stack<Patch, 2> patch_points;
  349. switch (opcode.opcode_id()) {
  350. case OpCodeId::Jump:
  351. patch_points.push({ static_cast<OpCode_Jump const&>(opcode).offset(), state.instruction_position + 1 });
  352. break;
  353. case OpCodeId::JumpNonEmpty:
  354. patch_points.push({ static_cast<OpCode_JumpNonEmpty const&>(opcode).offset(), state.instruction_position + 1 });
  355. patch_points.push({ static_cast<OpCode_JumpNonEmpty const&>(opcode).checkpoint(), state.instruction_position + 2 });
  356. break;
  357. case OpCodeId::ForkJump:
  358. patch_points.push({ static_cast<OpCode_ForkJump const&>(opcode).offset(), state.instruction_position + 1 });
  359. break;
  360. case OpCodeId::ForkStay:
  361. patch_points.push({ static_cast<OpCode_ForkStay const&>(opcode).offset(), state.instruction_position + 1 });
  362. break;
  363. case OpCodeId::Repeat:
  364. patch_points.push({ -(ssize_t) static_cast<OpCode_Repeat const&>(opcode).offset(), state.instruction_position + 1, true });
  365. break;
  366. default:
  367. break;
  368. }
  369. while (!patch_points.is_empty()) {
  370. auto& patch_point = patch_points.top();
  371. auto target_offset = patch_point.value + state.instruction_position + opcode.size();
  372. constexpr auto do_patch = [](auto& patch_it, auto& patch_point, auto& target_offset, auto& bytecode, auto ip) {
  373. if (patch_it.key() == ip)
  374. return;
  375. if (patch_point.value < 0 && target_offset < patch_it.key() && ip > patch_it.key())
  376. bytecode[patch_point.offset] += (patch_point.should_negate ? 1 : -1) * (*patch_it);
  377. else if (patch_point.value > 0 && target_offset > patch_it.key() && ip < patch_it.key())
  378. bytecode[patch_point.offset] += (patch_point.should_negate ? -1 : 1) * (*patch_it);
  379. };
  380. if (auto patch_it = needed_patches.find_largest_not_above_iterator(target_offset); !patch_it.is_end())
  381. do_patch(patch_it, patch_point, target_offset, bytecode, state.instruction_position);
  382. else if (auto patch_it = needed_patches.find_largest_not_above_iterator(state.instruction_position); !patch_it.is_end())
  383. do_patch(patch_it, patch_point, target_offset, bytecode, state.instruction_position);
  384. patch_points.pop();
  385. }
  386. state.instruction_position += opcode.size();
  387. }
  388. }
  389. if constexpr (REGEX_DEBUG) {
  390. warnln("Transformed to:");
  391. RegexDebug dbg;
  392. dbg.print_bytecode(*this);
  393. }
  394. }
  395. void Optimizer::append_alternation(ByteCode& target, ByteCode&& left, ByteCode&& right)
  396. {
  397. if (left.is_empty()) {
  398. target.extend(right);
  399. return;
  400. }
  401. if (right.is_empty()) {
  402. target.extend(left);
  403. return;
  404. }
  405. size_t left_skip = 0;
  406. MatchState state;
  407. for (state.instruction_position = 0; state.instruction_position < left.size() && state.instruction_position < right.size();) {
  408. auto left_size = left.get_opcode(state).size();
  409. auto right_size = right.get_opcode(state).size();
  410. if (left_size != right_size)
  411. break;
  412. if (left.spans().slice(state.instruction_position, left_size) == right.spans().slice(state.instruction_position, right_size))
  413. left_skip = state.instruction_position + left_size;
  414. else
  415. break;
  416. state.instruction_position += left_size;
  417. }
  418. dbgln_if(REGEX_DEBUG, "Skipping {}/{} bytecode entries from {}/{}", left_skip, 0, left.size(), right.size());
  419. if (left_skip) {
  420. target.extend(left.release_slice(0, left_skip));
  421. right = right.release_slice(left_skip);
  422. }
  423. auto left_size = left.size();
  424. target.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
  425. target.empend(right.size() + (left_size ? 2 : 0)); // Jump to the _ALT label
  426. target.extend(move(right));
  427. if (left_size != 0) {
  428. target.empend(static_cast<ByteCodeValueType>(OpCodeId::Jump));
  429. target.empend(left.size()); // Jump to the _END label
  430. }
  431. // LABEL _ALT = bytecode.size() + 2
  432. target.extend(move(left));
  433. // LABEL _END = alterantive_bytecode.size
  434. }
  435. template void Regex<PosixBasicParser>::run_optimization_passes();
  436. template void Regex<PosixExtendedParser>::run_optimization_passes();
  437. template void Regex<ECMA262Parser>::run_optimization_passes();
  438. }