diff --git a/Tests/LibWasm/test-wasm.cpp b/Tests/LibWasm/test-wasm.cpp index 7560cac1155..6a601fdedfd 100644 --- a/Tests/LibWasm/test-wasm.cpp +++ b/Tests/LibWasm/test-wasm.cpp @@ -53,7 +53,7 @@ public: Wasm::Module& module() { return *m_module; } Wasm::ModuleInstance& module_instance() { return *m_module_instance; } - static JS::ThrowCompletionOr create(JS::Realm& realm, Wasm::Module module, HashMap const& imports) + static JS::ThrowCompletionOr create(JS::Realm& realm, NonnullRefPtr module, HashMap const& imports) { auto& vm = realm.vm(); auto instance = realm.heap().allocate(realm, realm.intrinsics().object_prototype()); @@ -148,7 +148,7 @@ private: static HashMap s_spec_test_namespace; static Wasm::AbstractMachine m_machine; - Optional m_module; + RefPtr m_module; OwnPtr m_module_instance; }; @@ -379,13 +379,15 @@ JS_DEFINE_NATIVE_FUNCTION(WebAssemblyModule::wasm_invoke) arguments.append(Wasm::Value(bits)); break; } - case Wasm::ValueType::Kind::FunctionReference: + case Wasm::ValueType::Kind::FunctionReference: { if (argument.is_null()) { arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::FunctionReference) } })); break; } - arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Func { static_cast(double_value) } })); + Wasm::FunctionAddress addr = static_cast(double_value); + arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Func { addr, machine().store().get_module_for(addr) } })); break; + } case Wasm::ValueType::Kind::ExternReference: if (argument.is_null()) { arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::ExternReference) } })); diff --git a/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.cpp b/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.cpp index cecbce36246..142e8859454 100644 --- a/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.cpp +++ b/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.cpp @@ -14,14 +14,14 @@ namespace Wasm { -Optional Store::allocate(ModuleInstance& module, CodeSection::Code const& code, TypeIndex type_index) +Optional Store::allocate(ModuleInstance& instance, Module const& module, CodeSection::Code const& code, TypeIndex type_index) { FunctionAddress address { m_functions.size() }; - if (type_index.value() > module.types().size()) + if (type_index.value() > instance.types().size()) return {}; - auto& type = module.types()[type_index.value()]; - m_functions.empend(WasmFunction { type, module, code }); + auto& type = instance.types()[type_index.value()]; + m_functions.empend(WasmFunction { type, instance, module, code }); return address; } @@ -84,6 +84,14 @@ FunctionInstance* Store::get(FunctionAddress address) return &m_functions[value]; } +Module const* Store::get_module_for(Wasm::FunctionAddress address) +{ + auto* function = get(address); + if (!function || function->has()) + return nullptr; + return function->get().module_ref().ptr(); +} + TableInstance* Store::get(TableAddress address) { auto value = address.value(); @@ -223,7 +231,7 @@ InstantiationResult AbstractMachine::instantiate(Module const& module, Vector source_module; // null if host function. }; struct Extern { ExternAddress address; @@ -139,7 +140,7 @@ public: // 2: null funcref // 3: null externref ref.ref().visit( - [&](Reference::Func const& func) { m_value = u128(bit_cast(func.address), 0); }, + [&](Reference::Func const& func) { m_value = u128(bit_cast(func.address), bit_cast(func.source_module.ptr())); }, [&](Reference::Extern const& func) { m_value = u128(bit_cast(func.address), 1); }, [&](Reference::Null const& null) { m_value = u128(0, null.type.kind() == ValueType::Kind::FunctionReference ? 2 : 3); }); } @@ -177,17 +178,15 @@ public: return bit_cast(m_value.low()); } if constexpr (IsSame) { - switch (m_value.high()) { + switch (m_value.high() & 3) { case 0: - return Reference { Reference::Func { bit_cast(m_value.low()) } }; + return Reference { Reference::Func { bit_cast(m_value.low()), bit_cast(m_value.high()) } }; case 1: return Reference { Reference::Extern { bit_cast(m_value.low()) } }; case 2: return Reference { Reference::Null { ValueType(ValueType::Kind::FunctionReference) } }; case 3: return Reference { Reference::Null { ValueType(ValueType::Kind::ExternReference) } }; - default: - VERIFY_NOT_REACHED(); } } VERIFY_NOT_REACHED(); @@ -341,20 +340,23 @@ private: class WasmFunction { public: - explicit WasmFunction(FunctionType const& type, ModuleInstance const& module, CodeSection::Code const& code) + explicit WasmFunction(FunctionType const& type, ModuleInstance const& instance, Module const& module, CodeSection::Code const& code) : m_type(type) - , m_module(module) + , m_module(module.make_weak_ptr()) + , m_module_instance(instance) , m_code(code) { } auto& type() const { return m_type; } - auto& module() const { return m_module; } + auto& module() const { return m_module_instance; } auto& code() const { return m_code; } + RefPtr module_ref() const { return m_module.strong_ref(); } private: FunctionType m_type; - ModuleInstance const& m_module; + WeakPtr m_module; + ModuleInstance const& m_module_instance; CodeSection::Code const& m_code; }; @@ -554,7 +556,7 @@ class Store { public: Store() = default; - Optional allocate(ModuleInstance&, CodeSection::Code const&, TypeIndex); + Optional allocate(ModuleInstance&, Module const&, CodeSection::Code const&, TypeIndex); Optional allocate(HostFunction&&); Optional allocate(TableType const&); Optional allocate(MemoryType const&); @@ -562,6 +564,7 @@ public: Optional allocate(GlobalType const&, Value); Optional allocate(ValueType const&, Vector); + Module const* get_module_for(FunctionAddress); FunctionInstance* get(FunctionAddress); TableInstance* get(TableAddress); MemoryInstance* get(MemoryAddress); diff --git a/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp b/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp index 9e441cb0f94..7e61accaaac 100644 --- a/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp +++ b/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp @@ -864,7 +864,7 @@ ALWAYS_INLINE void BytecodeInterpreter::interpret_instruction(Configuration& con auto index = instruction.arguments().get().value(); auto& functions = configuration.frame().module().functions(); auto address = functions[index]; - configuration.value_stack().append(Value(address.value())); + configuration.value_stack().append(Value(Reference { Reference::Func { address, configuration.store().get_module_for(address) } })); return; } case Instructions::ref_is_null.value(): { diff --git a/Userland/Libraries/LibWasm/Parser/Parser.cpp b/Userland/Libraries/LibWasm/Parser/Parser.cpp index 0a7831618ef..6a1d30fbe81 100644 --- a/Userland/Libraries/LibWasm/Parser/Parser.cpp +++ b/Userland/Libraries/LibWasm/Parser/Parser.cpp @@ -1248,7 +1248,7 @@ ParseResult SectionId::parse(Stream& stream) } } -ParseResult Module::parse(Stream& stream) +ParseResult> Module::parse(Stream& stream) { ScopeLogger logger("Module"sv); u8 buf[4]; @@ -1263,7 +1263,9 @@ ParseResult Module::parse(Stream& stream) return with_eof_check(stream, ParseError::InvalidModuleVersion); auto last_section_id = SectionId::SectionIdKind::Custom; - Module module; + auto module_ptr = make_ref_counted(); + auto& module = *module_ptr; + while (!stream.is_eof()) { auto section_id = TRY(SectionId::parse(stream)); size_t section_size = TRY_READ(stream, LEB128, ParseError::ExpectedSize); @@ -1324,7 +1326,7 @@ ParseResult Module::parse(Stream& stream) return ParseError::SectionSizeMismatch; } - return module; + return module_ptr; } ByteString parse_error_to_byte_string(ParseError error) diff --git a/Userland/Libraries/LibWasm/Types.h b/Userland/Libraries/LibWasm/Types.h index 3aa21516458..fd557ba5967 100644 --- a/Userland/Libraries/LibWasm/Types.h +++ b/Userland/Libraries/LibWasm/Types.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -982,7 +983,8 @@ private: Optional m_count; }; -class Module { +class Module : public RefCounted + , public Weakable { public: enum class ValidationStatus { Unchecked, @@ -1027,7 +1029,7 @@ public: StringView validation_error() const { return *m_validation_error; } void set_validation_error(ByteString error) { m_validation_error = move(error); } - static ParseResult parse(Stream& stream); + static ParseResult> parse(Stream& stream); private: void set_validation_status(ValidationStatus status) { m_validation_status = status; } diff --git a/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.cpp b/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.cpp index bfe89b0bbf6..717847601b5 100644 --- a/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.cpp +++ b/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.cpp @@ -430,7 +430,7 @@ JS::ThrowCompletionOr to_webassembly_value(JS::VM& vm, JS::Value va auto& cache = get_cache(*vm.current_realm()); for (auto& entry : cache.function_instances()) { if (entry.value == &function) - return Wasm::Value { Wasm::Reference { Wasm::Reference::Func { entry.key } } }; + return Wasm::Value { Wasm::Reference { Wasm::Reference::Func { entry.key, cache.abstract_machine().store().get_module_for(entry.key) } } }; } } diff --git a/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.h b/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.h index 41b696513b8..19397b50164 100644 --- a/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.h +++ b/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.h @@ -29,12 +29,12 @@ WebIDL::ExceptionOr instantiate(JS::VM&, Module const& module_object, namespace Detail { struct CompiledWebAssemblyModule : public RefCounted { - explicit CompiledWebAssemblyModule(Wasm::Module&& module) + explicit CompiledWebAssemblyModule(NonnullRefPtr module) : module(move(module)) { } - Wasm::Module module; + NonnullRefPtr module; }; class WebAssemblyCache { diff --git a/Userland/Utilities/wasm.cpp b/Userland/Utilities/wasm.cpp index 246051628e8..1f64387c72b 100644 --- a/Userland/Utilities/wasm.cpp +++ b/Userland/Utilities/wasm.cpp @@ -491,7 +491,7 @@ static bool pre_interpret_hook(Wasm::Configuration& config, Wasm::InstructionPoi } } -static Optional parse(StringView filename) +static RefPtr parse(StringView filename) { auto result = Core::MappedFile::map(filename); if (result.is_error()) { @@ -603,7 +603,7 @@ ErrorOr serenity_main(Main::Arguments arguments) attempt_instantiate = true; auto parse_result = parse(filename); - if (!parse_result.has_value()) + if (parse_result.is_null()) return 1; g_stdout = TRY(Core::File::standard_output()); @@ -611,7 +611,7 @@ ErrorOr serenity_main(Main::Arguments arguments) if (print && !attempt_instantiate) { Wasm::Printer printer(*g_stdout); - printer.print(parse_result.value()); + printer.print(*parse_result); } if (attempt_instantiate) { @@ -653,14 +653,14 @@ ErrorOr serenity_main(Main::Arguments arguments) // First, resolve the linked modules Vector> linked_instances; - Vector linked_modules; + Vector> linked_modules; for (auto& name : modules_to_link_in) { auto parse_result = parse(name); - if (!parse_result.has_value()) { + if (parse_result.is_null()) { warnln("Failed to parse linked module '{}'", name); return 1; } - linked_modules.append(parse_result.release_value()); + linked_modules.append(parse_result.release_nonnull()); Wasm::Linker linker { linked_modules.last() }; for (auto& instance : linked_instances) linker.link(*instance); @@ -678,7 +678,7 @@ ErrorOr serenity_main(Main::Arguments arguments) linked_instances.append(instantiation_result.release_value()); } - Wasm::Linker linker { parse_result.value() }; + Wasm::Linker linker { *parse_result }; for (auto& instance : linked_instances) linker.link(*instance); @@ -704,7 +704,7 @@ ErrorOr serenity_main(Main::Arguments arguments) for (auto& entry : linker.unresolved_imports()) { if (!entry.type.has()) continue; - auto type = parse_result.value().type_section().types()[entry.type.get().value()]; + auto type = parse_result->type_section().types()[entry.type.get().value()]; auto address = machine.store().allocate(Wasm::HostFunction( [name = entry.name, type = type](auto&, auto& arguments) -> Wasm::Result { StringBuilder argument_builder; @@ -744,7 +744,7 @@ ErrorOr serenity_main(Main::Arguments arguments) print_link_error(link_result.error()); return 1; } - auto result = machine.instantiate(parse_result.value(), link_result.release_value()); + auto result = machine.instantiate(*parse_result, link_result.release_value()); if (result.is_error()) { warnln("Module instantiation failed: {}", result.error().error); return 1;