LibRegex: Use an interned string table for capture group names

This avoids messing around with unsafe string pointers and removes the
only non-FlyString-able user of DeprecatedFlyString.
This commit is contained in:
Ali Mohammad Pur 2025-04-01 16:49:07 +02:00
parent 6bb0d585e3
commit a35ebd60c5
6 changed files with 103 additions and 24 deletions

View file

@ -161,6 +161,7 @@ static bool restore_string_position(MatchInput const& input, MatchState& state)
OwnPtr<OpCode> ByteCode::s_opcodes[(size_t)OpCodeId::Last + 1];
bool ByteCode::s_opcodes_initialized { false };
size_t ByteCode::s_next_checkpoint_serial_id { 0 };
u32 StringTable::next_serial { 0 };
void ByteCode::ensure_opcodes_initialized()
{

View file

@ -12,6 +12,7 @@
#include <AK/Concepts.h>
#include <AK/DisjointChunks.h>
#include <AK/Forward.h>
#include <AK/HashMap.h>
#include <AK/OwnPtr.h>
#include <AK/TypeCasts.h>
#include <AK/Types.h>
@ -141,6 +142,50 @@ struct CompareTypeAndValuePair {
class OpCode;
struct StringTable {
StringTable()
: m_serial(next_serial++)
{
}
StringTable(StringTable const&) = default;
StringTable(StringTable&&) = default;
~StringTable()
{
if (m_serial == next_serial - 1 && m_table.is_empty())
--next_serial; // We didn't use this serial, put it back.
}
StringTable& operator=(StringTable const&) = default;
StringTable& operator=(StringTable&&) = default;
ByteCodeValueType set(FlyString string)
{
u32 local_index = m_table.size() + 0x4242;
ByteCodeValueType global_index;
if (auto maybe_local_index = m_table.get(string); maybe_local_index.has_value()) {
local_index = maybe_local_index.value();
global_index = static_cast<ByteCodeValueType>(m_serial) << 32 | static_cast<ByteCodeValueType>(local_index);
} else {
global_index = static_cast<ByteCodeValueType>(m_serial) << 32 | static_cast<ByteCodeValueType>(local_index);
m_table.set(string, global_index);
m_inverse_table.set(global_index, string);
}
return global_index;
}
FlyString get(ByteCodeValueType index) const
{
return m_inverse_table.get(index).value();
}
static u32 next_serial;
u32 m_serial { 0 };
HashMap<FlyString, ByteCodeValueType> m_table;
HashMap<ByteCodeValueType, FlyString> m_inverse_table;
};
class ByteCode : public DisjointChunks<ByteCodeValueType> {
using Base = DisjointChunks<ByteCodeValueType>;
@ -153,14 +198,27 @@ public:
ByteCode(ByteCode const&) = default;
ByteCode(ByteCode&&) = default;
ByteCode(Base&&) = delete;
ByteCode(Base const&) = delete;
virtual ~ByteCode() = default;
ByteCode& operator=(ByteCode const&) = default;
ByteCode& operator=(ByteCode&&) = default;
ByteCode& operator=(Base&& value)
ByteCode& operator=(Base&& value) = delete;
ByteCode& operator=(Base const& value) = delete;
void extend(ByteCode&& other)
{
static_cast<Base&>(*this) = move(value);
return *this;
merge_string_tables_from({ &other, 1 });
Base::extend(move(other));
}
void extend(ByteCode const& other)
{
merge_string_tables_from({ &other, 1 });
Base::extend(other);
}
template<typename... Args>
@ -202,9 +260,28 @@ public:
Base::last_chunk().ensure_capacity(capacity);
}
FlyString get_string(size_t index) const { return m_string_table.get(index); }
void last_chunk() const = delete;
void first_chunk() const = delete;
void merge_string_tables_from(Span<ByteCode const> others)
{
for (auto const& other : others) {
for (auto const& entry : other.m_string_table.m_table) {
auto const result = m_string_table.m_inverse_table.set(entry.value, entry.key);
if (result != HashSetResult::InsertedNewEntry) {
if (m_string_table.m_inverse_table.get(entry.value) == entry.key) // Already in inverse table.
continue;
dbgln("StringTable: Detected ID clash in string tables! ID {} seems to be reused", entry.value);
dbgln("Old: {}, New: {}", m_string_table.m_inverse_table.get(entry.value), entry.key);
VERIFY_NOT_REACHED();
}
m_string_table.m_table.set(entry.key, entry.value);
}
}
}
void insert_bytecode_compare_values(Vector<CompareTypeAndValuePair>&& pairs)
{
Optimizer::append_character_class(*this, move(pairs));
@ -246,11 +323,10 @@ public:
empend(capture_groups_count);
}
void insert_bytecode_group_capture_right(size_t capture_groups_count, StringView name)
void insert_bytecode_group_capture_right(size_t capture_groups_count, FlyString name)
{
empend(static_cast<ByteCodeValueType>(OpCodeId::SaveRightNamedCaptureGroup));
empend(reinterpret_cast<ByteCodeValueType>(name.characters_without_null_termination()));
empend(name.length());
empend(m_string_table.set(move(name)));
empend(capture_groups_count);
}
@ -541,6 +617,7 @@ private:
static OwnPtr<OpCode> s_opcodes[(size_t)OpCodeId::Last + 1];
static bool s_opcodes_initialized;
static size_t s_next_checkpoint_serial_id;
StringTable m_string_table;
};
#define ENUMERATE_EXECUTION_RESULTS \
@ -758,13 +835,13 @@ class OpCode_SaveRightNamedCaptureGroup final : public OpCode {
public:
ExecutionResult execute(MatchInput const& input, MatchState& state) const override;
ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::SaveRightNamedCaptureGroup; }
ALWAYS_INLINE size_t size() const override { return 4; }
ALWAYS_INLINE StringView name() const { return { reinterpret_cast<char*>(argument(0)), length() }; }
ALWAYS_INLINE size_t length() const { return argument(1); }
ALWAYS_INLINE size_t id() const { return argument(2); }
ALWAYS_INLINE size_t size() const override { return 3; }
ALWAYS_INLINE FlyString name() const { return m_bytecode->get_string(argument(0)); }
ALWAYS_INLINE size_t length() const { return name().bytes_as_string_view().length(); }
ALWAYS_INLINE size_t id() const { return argument(1); }
ByteString arguments_string() const override
{
return ByteString::formatted("name={}, length={}", name(), length());
return ByteString::formatted("name_id={}, id={}", argument(0), id());
}
};

