Kaynağa Gözat

LibWasm: Load and instantiate tables

This commit is a fairly large refactor, mainly because it unified the
two different ways that existed to represent references.
Now Reference values are also a kind of value.
It also implements a printer for values/references instead of copying
the implementation everywhere.
Ali Mohammad Pur 4 yıl önce
ebeveyn
işleme
be62e4d1d7

+ 9 - 7
Tests/LibWasm/test-wasm.cpp

@@ -185,16 +185,16 @@ JS_DEFINE_NATIVE_FUNCTION(WebAssemblyModule::wasm_invoke)
             arguments.append(Wasm::Value(static_cast<double>(value)));
             break;
         case Wasm::ValueType::Kind::FunctionReference:
-            arguments.append(Wasm::Value(Wasm::FunctionAddress { static_cast<u64>(value) }));
+            arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Func { static_cast<u64>(value) } }));
             break;
         case Wasm::ValueType::Kind::ExternReference:
-            arguments.append(Wasm::Value(Wasm::ExternAddress { static_cast<u64>(value) }));
+            arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Func { static_cast<u64>(value) } }));
             break;
         case Wasm::ValueType::Kind::NullFunctionReference:
-            arguments.append(Wasm::Value(Wasm::Value::Null { Wasm::ValueType(Wasm::ValueType::Kind::FunctionReference) }));
+            arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::FunctionReference) } }));
             break;
         case Wasm::ValueType::Kind::NullExternReference:
-            arguments.append(Wasm::Value(Wasm::Value::Null { Wasm::ValueType(Wasm::ValueType::Kind::ExternReference) }));
+            arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::ExternReference) } }));
             break;
         }
     }
@@ -211,8 +211,10 @@ JS_DEFINE_NATIVE_FUNCTION(WebAssemblyModule::wasm_invoke)
     JS::Value return_value;
     result.values().first().value().visit(
         [&](const auto& value) { return_value = JS::Value(static_cast<double>(value)); },
-        [&](const Wasm::FunctionAddress& index) { return_value = JS::Value(static_cast<double>(index.value())); },
-        [&](const Wasm::ExternAddress& index) { return_value = JS::Value(static_cast<double>(index.value())); },
-        [&](const Wasm::Value::Null&) { return_value = JS::js_null(); });
+        [&](const Wasm::Reference& reference) {
+            reference.ref().visit(
+                [&](const Wasm::Reference::Null&) { return_value = JS::js_null(); },
+                [&](const auto& ref) { return_value = JS::Value(static_cast<double>(ref.address.value())); });
+        });
     return return_value;
 }

+ 149 - 13
Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.cpp

@@ -52,6 +52,13 @@ Optional<GlobalAddress> Store::allocate(const GlobalType& type, Value value)
     return address;
 }
 
+Optional<ElementAddress> Store::allocate(const ValueType& type, Vector<Reference> references)
+{
+    ElementAddress address { m_elements.size() };
+    m_elements.append(ElementInstance { type, move(references) });
+    return address;
+}
+
 FunctionInstance* Store::get(FunctionAddress address)
 {
     auto value = address.value();
@@ -84,6 +91,14 @@ GlobalInstance* Store::get(GlobalAddress address)
     return &m_globals[value];
 }
 
+ElementInstance* Store::get(ElementAddress address)
+{
+    auto value = address.value();
+    if (m_elements.size() <= value)
+        return nullptr;
+    return &m_elements[value];
+}
+
 InstantiationResult AbstractMachine::instantiate(const Module& module, Vector<ExternValue> externs)
 {
     auto main_module_instance_pointer = make<ModuleInstance>();
@@ -97,6 +112,7 @@ InstantiationResult AbstractMachine::instantiate(const Module& module, Vector<Ex
     // FIXME: Validate stuff
 
     Vector<Value> global_values;
+    Vector<Vector<Reference>> elements;
     ModuleInstance auxiliary_instance;
 
     // FIXME: Check that imports/extern match
@@ -118,7 +134,6 @@ InstantiationResult AbstractMachine::instantiate(const Module& module, Vector<Ex
                 1,
             });
             auto result = config.execute(interpreter);
-            // What if this traps?
             if (result.is_trap())
                 instantiation_result = InstantiationError { "Global value construction trapped" };
             else
@@ -126,16 +141,121 @@ InstantiationResult AbstractMachine::instantiate(const Module& module, Vector<Ex
         }
     });
 
-    if (auto result = allocate_all(module, main_module_instance, externs, global_values); result.has_value()) {
+    if (instantiation_result.has_value())
+        return instantiation_result.release_value();
+
+    if (auto result = allocate_all_initial_phase(module, main_module_instance, externs, global_values); result.has_value())
         return result.release_value();
-    }
 
