LibWasm: Ensure correct section ordering when parsing binary modules

There are (currently) no spec-tests ensuring that section ordering is
enforced, but it _is_ a part of the spec. A pull request to add this to
the specification testsuite has been opened at WebAssembly/spec#1775.
This commit is contained in:
Diego Frias 2024-08-09 17:15:45 -07:00 committed by Andreas Kling
commit c58665332e
Notes: github-actions[bot] 2024-08-10 08:40:31 +00:00
2 changed files with 91 additions and 45 deletions

View file

@ -1213,6 +1213,41 @@ ParseResult<DataCountSection> DataCountSection::parse([[maybe_unused]] Stream& s
return DataCountSection { value }; return DataCountSection { value };
} }
ParseResult<SectionId> SectionId::parse(Stream& stream)
{
u8 id = TRY_READ(stream, u8, ParseError::ExpectedIndex);
switch (id) {
case 0x00:
return SectionId(SectionIdKind::Custom);
case 0x01:
return SectionId(SectionIdKind::Type);
case 0x02:
return SectionId(SectionIdKind::Import);
case 0x03:
return SectionId(SectionIdKind::Function);
case 0x04:
return SectionId(SectionIdKind::Table);
case 0x05:
return SectionId(SectionIdKind::Memory);
case 0x06:
return SectionId(SectionIdKind::Global);
case 0x07:
return SectionId(SectionIdKind::Export);
case 0x08:
return SectionId(SectionIdKind::Start);
case 0x09:
return SectionId(SectionIdKind::Element);
case 0x0a:
return SectionId(SectionIdKind::Code);
case 0x0b:
return SectionId(SectionIdKind::Data);
case 0x0c:
return SectionId(SectionIdKind::DataCount);
default:
return ParseError::InvalidIndex;
}
}
ParseResult<Module> Module::parse(Stream& stream) ParseResult<Module> Module::parse(Stream& stream)
{ {
ScopeLogger<WASM_BINPARSER_DEBUG> logger("Module"sv); ScopeLogger<WASM_BINPARSER_DEBUG> logger("Module"sv);
@ -1227,61 +1262,64 @@ ParseResult<Module> Module::parse(Stream& stream)
if (Bytes { buf, 4 } != wasm_version.span()) if (Bytes { buf, 4 } != wasm_version.span())
return with_eof_check(stream, ParseError::InvalidModuleVersion); return with_eof_check(stream, ParseError::InvalidModuleVersion);
auto last_section_id = CustomSection::section_id; auto last_section_id = SectionId::SectionIdKind::Custom;
Module module; Module module;
while (!stream.is_eof()) { while (!stream.is_eof()) {
auto section_id = TRY_READ(stream, u8, ParseError::ExpectedIndex); auto section_id = TRY(SectionId::parse(stream));
size_t section_size = TRY_READ(stream, LEB128<u32>, ParseError::ExpectedSize); size_t section_size = TRY_READ(stream, LEB128<u32>, ParseError::ExpectedSize);
auto section_stream = ConstrainedStream { MaybeOwned<Stream>(stream), section_size }; auto section_stream = ConstrainedStream { MaybeOwned<Stream>(stream), section_size };
if (section_id != CustomSection::section_id && section_id == last_section_id) if (section_id.kind() != SectionId::SectionIdKind::Custom && section_id.kind() == last_section_id)
return ParseError::DuplicateSection; return ParseError::DuplicateSection;
switch (section_id) { switch (section_id.kind()) {
case CustomSection::section_id: case SectionId::SectionIdKind::Custom:
module.custom_sections().append(TRY(CustomSection::parse(section_stream))); module.custom_sections().append(TRY(CustomSection::parse(section_stream)));
break; break;
case TypeSection::section_id: case SectionId::SectionIdKind::Type:
module.type_section() = TRY(TypeSection::parse(section_stream)); module.type_section() = TRY(TypeSection::parse(section_stream));
break; break;
case ImportSection::section_id: case SectionId::SectionIdKind::Import:
module.import_section() = TRY(ImportSection::parse(section_stream)); module.import_section() = TRY(ImportSection::parse(section_stream));
break; break;
case FunctionSection::section_id: case SectionId::SectionIdKind::Function:
module.function_section() = TRY(FunctionSection::parse(section_stream)); module.function_section() = TRY(FunctionSection::parse(section_stream));
break; break;
case TableSection::section_id: case SectionId::SectionIdKind::Table:
module.table_section() = TRY(TableSection::parse(section_stream)); module.table_section() = TRY(TableSection::parse(section_stream));
break; break;
case MemorySection::section_id: case SectionId::SectionIdKind::Memory:
module.memory_section() = TRY(MemorySection::parse(section_stream)); module.memory_section() = TRY(MemorySection::parse(section_stream));
break; break;
case GlobalSection::section_id: case SectionId::SectionIdKind::Global:
module.global_section() = TRY(GlobalSection::parse(section_stream)); module.global_section() = TRY(GlobalSection::parse(section_stream));
break; break;
case ExportSection::section_id: case SectionId::SectionIdKind::Export:
module.export_section() = TRY(ExportSection::parse(section_stream)); module.export_section() = TRY(ExportSection::parse(section_stream));
break; break;
case StartSection::section_id: case SectionId::SectionIdKind::Start:
module.start_section() = TRY(StartSection::parse(section_stream)); module.start_section() = TRY(StartSection::parse(section_stream));
break; break;
case ElementSection::section_id: case SectionId::SectionIdKind::Element:
module.element_section() = TRY(ElementSection::parse(section_stream)); module.element_section() = TRY(ElementSection::parse(section_stream));
break; break;
case CodeSection::section_id: case SectionId::SectionIdKind::Code:
module.code_section() = TRY(CodeSection::parse(section_stream)); module.code_section() = TRY(CodeSection::parse(section_stream));
break; break;
case DataSection::section_id: case SectionId::SectionIdKind::Data:
module.data_section() = TRY(DataSection::parse(section_stream)); module.data_section() = TRY(DataSection::parse(section_stream));
break; break;
case DataCountSection::section_id: case SectionId::SectionIdKind::DataCount:
module.data_count_section() = TRY(DataCountSection::parse(section_stream)); module.data_count_section() = TRY(DataCountSection::parse(section_stream));
break; break;
default: default:
return ParseError::InvalidIndex; return ParseError::InvalidIndex;
} }
if (section_id != CustomSection::section_id) if (section_id.kind() != SectionId::SectionIdKind::Custom) {
last_section_id = section_id; if (section_id.kind() < last_section_id)
return ParseError::SectionOutOfOrder;
last_section_id = section_id.kind();
}
if (section_stream.remaining() != 0) if (section_stream.remaining() != 0)
return ParseError::SectionSizeMismatch; return ParseError::SectionSizeMismatch;
} }
@ -1334,6 +1372,8 @@ ByteString parse_error_to_byte_string(ParseError error)
return "A parsed instruction was not known to this parser"; return "A parsed instruction was not known to this parser";
case ParseError::DuplicateSection: case ParseError::DuplicateSection:
return "Two sections of the same type were encountered"; return "Two sections of the same type were encountered";
case ParseError::SectionOutOfOrder:
return "A section encountered was not in the correct ordering";
} }
return "Unknown error"; return "Unknown error";
} }

