From 4136d8d13ebee164e5fdcf2489245f14d63674a9 Mon Sep 17 00:00:00 2001 From: Ali Mohammad Pur Date: Tue, 1 Apr 2025 16:49:07 +0200 Subject: [PATCH] LibRegex: Use an interned string table for capture group names This avoids messing around with unsafe string pointers and removes the only non-FlyString-able user of DeprecatedFlyString. --- Libraries/LibRegex/RegexByteCode.cpp | 1 + Libraries/LibRegex/RegexByteCode.h | 99 ++++++++++++++++++++++++--- Libraries/LibRegex/RegexDebug.h | 2 +- Libraries/LibRegex/RegexOptimizer.cpp | 2 + Libraries/LibRegex/RegexParser.cpp | 13 ++-- Libraries/LibRegex/RegexParser.h | 10 +-- 6 files changed, 103 insertions(+), 24 deletions(-) diff --git a/Libraries/LibRegex/RegexByteCode.cpp b/Libraries/LibRegex/RegexByteCode.cpp index 6970ca40eb6..a8ae99b735b 100644 --- a/Libraries/LibRegex/RegexByteCode.cpp +++ b/Libraries/LibRegex/RegexByteCode.cpp @@ -161,6 +161,7 @@ static bool restore_string_position(MatchInput const& input, MatchState& state) OwnPtr ByteCode::s_opcodes[(size_t)OpCodeId::Last + 1]; bool ByteCode::s_opcodes_initialized { false }; size_t ByteCode::s_next_checkpoint_serial_id { 0 }; +u32 StringTable::next_serial { 0 }; void ByteCode::ensure_opcodes_initialized() { diff --git a/Libraries/LibRegex/RegexByteCode.h b/Libraries/LibRegex/RegexByteCode.h index b9a8cdb57b2..322c4c210a5 100644 --- a/Libraries/LibRegex/RegexByteCode.h +++ b/Libraries/LibRegex/RegexByteCode.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -141,6 +142,50 @@ struct CompareTypeAndValuePair { class OpCode; +struct StringTable { + StringTable() + : m_serial(next_serial++) + { + } + StringTable(StringTable const&) = default; + StringTable(StringTable&&) = default; + + ~StringTable() + { + if (m_serial == next_serial - 1 && m_table.is_empty()) + --next_serial; // We didn't use this serial, put it back. + } + + StringTable& operator=(StringTable const&) = default; + StringTable& operator=(StringTable&&) = default; + + ByteCodeValueType set(FlyString string) + { + u32 local_index = m_table.size() + 0x4242; + ByteCodeValueType global_index; + if (auto maybe_local_index = m_table.get(string); maybe_local_index.has_value()) { + local_index = maybe_local_index.value(); + global_index = static_cast(m_serial) << 32 | static_cast(local_index); + } else { + global_index = static_cast(m_serial) << 32 | static_cast(local_index); + m_table.set(string, global_index); + m_inverse_table.set(global_index, string); + } + + return global_index; + } + + FlyString get(ByteCodeValueType index) const + { + return m_inverse_table.get(index).value(); + } + + static u32 next_serial; + u32 m_serial { 0 }; + HashMap m_table; + HashMap m_inverse_table; +}; + class ByteCode : public DisjointChunks { using Base = DisjointChunks; @@ -153,14 +198,27 @@ public: ByteCode(ByteCode const&) = default; ByteCode(ByteCode&&) = default; + ByteCode(Base&&) = delete; + ByteCode(Base const&) = delete; + virtual ~ByteCode() = default; ByteCode& operator=(ByteCode const&) = default; ByteCode& operator=(ByteCode&&) = default; - ByteCode& operator=(Base&& value) + + ByteCode& operator=(Base&& value) = delete; + ByteCode& operator=(Base const& value) = delete; + + void extend(ByteCode&& other) { - static_cast(*this) = move(value); - return *this; + merge_string_tables_from({ &other, 1 }); + Base::extend(move(other)); + } + + void extend(ByteCode const& other) + { + merge_string_tables_from({ &other, 1 }); + Base::extend(other); } template @@ -202,9 +260,28 @@ public: Base::last_chunk().ensure_capacity(capacity); } + FlyString get_string(size_t index) const { return m_string_table.get(index); } + void last_chunk() const = delete; void first_chunk() const = delete; + void merge_string_tables_from(Span others) + { + for (auto const& other : others) { + for (auto const& entry : other.m_string_table.m_table) { + auto const result = m_string_table.m_inverse_table.set(entry.value, entry.key); + if (result != HashSetResult::InsertedNewEntry) { + if (m_string_table.m_inverse_table.get(entry.value) == entry.key) // Already in inverse table. + continue; + dbgln("StringTable: Detected ID clash in string tables! ID {} seems to be reused", entry.value); + dbgln("Old: {}, New: {}", m_string_table.m_inverse_table.get(entry.value), entry.key); + VERIFY_NOT_REACHED(); + } + m_string_table.m_table.set(entry.key, entry.value); + } + } + } + void insert_bytecode_compare_values(Vector&& pairs) { Optimizer::append_character_class(*this, move(pairs)); @@ -246,11 +323,10 @@ public: empend(capture_groups_count); } - void insert_bytecode_group_capture_right(size_t capture_groups_count, StringView name) + void insert_bytecode_group_capture_right(size_t capture_groups_count, FlyString name) { empend(static_cast(OpCodeId::SaveRightNamedCaptureGroup)); - empend(reinterpret_cast(name.characters_without_null_termination())); - empend(name.length()); + empend(m_string_table.set(move(name))); empend(capture_groups_count); } @@ -541,6 +617,7 @@ private: static OwnPtr s_opcodes[(size_t)OpCodeId::Last + 1]; static bool s_opcodes_initialized; static size_t s_next_checkpoint_serial_id; + StringTable m_string_table; }; #define ENUMERATE_EXECUTION_RESULTS \ @@ -758,13 +835,13 @@ class OpCode_SaveRightNamedCaptureGroup final : public OpCode { public: ExecutionResult execute(MatchInput const& input, MatchState& state) const override; ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::SaveRightNamedCaptureGroup; } - ALWAYS_INLINE size_t size() const override { return 4; } - ALWAYS_INLINE StringView name() const { return { reinterpret_cast(argument(0)), length() }; } - ALWAYS_INLINE size_t length() const { return argument(1); } - ALWAYS_INLINE size_t id() const { return argument(2); } + ALWAYS_INLINE size_t size() const override { return 3; } + ALWAYS_INLINE FlyString name() const { return m_bytecode->get_string(argument(0)); } + ALWAYS_INLINE size_t length() const { return name().bytes_as_string_view().length(); } + ALWAYS_INLINE size_t id() const { return argument(1); } ByteString arguments_string() const override { - return ByteString::formatted("name={}, length={}", name(), length()); + return ByteString::formatted("name_id={}, id={}", argument(0), id()); } }; diff --git a/Libraries/LibRegex/RegexDebug.h b/Libraries/LibRegex/RegexDebug.h index 165c0ca3ec2..e5a8be10bbd 100644 --- a/Libraries/LibRegex/RegexDebug.h +++ b/Libraries/LibRegex/RegexDebug.h @@ -26,7 +26,7 @@ public: auto& bytecode = regex.parser_result.bytecode; size_t index { 0 }; for (auto& value : bytecode) { - outln(m_file, "OpCode i={:3} [{:#02X}]", index, (u32)value); + outln(m_file, "OpCode i={:3} [{:#02X}]", index, value); ++index; } } diff --git a/Libraries/LibRegex/RegexOptimizer.cpp b/Libraries/LibRegex/RegexOptimizer.cpp index b48b16b6857..5a8d083fea8 100644 --- a/Libraries/LibRegex/RegexOptimizer.cpp +++ b/Libraries/LibRegex/RegexOptimizer.cpp @@ -999,6 +999,8 @@ void Optimizer::append_alternation(ByteCode& target, Span alternatives if (alternatives.size() == 0) return; + target.merge_string_tables_from(alternatives); + if (alternatives.size() == 1) return target.extend(move(alternatives[0])); diff --git a/Libraries/LibRegex/RegexParser.cpp b/Libraries/LibRegex/RegexParser.cpp index e45b9c95a2e..fc6156fcc34 100644 --- a/Libraries/LibRegex/RegexParser.cpp +++ b/Libraries/LibRegex/RegexParser.cpp @@ -11,7 +11,6 @@ #include #include #include -#include #include #include #include @@ -813,7 +812,7 @@ ALWAYS_INLINE bool PosixExtendedParser::parse_sub_expression(ByteCode& stack, si NegativeLookbehind, } group_mode { Normal }; consume(); - Optional capture_group_name; + Optional capture_group_name; bool prevent_capture_group = false; if (match(TokenType::Questionmark)) { consume(); @@ -836,7 +835,7 @@ ALWAYS_INLINE bool PosixExtendedParser::parse_sub_expression(ByteCode& stack, si ++capture_group_name_length; last_token = consume(); } - capture_group_name = StringView(start_token.value().characters_without_null_termination(), capture_group_name_length); + capture_group_name = MUST(FlyString::from_utf8(m_parser_state.lexer.input().substring_view_starting_from_substring(start_token.value()).substring_view(0, capture_group_name_length))); ++m_parser_state.named_capture_groups_count; } else if (match(TokenType::EqualSign)) { // positive lookahead @@ -982,7 +981,7 @@ bool ECMA262Parser::parse_pattern(ByteCode& stack, size_t& match_length_minimum, return parse_disjunction(stack, match_length_minimum, flags); } -bool ECMA262Parser::has_duplicate_in_current_alternative(DeprecatedFlyString const& name) +bool ECMA262Parser::has_duplicate_in_current_alternative(FlyString const& name) { auto it = m_parser_state.named_capture_groups.find(name); if (it == m_parser_state.named_capture_groups.end()) @@ -2503,7 +2502,7 @@ bool ECMA262Parser::parse_unicode_property_escape(PropertyEscape& property, bool [](Empty&) -> bool { VERIFY_NOT_REACHED(); }); } -DeprecatedFlyString ECMA262Parser::read_capture_group_specifier(bool take_starting_angle_bracket) +FlyString ECMA262Parser::read_capture_group_specifier(bool take_starting_angle_bracket) { static constexpr u32 const REPLACEMENT_CHARACTER = 0xFFFD; constexpr u32 const ZERO_WIDTH_NON_JOINER { 0x200C }; @@ -2604,7 +2603,7 @@ DeprecatedFlyString ECMA262Parser::read_capture_group_specifier(bool take_starti builder.append_code_point(code_point); } - DeprecatedFlyString name = builder.to_byte_string(); + auto name = MUST(builder.to_fly_string()); if (!hit_end || name.is_empty()) set_error(Error::InvalidNameForCaptureGroup); @@ -2720,7 +2719,7 @@ bool ECMA262Parser::parse_capture_group(ByteCode& stack, size_t& match_length_mi stack.insert_bytecode_group_capture_left(group_index); stack.extend(move(capture_group_bytecode)); - stack.insert_bytecode_group_capture_right(group_index, name.view()); + stack.insert_bytecode_group_capture_right(group_index, name); match_length_minimum += length; diff --git a/Libraries/LibRegex/RegexParser.h b/Libraries/LibRegex/RegexParser.h index eb8a8bce8be..ae9b97875bb 100644 --- a/Libraries/LibRegex/RegexParser.h +++ b/Libraries/LibRegex/RegexParser.h @@ -11,7 +11,7 @@ #include "RegexLexer.h" #include "RegexOptions.h" -#include +#include #include #include #include @@ -59,7 +59,7 @@ public: size_t match_length_minimum; Error error; Token error_token; - Vector capture_groups; + Vector capture_groups; AllOptions options; struct { @@ -117,7 +117,7 @@ protected: size_t repetition_mark_count { 0 }; AllOptions regex_options; HashMap capture_group_minimum_lengths; - HashMap> named_capture_groups; + HashMap> named_capture_groups; explicit ParserState(Lexer& lexer) : lexer(lexer) @@ -240,7 +240,7 @@ private: }; StringView read_digits_as_string(ReadDigitsInitialZeroState initial_zero = ReadDigitsInitialZeroState::Allow, bool hex = false, int max_count = -1, int min_count = -1); Optional read_digits(ReadDigitsInitialZeroState initial_zero = ReadDigitsInitialZeroState::Allow, bool hex = false, int max_count = -1, int min_count = -1); - DeprecatedFlyString read_capture_group_specifier(bool take_starting_angle_bracket = false); + FlyString read_capture_group_specifier(bool take_starting_angle_bracket = false); struct Script { Unicode::Script script {}; @@ -282,7 +282,7 @@ private: bool parse_invalid_braced_quantifier(); // Note: This function either parses and *fails*, or doesn't parse anything and returns false. Optional parse_legacy_octal_escape(); - bool has_duplicate_in_current_alternative(DeprecatedFlyString const& name); + bool has_duplicate_in_current_alternative(FlyString const& name); size_t ensure_total_number_of_capturing_parenthesis();