-    module.for_each_section_of_type<ElementSection>([&](const ElementSection&) {
-        // FIXME: Implement me
-        // https://webassembly.github.io/spec/core/bikeshed/#element-segments%E2%91%A0
-        // https://webassembly.github.io/spec/core/bikeshed/#instantiation%E2%91%A1 step 9
+    module.for_each_section_of_type<ElementSection>([&](const ElementSection& section) {
+        for (auto& segment : section.segments()) {
+            Vector<Reference> references;
+            for (auto& entry : segment.init) {
+                Configuration config { m_store };
+                config.set_frame(Frame {
+                    main_module_instance,
+                    Vector<Value> {},
+                    entry,
+                    entry.instructions().size(),
+                });
+                auto result = config.execute(interpreter);
+                if (result.is_trap()) {
+                    instantiation_result = InstantiationError { "Element construction trapped" };
+                    return IterationDecision::Continue;
+                }
+
+                for (auto& value : result.values()) {
+                    if (!value.type().is_reference()) {
+                        instantiation_result = InstantiationError { "Evaluated element entry is not a reference" };
+                        return IterationDecision::Continue;
+                    }
+                    auto reference = value.to<Reference>();
+                    if (!reference.has_value()) {
+                        instantiation_result = InstantiationError { "Evaluated element entry does not contain a reference" };
+                        return IterationDecision::Continue;
+                    }
+                    // FIXME: type-check the reference.
+                    references.prepend(reference.release_value());
+                }
+            }
+            elements.append(move(references));
+        }
+
+        return IterationDecision::Continue;
     });
 
+    if (instantiation_result.has_value())
+        return instantiation_result.release_value();
+
+    if (auto result = allocate_all_final_phase(module, main_module_instance, elements); result.has_value())
+        return result.release_value();
+
+    module.for_each_section_of_type<ElementSection>([&](const ElementSection& section) {
+        size_t index = 0;
+        for (auto& segment : section.segments()) {
+            auto current_index = index;
+            ++index;
+            auto active_ptr = segment.mode.get_pointer<ElementSection::Active>();
+            if (!active_ptr)
+                continue;
+            if (active_ptr->index.value() != 0) {
+                instantiation_result = InstantiationError { "Non-zero table referenced by active element segment" };
+                return IterationDecision::Break;
+            }
+            Configuration config { m_store };
+            config.set_frame(Frame {
+                main_module_instance,
+                Vector<Value> {},
+                active_ptr->expression,
+                1,
+            });
+            auto result = config.execute(interpreter);
+            if (result.is_trap()) {
+                instantiation_result = InstantiationError { "Element section initialisation trapped" };
+                return IterationDecision::Break;
+            }
+            auto d = result.values().first().to<i32>();
+            if (!d.has_value()) {
+                instantiation_result = InstantiationError { "Element section initialisation returned invalid table initial offset" };
+                return IterationDecision::Break;
+            }
+            if (main_module_instance.tables().size() < 1) {
+                instantiation_result = InstantiationError { "Element section initialisation references nonexistent table" };
+                return IterationDecision::Break;
+            }
+            auto table_instance = m_store.get(main_module_instance.tables()[0]);
+            if (current_index >= main_module_instance.elements().size()) {
+                instantiation_result = InstantiationError { "Invalid element referenced by active element segment" };
+                return IterationDecision::Break;
+            }
+            auto elem_instance = m_store.get(main_module_instance.elements()[current_index]);
+            if (!table_instance || !elem_instance) {
+                instantiation_result = InstantiationError { "Invalid element referenced by active element segment" };
+                return IterationDecision::Break;
+            }
+
+            auto total_required_size = elem_instance->references().size() + d.value();
+
+            if (table_instance->type().limits().max().value_or(total_required_size) < total_required_size) {
+                instantiation_result = InstantiationError { "Table limit overflow in active element segment" };
+                return IterationDecision::Break;
+            }
+
+            if (table_instance->elements().size() < total_required_size)
+                table_instance->elements().resize(total_required_size);
+
+            size_t i = 0;
+            for (auto it = elem_instance->references().begin(); it < elem_instance->references().end(); ++i, ++it) {
+                table_instance->elements()[i + d.value()] = *it;
+            }
+        }
+
+        return IterationDecision::Continue;
+    });
+
+    if (instantiation_result.has_value())
+        return instantiation_result.release_value();
+
     module.for_each_section_of_type<DataSection>([&](const DataSection& data_section) {
         for (auto& segment : data_section.data()) {
             segment.value().visit(
@@ -148,12 +268,14 @@ InstantiationResult AbstractMachine::instantiate(const Module& module, Vector<Ex
                         1,
                     });
                     auto result = config.execute(interpreter);
+                    if (result.is_trap()) {
+                        instantiation_result = InstantiationError { "Data section initialisation trapped" };
+                        return;
+                    }
                     size_t offset = 0;
                     result.values().first().value().visit(
                         [&](const auto& value) { offset = value; },
-                        [&](const FunctionAddress&) { instantiation_result = InstantiationError { "Data segment offset returned an address" }; },
-                        [&](const ExternAddress&) { instantiation_result = InstantiationError { "Data segment offset returned an address" }; },
-                        [&](const Value::Null&) { instantiation_result = InstantiationError { "Data segment offset returned a null reference" }; });
+                        [&](const Reference&) { instantiation_result = InstantiationError { "Data segment offset returned a reference" }; });
                     if (instantiation_result.has_value() && instantiation_result->is_error())
                         return;
                     if (main_module_instance.memories().size() <= data.index.value()) {
@@ -193,7 +315,7 @@ InstantiationResult AbstractMachine::instantiate(const Module& module, Vector<Ex
     return InstantiationResult { move(main_module_instance_pointer) };
 }
 
-Optional<InstantiationError> AbstractMachine::allocate_all(const Module& module, ModuleInstance& module_instance, Vector<ExternValue>& externs, Vector<Value>& global_values)
+Optional<InstantiationError> AbstractMachine::allocate_all_initial_phase(const Module& module, ModuleInstance& module_instance, Vector<ExternValue>& externs, Vector<Value>& global_values)
 {
     Optional<InstantiationError> result;
 
@@ -232,13 +354,12 @@ Optional<InstantiationError> AbstractMachine::allocate_all(const Module& module,
     module.for_each_section_of_type<GlobalSection>([&](const GlobalSection& section) {
         size_t index = 0;
         for (auto& entry : section.entries()) {
-            auto address = m_store.allocate(entry.type(), global_values[index]);
+            auto address = m_store.allocate(entry.type(), move(global_values[index]));
             VERIFY(address.has_value());
             module_instance.globals().append(*address);
             index++;
         }
     });
-
     module.for_each_section_of_type<ExportSection>([&](const ExportSection& section) {
         for (auto& entry : section.entries()) {
             Variant<FunctionAddress, TableAddress, MemoryAddress, GlobalAddress, Empty> address { Empty {} };
@@ -283,6 +404,21 @@ Optional<InstantiationError> AbstractMachine::allocate_all(const Module& module,
     return result;
 }
 
+Optional<InstantiationError> AbstractMachine::allocate_all_final_phase(const Module& module, ModuleInstance& module_instance, Vector<Vector<Reference>>& elements)
+{
+    module.for_each_section_of_type<ElementSection>([&](const ElementSection& section) {
+        size_t index = 0;
+        for (auto& segment : section.segments()) {
+            auto address = m_store.allocate(segment.type, move(elements[index]));
+            VERIFY(address.has_value());
+            module_instance.elements().append(*address);
+            index++;
+        }
+    });
+
+    return {};
+}
+
 Result AbstractMachine::invoke(FunctionAddress address, Vector<Value> arguments)
 {
     BytecodeInterpreter interpreter;

+ 71 - 49
Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.h

@@ -33,10 +33,35 @@ TYPEDEF_DISTINCT_NUMERIC_GENERAL(u64, true, true, false, false, false, true, Fun
 TYPEDEF_DISTINCT_NUMERIC_GENERAL(u64, true, true, false, false, false, true, ExternAddress);
 TYPEDEF_DISTINCT_NUMERIC_GENERAL(u64, true, true, false, false, false, true, TableAddress);
 TYPEDEF_DISTINCT_NUMERIC_GENERAL(u64, true, true, false, false, false, true, GlobalAddress);
+TYPEDEF_DISTINCT_NUMERIC_GENERAL(u64, true, true, false, false, false, true, ElementAddress);
 TYPEDEF_DISTINCT_NUMERIC_GENERAL(u64, true, true, false, false, false, true, MemoryAddress);
 
 // FIXME: These should probably be made generic/virtual if/when we decide to do something more
 //        fancy than just a dumb interpreter.
+class Reference {
+public:
+    struct Null {
+        ValueType type;
+    };
+    struct Func {
+        FunctionAddress address;
+    };
+    struct Extern {
+        ExternAddress address;
+    };
+
+    using RefType = Variant<Null, Func, Extern>;
+    explicit Reference(RefType ref)
+        : m_ref(move(ref))
+    {
+    }
+
+    auto& ref() const { return m_ref; }
+
+private:
+    RefType m_ref;
+};
+
 class Value {
 public:
     Value()
@@ -45,11 +70,7 @@ public:
     {
     }
 
-    struct Null {
-        ValueType type;
-    };
-
-    using AnyValueType = Variant<i32, i64, float, double, FunctionAddress, ExternAddress, Null>;
+    using AnyValueType = Variant<i32, i64, float, double, Reference>;
     explicit Value(AnyValueType value)
         : m_value(move(value))
         , m_type(ValueType::I32)
@@ -62,12 +83,12 @@ public:
             m_type = ValueType { ValueType::F32 };
         else if (m_value.has<double>())
             m_type = ValueType { ValueType::F64 };
-        else if (m_value.has<FunctionAddress>())
+        else if (m_value.has<Reference>() && m_value.get<Reference>().ref().has<Reference::Func>())
             m_type = ValueType { ValueType::FunctionReference };
-        else if (m_value.has<ExternAddress>())
+        else if (m_value.has<Reference>() && m_value.get<Reference>().ref().has<Reference::Extern>())
             m_type = ValueType { ValueType::ExternReference };
-        else if (m_value.has<Null>())
-            m_type = ValueType { m_value.get<Null>().type.kind() == ValueType::ExternReference ? ValueType::NullExternReference : ValueType::NullFunctionReference };
+        else if (m_value.has<Reference>())
+            m_type = m_value.get<Reference>().ref().get<Reference::Null>().type;
         else
             VERIFY_NOT_REACHED();
     }
@@ -79,10 +100,10 @@ public:
     {
         switch (type.kind()) {
         case ValueType::Kind::ExternReference:
-            m_value = ExternAddress { bit_cast<u64>(raw_value) };
+            m_value = Reference { Reference::Extern { { bit_cast<u64>(raw_value) } } };
             break;
         case ValueType::Kind::FunctionReference:
-            m_value = FunctionAddress { bit_cast<u64>(raw_value) };
+            m_value = Reference { Reference::Func { { bit_cast<u64>(raw_value) } } };
             break;
         case ValueType::Kind::I32:
             m_value = static_cast<i32>(bit_cast<i64>(raw_value));
@@ -98,11 +119,11 @@ public:
             break;
         case ValueType::Kind::NullFunctionReference:
             VERIFY(raw_value == 0);
-            m_value = Null { ValueType(ValueType::Kind::FunctionReference) };
+            m_value = Reference { Reference::Null { ValueType(ValueType::Kind::FunctionReference) } };
             break;
         case ValueType::Kind::NullExternReference:
             VERIFY(raw_value == 0);
-            m_value = Null { ValueType(ValueType::Kind::ExternReference) };
+            m_value = Reference { Reference::Null { ValueType(ValueType::Kind::ExternReference) } };
             break;
         default:
             VERIFY_NOT_REACHED();
@@ -146,17 +167,19 @@ public:
                 else if constexpr (!IsFloatingPoint<T> && IsSame<decltype(value), MakeSigned<T>>)
                     result = value;
             },
-            [&](const FunctionAddress& address) {
-                if constexpr (IsSame<T, FunctionAddress>)
-                    result = address;
-            },
-            [&](const ExternAddress& address) {
-                if constexpr (IsSame<T, ExternAddress>)
-                    result = address;
-            },
-            [&](const Null& null) {
-                if constexpr (IsSame<T, Null>)
-                    result = null;
+            [&](const Reference& value) {
+                if constexpr (IsSame<T, Reference>) {
+                    result = value;
+                } else if constexpr (IsSame<T, Reference::Func>) {
+                    if (auto ptr = value.ref().template get_pointer<Reference::Func>())
+                        result = *ptr;
+                } else if constexpr (IsSame<T, Reference::Extern>) {
+                    if (auto ptr = value.ref().template get_pointer<Reference::Extern>())
+                        result = *ptr;
+                } else if constexpr (IsSame<T, Reference::Null>) {
+                    if (auto ptr = value.ref().template get_pointer<Reference::Null>())
+                        result = *ptr;
+                }
             });
         return result;
     }
@@ -233,6 +256,7 @@ public:
     auto& tables() const { return m_tables; }
     auto& memories() const { return m_memories; }
     auto& globals() const { return m_globals; }
+    auto& elements() const { return m_elements; }
     auto& exports() const { return m_exports; }
 
     auto& types() { return m_types; }
@@ -240,6 +264,7 @@ public:
     auto& tables() { return m_tables; }
     auto& memories() { return m_memories; }
     auto& globals() { return m_globals; }
+    auto& elements() { return m_elements; }
     auto& exports() { return m_exports; }
 
 private:
@@ -248,6 +273,7 @@ private:
     Vector<TableAddress> m_tables;
     Vector<MemoryAddress> m_memories;
     Vector<GlobalAddress> m_globals;
+    Vector<ElementAddress> m_elements;
     Vector<ExportInstance> m_exports;
 };
 
@@ -288,30 +314,6 @@ private:
 
 using FunctionInstance = Variant<WasmFunction, HostFunction>;
 
-class Reference {
-public:
-    struct Null {
-        ValueType type;
-    };
-    struct Func {
-        FunctionAddress address;
-    };
-    struct Extern {
-        ExternAddress address;
-    };
-
-    using RefType = Variant<Null, Func, Extern>;
-    explicit Reference(RefType ref)
-        : m_ref(move(ref))
-    {
-    }
-
-    auto& ref() const { return m_ref; }
-
-private:
-    RefType m_ref;
-};
-
 class TableInstance {
 public:
     explicit TableInstance(const TableType& type, Vector<Optional<Reference>> elements)
@@ -384,6 +386,22 @@ private:
     Value m_value;
 };
 
+class ElementInstance {
+public:
+    explicit ElementInstance(ValueType type, Vector<Reference> references)
+        : m_type(move(type))
+        , m_references(move(references))
+    {
+    }
+
+    auto& type() const { return m_type; }
+    auto& references() const { return m_references; }
+
+private:
+    ValueType m_type;
+    Vector<Reference> m_references;
+};
+
 class Store {
 public:
     Store() = default;
@@ -393,17 +411,20 @@ public:
     Optional<TableAddress> allocate(const TableType&);
     Optional<MemoryAddress> allocate(const MemoryType&);
     Optional<GlobalAddress> allocate(const GlobalType&, Value);
+    Optional<ElementAddress> allocate(const ValueType&, Vector<Reference>);
 
     FunctionInstance* get(FunctionAddress);
     TableInstance* get(TableAddress);
     MemoryInstance* get(MemoryAddress);
     GlobalInstance* get(GlobalAddress);
+    ElementInstance* get(ElementAddress);
 
 private:
     Vector<FunctionInstance> m_functions;
     Vector<TableInstance> m_tables;
     Vector<MemoryInstance> m_memories;
     Vector<GlobalInstance> m_globals;
+    Vector<ElementInstance> m_elements;
 };
 
 class Label {
@@ -479,7 +500,8 @@ public:
     auto& store() { return m_store; }
 
 private:
-    Optional<InstantiationError> allocate_all(const Module&, ModuleInstance&, Vector<ExternValue>&, Vector<Value>& global_values);
+    Optional<InstantiationError> allocate_all_initial_phase(const Module&, ModuleInstance&, Vector<ExternValue>&, Vector<Value>& global_values);
+    Optional<InstantiationError> allocate_all_final_phase(const Module&, ModuleInstance&, Vector<Vector<Reference>>& elements);
     Store m_store;
 };
 

+ 14 - 18
Userland/Libraries/LibWasm/AbstractMachine/Configuration.cpp

@@ -6,6 +6,7 @@
 
 #include <LibWasm/AbstractMachine/Configuration.h>
 #include <LibWasm/AbstractMachine/Interpreter.h>
+#include <LibWasm/Printer/Printer.h>
 
 namespace Wasm {
 
@@ -24,6 +25,9 @@ Optional<Label> Configuration::nth_label(size_t i)
 
 void Configuration::unwind(Badge<CallFrameHandle>, const CallFrameHandle& frame_handle)
 {
+    if (m_stack.size() == frame_handle.stack_size && frame_handle.frame_index == m_current_frame_index)
+        return;
+
     VERIFY(m_stack.size() > frame_handle.stack_size);
     m_stack.entries().remove(frame_handle.stack_size, m_stack.size() - frame_handle.stack_size);
     m_current_frame_index = frame_handle.frame_index;
@@ -82,29 +86,21 @@ Result Configuration::execute(Interpreter& interpreter)
 
 void Configuration::dump_stack()
 {
+    auto print_value = []<typename... Ts>(CheckedFormatString<Ts...> format, Ts... vs)
+    {
+        DuplexMemoryStream memory_stream;
+        Printer { memory_stream }.print(vs...);
+        dbgln(format.view(), StringView(memory_stream.copy_into_contiguous_buffer()).trim_whitespace());
+    };
     for (const auto& entry : stack().entries()) {
         entry.visit(
-            [](const Value& v) {
-                v.value().visit([]<typename T>(const T& v) {
-                    if constexpr (IsIntegral<T> || IsFloatingPoint<T>)
-                        dbgln("    {}", v);
-                    else if constexpr (IsSame<Value::Null, T>)
-                        dbgln("    *null");
-                    else
-                        dbgln("    *{}", v.value());
-                });
+            [&](const Value& v) {
+                print_value("    {}", v);
             },
-            [](const Frame& f) {
+            [&](const Frame& f) {
                 dbgln("    frame({})", f.arity());
                 for (auto& local : f.locals()) {
-                    local.value().visit([]<typename T>(const T& v) {
-                        if constexpr (IsIntegral<T> || IsFloatingPoint<T>)
-                            dbgln("        {}", v);
-                        else if constexpr (IsSame<Value::Null, T>)
-                            dbgln("    *null");
-                        else
-                            dbgln("        *{}", v.value());
-                    });
+                    print_value("        {}", local);
                 }
             },
             [](const Label& l) {

+ 7 - 12
Userland/Libraries/LibWasm/AbstractMachine/Interpreter.cpp

@@ -32,6 +32,7 @@ namespace Wasm {
 
 void BytecodeInterpreter::interpret(Configuration& configuration)
 {
+    m_do_trap = false;
     auto& instructions = configuration.frame().expression().instructions();
     auto max_ip_value = InstructionPointer { instructions.size() };
     auto& current_ip_value = configuration.ip();
@@ -534,17 +535,11 @@ void BytecodeInterpreter::interpret(Configuration& configuration, InstructionPoi
         auto table_instance = configuration.store().get(table_address);
         auto index = configuration.stack().pop().get<Value>().to<i32>();
         TRAP_IF_NOT(index.has_value());
-        if (index.value() < 0 || static_cast<size_t>(index.value()) >= table_instance->elements().size()) {
-            dbgln("LibWasm: Element access out of bounds, expected {0} > 0 and {0} < {1}", index.value(), table_instance->elements().size());
-            m_do_trap = true;
-            return;
-        }
+        TRAP_IF_NOT(index.value() >= 0);
+        TRAP_IF_NOT(static_cast<size_t>(index.value()) < table_instance->elements().size());
         auto element = table_instance->elements()[index.value()];
-        if (!element.has_value() || !element->ref().has<Reference::Func>()) {
-            dbgln("LibWasm: call_indirect attempted with invalid address element (not a function)");
-            m_do_trap = true;
-            return;
-        }
+        TRAP_IF_NOT(element.has_value());
+        TRAP_IF_NOT(element->ref().has<Reference::Func>());
         auto address = element->ref().get<Reference::Func>().address;
         dbgln_if(WASM_TRACE_DEBUG, "call_indirect({} -> {})", index.value(), address.value());
         call_address(configuration, address);
@@ -652,7 +647,7 @@ void BytecodeInterpreter::interpret(Configuration& configuration, InstructionPoi
     case Instructions::ref_null.value(): {
         auto type = instruction.arguments().get<ValueType>();
         TRAP_IF_NOT(type.is_reference());
-        configuration.stack().push(Value(Value::Null { type }));
+        configuration.stack().push(Value(Reference(Reference::Null { type })));
         return;
     };
     case Instructions::ref_func.value(): {
@@ -667,7 +662,7 @@ void BytecodeInterpreter::interpret(Configuration& configuration, InstructionPoi
         auto top = configuration.stack().peek().get_pointer<Value>();
         TRAP_IF_NOT(top);
         TRAP_IF_NOT(top->type().is_reference());
-        auto is_null = top->to<Value::Null>().has_value();
+        auto is_null = top->to<Reference::Null>().has_value();
         configuration.stack().peek() = Value(ValueType(ValueType::I32), static_cast<u64>(is_null ? 1 : 0));
         return;
     }

+ 17 - 11
Userland/Libraries/LibWasm/Parser/Parser.cpp

@@ -948,7 +948,7 @@ ParseResult<ElementSection::SegmentType0> ElementSection::SegmentType0::parse(In
     if (indices.is_error())
         return indices.error();
 
-    return SegmentType0 { ValueType(ValueType::FunctionReference), indices.release_value(), Active { 0, expression.release_value() } };
+    return SegmentType0 { indices.release_value(), Active { 0, expression.release_value() } };
 }
 
 ParseResult<ElementSection::SegmentType1> ElementSection::SegmentType1::parse(InputStream& stream)
@@ -963,7 +963,7 @@ ParseResult<ElementSection::SegmentType1> ElementSection::SegmentType1::parse(In
     if (indices.is_error())
         return indices.error();
 
-    return SegmentType1 { ValueType(ValueType::FunctionReference), indices.release_value() };
+    return SegmentType1 { indices.release_value() };
 }
 
 ParseResult<ElementSection::SegmentType2> ElementSection::SegmentType2::parse(InputStream& stream)
@@ -1008,7 +1008,7 @@ ParseResult<ElementSection::SegmentType7> ElementSection::SegmentType7::parse(In
     return ParseError::NotImplemented;
 }
 
-ParseResult<ElementSection::AnyElementType> ElementSection::Element::parse(InputStream& stream)
+ParseResult<ElementSection::Element> ElementSection::Element::parse(InputStream& stream)
 {
     ScopeLogger<WASM_BINPARSER_DEBUG> logger("Element");
     u8 tag;
@@ -1021,49 +1021,55 @@ ParseResult<ElementSection::AnyElementType> ElementSection::Element::parse(Input
         if (auto result = SegmentType0::parse(stream); result.is_error()) {
             return result.error();
         } else {
-            return AnyElementType { result.release_value() };
+            Vector<Instruction> instructions;
+            for (auto& index : result.value().function_indices)
+                instructions.empend(Instructions::ref_func, index);
+            return Element { ValueType(ValueType::FunctionReference), { Expression { move(instructions) } }, move(result.value().mode) };
         }
     case 0x01:
         if (auto result = SegmentType1::parse(stream); result.is_error()) {
             return result.error();
         } else {
-            return AnyElementType { result.release_value() };
+            Vector<Instruction> instructions;
+            for (auto& index : result.value().function_indices)
+                instructions.empend(Instructions::ref_func, index);
+            return Element { ValueType(ValueType::FunctionReference), { Expression { move(instructions) } }, Passive {} };
         }
     case 0x02:
         if (auto result = SegmentType2::parse(stream); result.is_error()) {
             return result.error();
         } else {
-            return AnyElementType { result.release_value() };
+            return ParseError::NotImplemented;
         }
     case 0x03:
         if (auto result = SegmentType3::parse(stream); result.is_error()) {
             return result.error();
         } else {
-            return AnyElementType { result.release_value() };
+            return ParseError::NotImplemented;
         }
     case 0x04:
         if (auto result = SegmentType4::parse(stream); result.is_error()) {
             return result.error();
         } else {
-            return AnyElementType { result.release_value() };
+            return ParseError::NotImplemented;
         }
     case 0x05:
         if (auto result = SegmentType5::parse(stream); result.is_error()) {
             return result.error();
         } else {
-            return AnyElementType { result.release_value() };
+            return ParseError::NotImplemented;
         }
     case 0x06:
         if (auto result = SegmentType6::parse(stream); result.is_error()) {
             return result.error();
         } else {
-            return AnyElementType { result.release_value() };
+            return ParseError::NotImplemented;
         }
     case 0x07:
         if (auto result = SegmentType7::parse(stream); result.is_error()) {
             return result.error();
         } else {
-            return AnyElementType { result.release_value() };
+            return ParseError::NotImplemented;
         }
     default:
         return ParseError::InvalidTag;

+ 65 - 37
Userland/Libraries/LibWasm/Printer/Printer.cpp

@@ -6,6 +6,7 @@
 
 #include <AK/HashMap.h>
 #include <AK/TemporaryChange.h>
+#include <LibWasm/AbstractMachine/AbstractMachine.h>
 #include <LibWasm/Printer/Printer.h>
 
 namespace Wasm {
@@ -170,53 +171,52 @@ void Printer::print(const Wasm::ElementSection& section)
     {
         TemporaryChange change { m_indent, m_indent + 1 };
         for (auto& entry : section.segments())
-            entry.visit([this](auto& segment) { print(segment); });
+            print(entry);
     }
     print_indent();
     print(")\n");
 }
 
-void Printer::print(const Wasm::ElementSection::SegmentType0&)
-{
-}
-
-void Printer::print(const Wasm::ElementSection::SegmentType1& segment)
+void Printer::print(const Wasm::ElementSection::Element& element)
 {
     print_indent();
-    print("(element segment kind 1\n");
+    print("(element ");
+    {
+        TemporaryChange<size_t> change { m_indent, 0 };
+        print(element.type);
+    }
     {
         TemporaryChange change { m_indent, m_indent + 1 };
-        for (auto& index : segment.function_indices) {
-            print_indent();
-            print("(function index {})\n", index.value());
+        print_indent();
+        print("(init\n");
+        {
+            TemporaryChange change { m_indent, m_indent + 1 };
+            for (auto& entry : element.init)
+                print(entry);
         }
+        print_indent();
+        print(")\n");
+        print_indent();
+        print("(mode ");
+        element.mode.visit(
+            [this](const ElementSection::Active& active) {
+                print("\n");
+                {
+                    TemporaryChange change { m_indent, m_indent + 1 };
+                    print_indent();
+                    print("(active index {}\n", active.index.value());
+                    {
+                        print(active.expression);
+                    }
+                    print_indent();
+                    print(")\n");
+                }
+                print_indent();
+            },
+            [this](const ElementSection::Passive&) { print("passive"); },
+            [this](const ElementSection::Declarative&) { print("declarative"); });
+        print(")\n");
     }
-    print_indent();
-    print(")\n");
-}
-
-void Printer::print(const Wasm::ElementSection::SegmentType2&)
-{
-}
-
-void Printer::print(const Wasm::ElementSection::SegmentType3&)
-{
-}
-
-void Printer::print(const Wasm::ElementSection::SegmentType4&)
-{
-}
-
-void Printer::print(const Wasm::ElementSection::SegmentType5&)
-{
-}
-
-void Printer::print(const Wasm::ElementSection::SegmentType6&)
-{
-}
-
-void Printer::print(const Wasm::ElementSection::SegmentType7&)
-{
 }
 
 void Printer::print(const Wasm::ExportSection& section)
@@ -312,7 +312,7 @@ void Printer::print(const Wasm::FunctionType& type)
                 print(param);
         }
         print_indent();
-        print("\n");
+        print(")\n");
     }
     {
         TemporaryChange change { m_indent, m_indent + 1 };
@@ -622,6 +622,34 @@ void Printer::print(const Wasm::ValueType& type)
     print_indent();
     print("(type {})\n", ValueType::kind_name(type.kind()));
 }
+
+void Printer::print(const Wasm::Value& value)
+{
+    print_indent();
+    print("{} ", value.value().visit([&]<typename T>(const T& value) {
+        if constexpr (IsSame<Wasm::Reference, T>)
+            return String::formatted(
+                "addr({})",
+                value.ref().visit(
+                    [](const Wasm::Reference::Null&) { return String("null"); },
+                    [](const auto& ref) { return String::number(ref.address.value()); }));
+        else
+            return String::formatted("{}", value);
+    }));
+    TemporaryChange<size_t> change { m_indent, 0 };
+    print(value.type());
+}
+
+void Printer::print(const Wasm::Reference& value)
+{
+    print_indent();
+    print(
+        "addr({})\n",
+        value.ref().visit(
+            [](const Wasm::Reference::Null&) { return String("null"); },
+            [](const auto& ref) { return String::number(ref.address.value()); }));
+}
+
 }
 
 HashMap<Wasm::OpCode, String> Wasm::Names::instruction_names {

+ 5 - 8
Userland/Libraries/LibWasm/Printer/Printer.h

@@ -10,6 +10,8 @@
 
 namespace Wasm {
 
+class Value;
+
 String instruction_name(const OpCode& opcode);
 
 struct Printer {
@@ -28,14 +30,7 @@ struct Printer {
     void print(const Wasm::DataSection&);
     void print(const Wasm::DataSection::Data&);
     void print(const Wasm::ElementSection&);
-    void print(const Wasm::ElementSection::SegmentType0&);
-    void print(const Wasm::ElementSection::SegmentType1&);
-    void print(const Wasm::ElementSection::SegmentType2&);
-    void print(const Wasm::ElementSection::SegmentType3&);
-    void print(const Wasm::ElementSection::SegmentType4&);
-    void print(const Wasm::ElementSection::SegmentType5&);
-    void print(const Wasm::ElementSection::SegmentType6&);
-    void print(const Wasm::ElementSection::SegmentType7&);
+    void print(const Wasm::ElementSection::Element&);
     void print(const Wasm::ExportSection&);
     void print(const Wasm::ExportSection::Export&);
     void print(const Wasm::Expression&);
@@ -54,6 +49,7 @@ struct Printer {
     void print(const Wasm::MemoryType&);
     void print(const Wasm::Module&);
     void print(const Wasm::Module::Function&);
+    void print(const Wasm::Reference&);
     void print(const Wasm::StartSection&);
     void print(const Wasm::StartSection::StartFunction&);
     void print(const Wasm::TableSection&);
@@ -61,6 +57,7 @@ struct Printer {
     void print(const Wasm::TableType&);
     void print(const Wasm::TypeSection&);
     void print(const Wasm::ValueType&);
+    void print(const Wasm::Value&);
 
 private:
     void print_indent();

+ 8 - 15
Userland/Libraries/LibWasm/Types.h

@@ -763,13 +763,12 @@ public:
         // FIXME: Implement me!
         static ParseResult<SegmentType0> parse(InputStream& stream);
 
-        ValueType type;
         Vector<FunctionIndex> function_indices;
         Active mode;
     };
     struct SegmentType1 {
         static ParseResult<SegmentType1> parse(InputStream& stream);
-        ValueType type;
+
         Vector<FunctionIndex> function_indices;
     };
     struct SegmentType2 {
@@ -797,23 +796,17 @@ public:
         static ParseResult<SegmentType7> parse(InputStream& stream);
     };
 
-    using AnyElementType = Variant<
-        SegmentType0,
-        SegmentType1,
-        SegmentType2,
-        SegmentType3,
-        SegmentType4,
-        SegmentType5,
-        SegmentType6,
-        SegmentType7>;
-
     struct Element {
-        static ParseResult<AnyElementType> parse(InputStream&);
+        static ParseResult<Element> parse(InputStream&);
+
+        ValueType type;
+        Vector<Expression> init;
+        Variant<Active, Passive, Declarative> mode;
     };
 
     static constexpr u8 section_id = 9;
 
-    explicit ElementSection(Vector<AnyElementType> segs)
+    explicit ElementSection(Vector<Element> segs)
         : m_segments(move(segs))
     {
     }
@@ -823,7 +816,7 @@ public:
     static ParseResult<ElementSection> parse(InputStream& stream);
 
 private:
-    Vector<AnyElementType> m_segments;
+    Vector<Element> m_segments;
 };
 
 class Locals {

+ 4 - 21
Userland/Utilities/wasm.cpp

@@ -204,17 +204,8 @@ static bool pre_interpret_hook(Wasm::Configuration& config, Wasm::InstructionPoi
             if (!result.values().is_empty())
                 warnln("Returned:");
             for (auto& value : result.values()) {
-                auto str = value.value().visit(
-                    [&]<typename T>(const T& value) {
-                        if constexpr (requires { value.value(); })
-                            return String::formatted("  -> addr{} ", value.value());
-                        else if constexpr (IsSame<Wasm::Value::Null, T>)
-                            return String::formatted("  ->addr(null)");
-                        else
-                            return String::formatted("  -> {} ", value);
-                    });
-                g_stdout.write(str.bytes());
-                g_printer.print(value.type());
+                g_stdout.write("  -> "sv.bytes());
+                g_printer.print(value);
             }
             continue;
         }
@@ -541,17 +532,9 @@ int main(int argc, char* argv[])
             if (!result.values().is_empty())
                 warnln("Returned:");
             for (auto& value : result.values()) {
-                value.value().visit(
-                    [&]<typename T>(const T& value) {
-                        if constexpr (requires { value.value(); })
-                            out("  -> addr{} ", value.value());
-                        else if constexpr (IsSame<Wasm::Value::Null, T>)
-                            out("  ->addr(null)");
-                        else
-                            out("  -> {} ", value);
-                    });
                 Wasm::Printer printer { stream };
-                printer.print(value.type());
+                g_stdout.write("  -> "sv.bytes());
+                g_printer.print(value);
             }
         }
     }