LibRegex: Use an interned string table for capture group names

This avoids messing around with unsafe string pointers and removes the
only non-FlyString-able user of DeprecatedFlyString.
This commit is contained in:
Ali Mohammad Pur 2025-04-01 16:49:07 +02:00 committed by Andreas Kling
commit 4136d8d13e
Notes: github-actions[bot] 2025-04-02 09:44:16 +00:00
6 changed files with 103 additions and 24 deletions

View file

@ -12,6 +12,7 @@
#include <AK/Concepts.h>
#include <AK/DisjointChunks.h>
#include <AK/Forward.h>
#include <AK/HashMap.h>
#include <AK/OwnPtr.h>
#include <AK/TypeCasts.h>
#include <AK/Types.h>
@ -141,6 +142,50 @@ struct CompareTypeAndValuePair {
class OpCode;
struct StringTable {
StringTable()
: m_serial(next_serial++)
{
}
StringTable(StringTable const&) = default;
StringTable(StringTable&&) = default;
~StringTable()
{
if (m_serial == next_serial - 1 && m_table.is_empty())
--next_serial; // We didn't use this serial, put it back.
}
StringTable& operator=(StringTable const&) = default;
StringTable& operator=(StringTable&&) = default;
ByteCodeValueType set(FlyString string)
{
u32 local_index = m_table.size() + 0x4242;
ByteCodeValueType global_index;
if (auto maybe_local_index = m_table.get(string); maybe_local_index.has_value()) {
local_index = maybe_local_index.value();
global_index = static_cast<ByteCodeValueType>(m_serial) << 32 | static_cast<ByteCodeValueType>(local_index);
} else {
global_index = static_cast<ByteCodeValueType>(m_serial) << 32 | static_cast<ByteCodeValueType>(local_index);
m_table.set(string, global_index);
m_inverse_table.set(global_index, string);
}
return global_index;
}
FlyString get(ByteCodeValueType index) const
{
return m_inverse_table.get(index).value();
}
static u32 next_serial;
u32 m_serial { 0 };
HashMap<FlyString, ByteCodeValueType> m_table;
HashMap<ByteCodeValueType, FlyString> m_inverse_table;
};
class ByteCode : public DisjointChunks<ByteCodeValueType> {
using Base = DisjointChunks<ByteCodeValueType>;
@ -153,14 +198,27 @@ public:
ByteCode(ByteCode const&) = default;
ByteCode(ByteCode&&) = default;
ByteCode(Base&&) = delete;
ByteCode(Base const&) = delete;
virtual ~ByteCode() = default;
ByteCode& operator=(ByteCode const&) = default;
ByteCode& operator=(ByteCode&&) = default;
ByteCode& operator=(Base&& value)
ByteCode& operator=(Base&& value) = delete;
ByteCode& operator=(Base const& value) = delete;
void extend(ByteCode&& other)
{
static_cast<Base&>(*this) = move(value);
return *this;
merge_string_tables_from({ &other, 1 });
Base::extend(move(other));
}
void extend(ByteCode const& other)
{
merge_string_tables_from({ &other, 1 });
Base::extend(other);
}
template<typename... Args>
@ -202,9 +260,28 @@ public:
Base::last_chunk().ensure_capacity(capacity);
}
FlyString get_string(size_t index) const { return m_string_table.get(index); }
void last_chunk() const = delete;
void first_chunk() const = delete;
void merge_string_tables_from(Span<ByteCode const> others)
{
for (auto const& other : others) {
for (auto const& entry : other.m_string_table.m_table) {
auto const result = m_string_table.m_inverse_table.set(entry.value, entry.key);
if (result != HashSetResult::InsertedNewEntry) {
if (m_string_table.m_inverse_table.get(entry.value) == entry.key) // Already in inverse table.
continue;
dbgln("StringTable: Detected ID clash in string tables! ID {} seems to be reused", entry.value);
dbgln("Old: {}, New: {}", m_string_table.m_inverse_table.get(entry.value), entry.key);
VERIFY_NOT_REACHED();
}
m_string_table.m_table.set(entry.key, entry.value);
}
}
}
void insert_bytecode_compare_values(Vector<CompareTypeAndValuePair>&& pairs)
{
Optimizer::append_character_class(*this, move(pairs));
@ -246,11 +323,10 @@ public:
empend(capture_groups_count);
}
void insert_bytecode_group_capture_right(size_t capture_groups_count, StringView name)
void insert_bytecode_group_capture_right(size_t capture_groups_count, FlyString name)
{
empend(static_cast<ByteCodeValueType>(OpCodeId::SaveRightNamedCaptureGroup));
empend(reinterpret_cast<ByteCodeValueType>(name.characters_without_null_termination()));
empend(name.length());
empend(m_string_table.set(move(name)));
empend(capture_groups_count);
}
@ -541,6 +617,7 @@ private:
static OwnPtr<OpCode> s_opcodes[(size_t)OpCodeId::Last + 1];
static bool s_opcodes_initialized;
static size_t s_next_checkpoint_serial_id;
StringTable m_string_table;
};
#define ENUMERATE_EXECUTION_RESULTS \
@ -758,13 +835,13 @@ class OpCode_SaveRightNamedCaptureGroup final : public OpCode {
public:
ExecutionResult execute(MatchInput const& input, MatchState& state) const override;
ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::SaveRightNamedCaptureGroup; }
ALWAYS_INLINE size_t size() const override { return 4; }
ALWAYS_INLINE StringView name() const { return { reinterpret_cast<char*>(argument(0)), length() }; }
ALWAYS_INLINE size_t length() const { return argument(1); }
ALWAYS_INLINE size_t id() const { return argument(2); }
ALWAYS_INLINE size_t size() const override { return 3; }
ALWAYS_INLINE FlyString name() const { return m_bytecode->get_string(argument(0)); }
ALWAYS_INLINE size_t length() const { return name().bytes_as_string_view().length(); }
ALWAYS_INLINE size_t id() const { return argument(1); }
ByteString arguments_string() const override
{
return ByteString::formatted("name={}, length={}", name(), length());
return ByteString::formatted("name_id={}, id={}", argument(0), id());
}
};