LibRegex: Pull out the first compare to avoid unnecessary execution

This adds a fast-path to drop view indices we know will not match
immediately without going through the regex VM.
This commit is contained in:
Ali Mohammad Pur 2025-04-15 21:32:26 +02:00 committed by Andreas Kling
parent 76f5dce3db
commit 446a453719
Notes: github-actions[bot] 2025-04-18 15:10:31 +00:00
4 changed files with 250 additions and 117 deletions

View file

@ -4,6 +4,7 @@
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/BinarySearch.h>
#include <AK/BumpAllocator.h>
#include <AK/ByteString.h>
#include <AK/Debug.h>
@ -208,6 +209,23 @@ RegexResult Matcher<Parser>::match(Vector<RegexStringView> const& views, Optiona
auto single_match_only = input.regex_options.has_flag_set(AllFlags::SingleMatch);
auto only_start_of_line = m_pattern->parser_result.optimization_data.only_start_of_line && !input.regex_options.has_flag_set(AllFlags::Multiline);
auto compare_range = [insensitive = input.regex_options & AllFlags::Insensitive](auto needle, CharRange range) {
auto upper_case_needle = needle;
auto lower_case_needle = needle;
if (insensitive) {
upper_case_needle = to_ascii_uppercase(needle);
lower_case_needle = to_ascii_lowercase(needle);
}
if (lower_case_needle >= range.from && lower_case_needle <= range.to)
return 0;
if (upper_case_needle >= range.from && upper_case_needle <= range.to)
return 0;
if (lower_case_needle > range.to || upper_case_needle > range.to)
return 1;
return -1;
};
for (auto const& view : views) {
if (lines_to_skip != 0) {
++input.line;
@ -253,19 +271,26 @@ RegexResult Matcher<Parser>::match(Vector<RegexStringView> const& views, Optiona
}
for (; view_index <= view_length; ++view_index) {
if (view_index == view_length && input.regex_options.has_flag_set(AllFlags::Multiline))
break;
if (view_index == view_length) {
if (input.regex_options.has_flag_set(AllFlags::Multiline))
break;
}
auto& match_length_minimum = m_pattern->parser_result.match_length_minimum;
// FIXME: More performant would be to know the remaining minimum string
// length needed to match from the current position onwards within
// the vm. Add new OpCode for MinMatchLengthFromSp with the value of
// the remaining string length from the current path. The value though
// has to be filled in reverse. That implies a second run over bytecode
// after generation has finished.
auto const match_length_minimum = m_pattern->parser_result.match_length_minimum;
if (match_length_minimum && match_length_minimum > view_length - view_index)
break;
if (auto& starting_ranges = m_pattern->parser_result.optimization_data.starting_ranges; !starting_ranges.is_empty()) {
if (!binary_search(starting_ranges, input.view.code_unit_at(view_index), nullptr, compare_range))
goto done_matching;
}
input.column = match_count;
input.match_index = match_count;
@ -274,8 +299,7 @@ RegexResult Matcher<Parser>::match(Vector<RegexStringView> const& views, Optiona
state.instruction_position = 0;
state.repetition_marks.clear();
auto success = execute(input, state, operations);
if (success) {
if (execute(input, state, operations)) {
succeeded = true;
if (input.regex_options.has_flag_set(AllFlags::MatchNotEndOfLine) && state.string_position == input.view.length()) {
@ -315,6 +339,7 @@ RegexResult Matcher<Parser>::match(Vector<RegexStringView> const& views, Optiona
break;
}
done_matching:
if (!continue_search || only_start_of_line)
break;
}

View file

@ -230,6 +230,7 @@ private:
void run_optimization_passes();
void attempt_rewrite_loops_as_atomic_groups(BasicBlockList const&);
bool attempt_rewrite_entire_match_as_substring_search(BasicBlockList const&);
void fill_optimization_data(BasicBlockList const&);
};
// free standing functions for match, search and has_match

View file

@ -36,15 +36,211 @@ void Regex<Parser>::run_optimization_passes()
// e.g. a*b -> (ATOMIC a*)b
attempt_rewrite_loops_as_atomic_groups(blocks);
// FIXME: "There are a few more conditions this can be true in (e.g. within an arbitrarily nested capture group)"
auto state = MatchState::only_for_enumeration();
auto& opcode = parser_result.bytecode.get_opcode(state);
if (opcode.opcode_id() == OpCodeId::CheckBegin)
parser_result.optimization_data.only_start_of_line = true;
fill_optimization_data(split_basic_blocks(parser_result.bytecode));
parser_result.bytecode.flatten();
}
struct StaticallyInterpretedCompares {
RedBlackTree<u32, u32> lhs_ranges;
RedBlackTree<u32, u32> lhs_negated_ranges;
HashTable<CharClass> lhs_char_classes;
HashTable<CharClass> lhs_negated_char_classes;
bool has_any_unicode_property = false;
HashTable<Unicode::GeneralCategory> lhs_unicode_general_categories;
HashTable<Unicode::Property> lhs_unicode_properties;
HashTable<Unicode::Script> lhs_unicode_scripts;
HashTable<Unicode::Script> lhs_unicode_script_extensions;
HashTable<Unicode::GeneralCategory> lhs_negated_unicode_general_categories;
HashTable<Unicode::Property> lhs_negated_unicode_properties;
HashTable<Unicode::Script> lhs_negated_unicode_scripts;
HashTable<Unicode::Script> lhs_negated_unicode_script_extensions;
};
static bool interpret_compares(Vector<CompareTypeAndValuePair> const& lhs, StaticallyInterpretedCompares& compares)
{
bool inverse { false };
bool temporary_inverse { false };
bool reset_temporary_inverse { false };
auto current_lhs_inversion_state = [&]() -> bool { return temporary_inverse ^ inverse; };
auto& lhs_ranges = compares.lhs_ranges;
auto& lhs_negated_ranges = compares.lhs_negated_ranges;
auto& lhs_char_classes = compares.lhs_char_classes;
auto& lhs_negated_char_classes = compares.lhs_negated_char_classes;
auto& has_any_unicode_property = compares.has_any_unicode_property;
auto& lhs_unicode_general_categories = compares.lhs_unicode_general_categories;
auto& lhs_unicode_properties = compares.lhs_unicode_properties;
auto& lhs_unicode_scripts = compares.lhs_unicode_scripts;
auto& lhs_unicode_script_extensions = compares.lhs_unicode_script_extensions;
auto& lhs_negated_unicode_general_categories = compares.lhs_negated_unicode_general_categories;
auto& lhs_negated_unicode_properties = compares.lhs_negated_unicode_properties;
auto& lhs_negated_unicode_scripts = compares.lhs_negated_unicode_scripts;
auto& lhs_negated_unicode_script_extensions = compares.lhs_negated_unicode_script_extensions;
for (auto const& pair : lhs) {
if (reset_temporary_inverse) {
reset_temporary_inverse = false;
temporary_inverse = false;
} else {
reset_temporary_inverse = true;
}
switch (pair.type) {
case CharacterCompareType::Inverse:
inverse = !inverse;
break;
case CharacterCompareType::TemporaryInverse:
temporary_inverse = true;
reset_temporary_inverse = false;
break;
case CharacterCompareType::AnyChar:
// Special case: if not inverted, AnyChar is always in the range.
if (!current_lhs_inversion_state())
return false;
break;
case CharacterCompareType::Char:
if (!current_lhs_inversion_state())
lhs_ranges.insert(pair.value, pair.value);
else
lhs_negated_ranges.insert(pair.value, pair.value);
break;
case CharacterCompareType::String:
// FIXME: We just need to look at the last character of this string, but we only have the first character here.
// Just bail out to avoid false positives.
return false;
case CharacterCompareType::CharClass:
if (!current_lhs_inversion_state())
lhs_char_classes.set(static_cast<CharClass>(pair.value));
else
lhs_negated_char_classes.set(static_cast<CharClass>(pair.value));
break;
case CharacterCompareType::CharRange: {
auto range = CharRange(pair.value);
if (!current_lhs_inversion_state())
lhs_ranges.insert(range.from, range.to);
else
lhs_negated_ranges.insert(range.from, range.to);
break;
}
case CharacterCompareType::LookupTable:
// We've transformed this into a series of ranges in flat_compares(), so bail out if we see it.
return false;
case CharacterCompareType::Reference:
// We've handled this before coming here.
break;
case CharacterCompareType::Property:
has_any_unicode_property = true;
if (!current_lhs_inversion_state())
lhs_unicode_properties.set(static_cast<Unicode::Property>(pair.value));
else
lhs_negated_unicode_properties.set(static_cast<Unicode::Property>(pair.value));
break;
case CharacterCompareType::GeneralCategory:
has_any_unicode_property = true;
if (!current_lhs_inversion_state())
lhs_unicode_general_categories.set(static_cast<Unicode::GeneralCategory>(pair.value));
else
lhs_negated_unicode_general_categories.set(static_cast<Unicode::GeneralCategory>(pair.value));
break;
case CharacterCompareType::Script:
has_any_unicode_property = true;
if (!current_lhs_inversion_state())
lhs_unicode_scripts.set(static_cast<Unicode::Script>(pair.value));
else
lhs_negated_unicode_scripts.set(static_cast<Unicode::Script>(pair.value));
break;
case CharacterCompareType::ScriptExtension:
has_any_unicode_property = true;
if (!current_lhs_inversion_state())
lhs_unicode_script_extensions.set(static_cast<Unicode::Script>(pair.value));
else
lhs_negated_unicode_script_extensions.set(static_cast<Unicode::Script>(pair.value));
break;
case CharacterCompareType::Or:
case CharacterCompareType::EndAndOr:
// These are the default behaviour for [...], so we don't need to do anything (unless we add support for 'And' below).
break;
case CharacterCompareType::And:
// FIXME: These are too difficult to handle, so bail out.
return false;
case CharacterCompareType::Undefined:
case CharacterCompareType::RangeExpressionDummy:
// These do not occur in valid bytecode.
VERIFY_NOT_REACHED();
}
}
return true;
}
template<class Parser>
void Regex<Parser>::fill_optimization_data(BasicBlockList const& blocks)
{
if (blocks.is_empty())
return;
if constexpr (REGEX_DEBUG) {
dbgln("Pulling out optimization data from bytecode:");
RegexDebug dbg;
dbg.print_bytecode(*this);
for (auto const& block : blocks)
dbgln("block from {} to {} (comment: {})", block.start, block.end, block.comment);
}
ScopeGuard print = [&] {
if constexpr (REGEX_DEBUG) {
dbgln("Optimization data:");
if (parser_result.optimization_data.starting_ranges.is_empty())
dbgln("; - no starting ranges");
for (auto const& range : parser_result.optimization_data.starting_ranges)
dbgln(" - starting range: {}-{}", range.from, range.to);
dbgln("; - only start of line: {}", parser_result.optimization_data.only_start_of_line);
}
};
auto& bytecode = parser_result.bytecode;
auto state = MatchState::only_for_enumeration();
auto block = blocks.first();
for (state.instruction_position = block.start; state.instruction_position < block.end;) {
auto& opcode = bytecode.get_opcode(state);
switch (opcode.opcode_id()) {
case OpCodeId::Compare: {
auto flat_compares = static_cast<OpCode_Compare const&>(opcode).flat_compares();
StaticallyInterpretedCompares compares;
if (!interpret_compares(flat_compares, compares))
return; // No idea, the bytecode is too complex.
if (compares.has_any_unicode_property)
return; // Faster to just run the bytecode.
// FIXME: We should be able to handle these cases (jump ahead while...)
if (!compares.lhs_char_classes.is_empty() || !compares.lhs_negated_char_classes.is_empty() || !compares.lhs_negated_ranges.is_empty())
return;
for (auto it = compares.lhs_ranges.begin(); it != compares.lhs_ranges.end(); ++it)
parser_result.optimization_data.starting_ranges.append({ it.key(), *it });
return;
}
case OpCodeId::CheckBegin:
parser_result.optimization_data.only_start_of_line = true;
return;
case OpCodeId::Checkpoint:
case OpCodeId::Save:
case OpCodeId::ClearCaptureGroup:
case OpCodeId::SaveLeftCaptureGroup:
// These do not 'match' anything, so look through them.
state.instruction_position += opcode.size();
continue;
default:
return;
}
}
}
template<typename Parser>
typename Regex<Parser>::BasicBlockList Regex<Parser>::split_basic_blocks(ByteCode const& bytecode)
{
@ -126,7 +322,6 @@ typename Regex<Parser>::BasicBlockList Regex<Parser>::split_basic_blocks(ByteCod
static bool has_overlap(Vector<CompareTypeAndValuePair> const& lhs, Vector<CompareTypeAndValuePair> const& rhs)
{
// We have to fully interpret the two sequences to determine if they overlap (that is, keep track of inversion state and what ranges they cover).
bool inverse { false };
bool temporary_inverse { false };
@ -134,20 +329,20 @@ static bool has_overlap(Vector<CompareTypeAndValuePair> const& lhs, Vector<Compa
auto current_lhs_inversion_state = [&]() -> bool { return temporary_inverse ^ inverse; };
RedBlackTree<u32, u32> lhs_ranges;
RedBlackTree<u32, u32> lhs_negated_ranges;
HashTable<CharClass> lhs_char_classes;
HashTable<CharClass> lhs_negated_char_classes;
auto has_any_unicode_property = false;
HashTable<Unicode::GeneralCategory> lhs_unicode_general_categories;
HashTable<Unicode::Property> lhs_unicode_properties;
HashTable<Unicode::Script> lhs_unicode_scripts;
HashTable<Unicode::Script> lhs_unicode_script_extensions;
HashTable<Unicode::GeneralCategory> lhs_negated_unicode_general_categories;
HashTable<Unicode::Property> lhs_negated_unicode_properties;
HashTable<Unicode::Script> lhs_negated_unicode_scripts;
HashTable<Unicode::Script> lhs_negated_unicode_script_extensions;
StaticallyInterpretedCompares compares;
auto& lhs_ranges = compares.lhs_ranges;
auto& lhs_negated_ranges = compares.lhs_negated_ranges;
auto& lhs_char_classes = compares.lhs_char_classes;
auto& lhs_negated_char_classes = compares.lhs_negated_char_classes;
auto& has_any_unicode_property = compares.has_any_unicode_property;
auto& lhs_unicode_general_categories = compares.lhs_unicode_general_categories;
auto& lhs_unicode_properties = compares.lhs_unicode_properties;
auto& lhs_unicode_scripts = compares.lhs_unicode_scripts;
auto& lhs_unicode_script_extensions = compares.lhs_unicode_script_extensions;
auto& lhs_negated_unicode_general_categories = compares.lhs_negated_unicode_general_categories;
auto& lhs_negated_unicode_properties = compares.lhs_negated_unicode_properties;
auto& lhs_negated_unicode_scripts = compares.lhs_negated_unicode_scripts;
auto& lhs_negated_unicode_script_extensions = compares.lhs_negated_unicode_script_extensions;
auto any_unicode_property_matches = [&](u32 code_point) {
if (any_of(lhs_negated_unicode_general_categories, [code_point](auto category) { return Unicode::code_point_has_general_category(code_point, category); }))
@ -214,98 +409,8 @@ static bool has_overlap(Vector<CompareTypeAndValuePair> const& lhs, Vector<Compa
return false;
};
for (auto const& pair : lhs) {
if (reset_temporary_inverse) {
reset_temporary_inverse = false;
temporary_inverse = false;
} else {
reset_temporary_inverse = true;
}
switch (pair.type) {
case CharacterCompareType::Inverse:
inverse = !inverse;
break;
case CharacterCompareType::TemporaryInverse:
temporary_inverse = true;
reset_temporary_inverse = false;
break;
case CharacterCompareType::AnyChar:
// Special case: if not inverted, AnyChar is always in the range.
if (!current_lhs_inversion_state())
return true;
break;
case CharacterCompareType::Char:
if (!current_lhs_inversion_state())
lhs_ranges.insert(pair.value, pair.value);
else
lhs_negated_ranges.insert(pair.value, pair.value);
break;
case CharacterCompareType::String:
// FIXME: We just need to look at the last character of this string, but we only have the first character here.
// Just bail out to avoid false positives.
return true;
case CharacterCompareType::CharClass:
if (!current_lhs_inversion_state())
lhs_char_classes.set(static_cast<CharClass>(pair.value));
else
lhs_negated_char_classes.set(static_cast<CharClass>(pair.value));
break;
case CharacterCompareType::CharRange: {
auto range = CharRange(pair.value);
if (!current_lhs_inversion_state())
lhs_ranges.insert(range.from, range.to);
else
lhs_negated_ranges.insert(range.from, range.to);
break;
}
case CharacterCompareType::LookupTable:
// We've transformed this into a series of ranges in flat_compares(), so bail out if we see it.
return true;
case CharacterCompareType::Reference:
// We've handled this before coming here.
break;
case CharacterCompareType::Property:
has_any_unicode_property = true;
if (!current_lhs_inversion_state())
lhs_unicode_properties.set(static_cast<Unicode::Property>(pair.value));
else
lhs_negated_unicode_properties.set(static_cast<Unicode::Property>(pair.value));
break;
case CharacterCompareType::GeneralCategory:
has_any_unicode_property = true;
if (!current_lhs_inversion_state())
lhs_unicode_general_categories.set(static_cast<Unicode::GeneralCategory>(pair.value));
else
lhs_negated_unicode_general_categories.set(static_cast<Unicode::GeneralCategory>(pair.value));
break;
case CharacterCompareType::Script:
has_any_unicode_property = true;
if (!current_lhs_inversion_state())
lhs_unicode_scripts.set(static_cast<Unicode::Script>(pair.value));
else
lhs_negated_unicode_scripts.set(static_cast<Unicode::Script>(pair.value));
break;
case CharacterCompareType::ScriptExtension:
has_any_unicode_property = true;
if (!current_lhs_inversion_state())
lhs_unicode_script_extensions.set(static_cast<Unicode::Script>(pair.value));
else
lhs_negated_unicode_script_extensions.set(static_cast<Unicode::Script>(pair.value));
break;
case CharacterCompareType::Or:
case CharacterCompareType::EndAndOr:
// These are the default behaviour for [...], so we don't need to do anything (unless we add support for 'And' below).
break;
case CharacterCompareType::And:
// FIXME: These are too difficult to handle, so bail out.
return true;
case CharacterCompareType::Undefined:
case CharacterCompareType::RangeExpressionDummy:
// These do not occur in valid bytecode.
VERIFY_NOT_REACHED();
}
}
if (!interpret_compares(lhs, compares))
return true; // We can't interpret this, so we can't optimize it.
if constexpr (REGEX_DEBUG) {
dbgln("lhs ranges:");

View file

@ -64,6 +64,8 @@ public:
struct {
Optional<ByteString> pure_substring_search;
// If populated, the pattern only accepts strings that start with a character in these ranges.
Vector<CharRange> starting_ranges;
bool only_start_of_line = false;
} optimization_data {};
};