View file

@ -26,7 +26,7 @@ public:
auto& bytecode = regex.parser_result.bytecode;
size_t index { 0 };
for (auto& value : bytecode) {
outln(m_file, "OpCode i={:3} [{:#02X}]", index, (u32)value);
outln(m_file, "OpCode i={:3} [{:#02X}]", index, value);
++index;
}
}

View file

@ -999,6 +999,8 @@ void Optimizer::append_alternation(ByteCode& target, Span<ByteCode> alternatives
if (alternatives.size() == 0)
return;
target.merge_string_tables_from(alternatives);
if (alternatives.size() == 1)
return target.extend(move(alternatives[0]));

View file

@ -11,7 +11,6 @@
#include <AK/ByteString.h>
#include <AK/CharacterTypes.h>
#include <AK/Debug.h>
#include <AK/DeprecatedFlyString.h>
#include <AK/GenericLexer.h>
#include <AK/ScopeGuard.h>
#include <AK/StringBuilder.h>
@ -813,7 +812,7 @@ ALWAYS_INLINE bool PosixExtendedParser::parse_sub_expression(ByteCode& stack, si
NegativeLookbehind,
} group_mode { Normal };
consume();
Optional<StringView> capture_group_name;
Optional<FlyString> capture_group_name;
bool prevent_capture_group = false;
if (match(TokenType::Questionmark)) {
consume();
@ -836,7 +835,7 @@ ALWAYS_INLINE bool PosixExtendedParser::parse_sub_expression(ByteCode& stack, si
++capture_group_name_length;
last_token = consume();
}
capture_group_name = StringView(start_token.value().characters_without_null_termination(), capture_group_name_length);
capture_group_name = MUST(FlyString::from_utf8(m_parser_state.lexer.input().substring_view_starting_from_substring(start_token.value()).substring_view(0, capture_group_name_length)));
++m_parser_state.named_capture_groups_count;
} else if (match(TokenType::EqualSign)) { // positive lookahead
@ -982,7 +981,7 @@ bool ECMA262Parser::parse_pattern(ByteCode& stack, size_t& match_length_minimum,
return parse_disjunction(stack, match_length_minimum, flags);
}
bool ECMA262Parser::has_duplicate_in_current_alternative(DeprecatedFlyString const& name)
bool ECMA262Parser::has_duplicate_in_current_alternative(FlyString const& name)
{
auto it = m_parser_state.named_capture_groups.find(name);
if (it == m_parser_state.named_capture_groups.end())
@ -2503,7 +2502,7 @@ bool ECMA262Parser::parse_unicode_property_escape(PropertyEscape& property, bool
[](Empty&) -> bool { VERIFY_NOT_REACHED(); });
}
DeprecatedFlyString ECMA262Parser::read_capture_group_specifier(bool take_starting_angle_bracket)
FlyString ECMA262Parser::read_capture_group_specifier(bool take_starting_angle_bracket)
{
static constexpr u32 const REPLACEMENT_CHARACTER = 0xFFFD;
constexpr u32 const ZERO_WIDTH_NON_JOINER { 0x200C };
@ -2604,7 +2603,7 @@ DeprecatedFlyString ECMA262Parser::read_capture_group_specifier(bool take_starti
builder.append_code_point(code_point);
}
DeprecatedFlyString name = builder.to_byte_string();
auto name = MUST(builder.to_fly_string());
if (!hit_end || name.is_empty())
set_error(Error::InvalidNameForCaptureGroup);
@ -2720,7 +2719,7 @@ bool ECMA262Parser::parse_capture_group(ByteCode& stack, size_t& match_length_mi
stack.insert_bytecode_group_capture_left(group_index);
stack.extend(move(capture_group_bytecode));
stack.insert_bytecode_group_capture_right(group_index, name.view());
stack.insert_bytecode_group_capture_right(group_index, name);
match_length_minimum += length;

View file

@ -11,7 +11,7 @@
#include "RegexLexer.h"
#include "RegexOptions.h"
#include <AK/DeprecatedFlyString.h>
#include <AK/FlyString.h>
#include <AK/Forward.h>
#include <AK/HashMap.h>
#include <AK/Types.h>
@ -59,7 +59,7 @@ public:
size_t match_length_minimum;
Error error;
Token error_token;
Vector<DeprecatedFlyString> capture_groups;
Vector<FlyString> capture_groups;
AllOptions options;
struct {
@ -117,7 +117,7 @@ protected:
size_t repetition_mark_count { 0 };
AllOptions regex_options;
HashMap<size_t, size_t> capture_group_minimum_lengths;
HashMap<DeprecatedFlyString, Vector<NamedCaptureGroup>> named_capture_groups;
HashMap<FlyString, Vector<NamedCaptureGroup>> named_capture_groups;
explicit ParserState(Lexer& lexer)
: lexer(lexer)
@ -240,7 +240,7 @@ private:
};
StringView read_digits_as_string(ReadDigitsInitialZeroState initial_zero = ReadDigitsInitialZeroState::Allow, bool hex = false, int max_count = -1, int min_count = -1);
Optional<unsigned> read_digits(ReadDigitsInitialZeroState initial_zero = ReadDigitsInitialZeroState::Allow, bool hex = false, int max_count = -1, int min_count = -1);
DeprecatedFlyString read_capture_group_specifier(bool take_starting_angle_bracket = false);
FlyString read_capture_group_specifier(bool take_starting_angle_bracket = false);
struct Script {
Unicode::Script script {};
@ -282,7 +282,7 @@ private:
bool parse_invalid_braced_quantifier(); // Note: This function either parses and *fails*, or doesn't parse anything and returns false.
Optional<u8> parse_legacy_octal_escape();
bool has_duplicate_in_current_alternative(DeprecatedFlyString const& name);
bool has_duplicate_in_current_alternative(FlyString const& name);
size_t ensure_total_number_of_capturing_parenthesis();