View file

@ -57,6 +57,7 @@ enum class ParseError {
SectionSizeMismatch, SectionSizeMismatch,
InvalidUtf8, InvalidUtf8,
DuplicateSection, DuplicateSection,
SectionOutOfOrder,
}; };
ByteString parse_error_to_byte_string(ParseError); ByteString parse_error_to_byte_string(ParseError);
@ -499,10 +500,39 @@ private:
m_arguments; m_arguments;
}; };
struct SectionId {
public:
enum class SectionIdKind : u8 {
Custom,
Type,
Import,
Function,
Table,
Memory,
Global,
Export,
Start,
Element,
DataCount,
Code,
Data,
};
explicit SectionId(SectionIdKind kind)
: m_kind(kind)
{
}
SectionIdKind kind() const { return m_kind; }
static ParseResult<SectionId> parse(Stream& stream);
private:
SectionIdKind m_kind;
};
class CustomSection { class CustomSection {
public: public:
static constexpr u8 section_id = 0;
CustomSection(ByteString name, ByteBuffer contents) CustomSection(ByteString name, ByteBuffer contents)
: m_name(move(name)) : m_name(move(name))
, m_contents(move(contents)) , m_contents(move(contents))
@ -521,8 +551,6 @@ private:
class TypeSection { class TypeSection {
public: public:
static constexpr u8 section_id = 1;
TypeSection() = default; TypeSection() = default;
explicit TypeSection(Vector<FunctionType> types) explicit TypeSection(Vector<FunctionType> types)
@ -570,8 +598,6 @@ public:
}; };
public: public:
static constexpr u8 section_id = 2;
ImportSection() = default; ImportSection() = default;
explicit ImportSection(Vector<Import> imports) explicit ImportSection(Vector<Import> imports)
@ -589,8 +615,6 @@ private:
class FunctionSection { class FunctionSection {
public: public:
static constexpr u8 section_id = 3;
FunctionSection() = default; FunctionSection() = default;
explicit FunctionSection(Vector<TypeIndex> types) explicit FunctionSection(Vector<TypeIndex> types)
@ -624,8 +648,6 @@ public:
}; };
public: public:
static constexpr u8 section_id = 4;
TableSection() = default; TableSection() = default;
explicit TableSection(Vector<Table> tables) explicit TableSection(Vector<Table> tables)
@ -659,8 +681,6 @@ public:
}; };
public: public:
static constexpr u8 section_id = 5;
MemorySection() = default; MemorySection() = default;
explicit MemorySection(Vector<Memory> memories) explicit MemorySection(Vector<Memory> memories)
@ -712,8 +732,6 @@ public:
}; };
public: public:
static constexpr u8 section_id = 6;
GlobalSection() = default; GlobalSection() = default;
explicit GlobalSection(Vector<Global> entries) explicit GlobalSection(Vector<Global> entries)
@ -752,8 +770,6 @@ public:
ExportDesc m_description; ExportDesc m_description;
}; };
static constexpr u8 section_id = 7;
ExportSection() = default; ExportSection() = default;
explicit ExportSection(Vector<Export> entries) explicit ExportSection(Vector<Export> entries)
@ -786,8 +802,6 @@ public:
FunctionIndex m_index; FunctionIndex m_index;
}; };
static constexpr u8 section_id = 8;
StartSection() = default; StartSection() = default;
explicit StartSection(Optional<StartFunction> func) explicit StartSection(Optional<StartFunction> func)
@ -822,8 +836,6 @@ public:
Variant<Active, Passive, Declarative> mode; Variant<Active, Passive, Declarative> mode;
}; };
static constexpr u8 section_id = 9;
ElementSection() = default; ElementSection() = default;
explicit ElementSection(Vector<Element> segs) explicit ElementSection(Vector<Element> segs)
@ -896,8 +908,6 @@ public:
Func m_func; Func m_func;
}; };
static constexpr u8 section_id = 10;
CodeSection() = default; CodeSection() = default;
explicit CodeSection(Vector<Code> funcs) explicit CodeSection(Vector<Code> funcs)
@ -940,8 +950,6 @@ public:
Value m_value; Value m_value;
}; };
static constexpr u8 section_id = 11;
DataSection() = default; DataSection() = default;
explicit DataSection(Vector<Data> data) explicit DataSection(Vector<Data> data)
@ -959,8 +967,6 @@ private:
class DataCountSection { class DataCountSection {
public: public:
static constexpr u8 section_id = 12;
DataCountSection() = default; DataCountSection() = default;
explicit DataCountSection(Optional<u32> count) explicit DataCountSection(Optional<u32> count)