瀏覽代碼

LibWasm: Check for correct NaN bit patterns in tests

Some spec-tests check the bit pattern of a returned `NaN` (i.e.
`nan:canonical`, `nan:arithmetic`, or something like `nan:0x200000`).
Previously, we just accepted any `NaN`.
Diego 1 年之前
父節點
當前提交
524e09dda1
共有 2 個文件被更改,包括 85 次插入52 次删除
  1. 50 28
      Meta/generate-libwasm-spec-test.py
  2. 35 24
      Tests/LibWasm/test-wasm.cpp

+ 50 - 28
Meta/generate-libwasm-spec-test.py

@@ -1,5 +1,4 @@
 import json
-import math
 import sys
 import struct
 import subprocess
@@ -89,6 +88,19 @@ Command = Union[
 ]
 
 
+@dataclass
+class ArithmeticNan:
+    num_bits: int
+
+
+@dataclass
+class CanonicalNan:
+    num_bits: int
+
+
+GeneratedValue = Union[str, ArithmeticNan, CanonicalNan]
+
+
 @dataclass
 class WastDescription:
     source_filename: str
@@ -200,7 +212,7 @@ def make_description(input_path: Path, name: str, out_path: Path) -> WastDescrip
     return parse(description)
 
 
-def gen_value(value: WasmValue, as_arg=False) -> str:
+def gen_value_arg(value: WasmValue) -> str:
     def unsigned_to_signed(uint: int, bits: int) -> int:
         max_value = 2**bits
         if uint >= 2 ** (bits - 1):
@@ -221,30 +233,15 @@ def gen_value(value: WasmValue, as_arg=False) -> str:
         f = struct.unpack("d", b)[0]
         return f
 
-    def float_to_str(bits: int, *, double=False, preserve_nan_sign=False) -> str:
+    def float_to_str(bits: int, *, double=False) -> str:
         f = int_to_float64_bitcast(bits) if double else int_to_float_bitcast(bits)
-
-        if math.isnan(f) and preserve_nan_sign:
-            f_bytes = bits.to_bytes(8 if double else 4, byteorder="little")
-            # -NaN does not preserve the sign bit in JavaScript land, so if
-            # we want to preserve NaN "sign", we pass in raw bytes
-            return f"new Uint8Array({list(f_bytes)})"
-
-        if math.isnan(f) and math.copysign(1.0, f) < 0:
-            return "-NaN"
-        elif math.isnan(f):
-            return "NaN"
-        elif math.isinf(f) and math.copysign(1.0, f) < 0:
-            return "-Infinity"
-        elif math.isinf(f):
-            return "Infinity"
         return str(f)
 
     if value.value.startswith("nan"):
-        return "NaN"
-    elif value.value == "inf":
+        raise GenerateException("Should not get indeterminate nan value as an argument")
+    if value.value == "inf":
         return "Infinity"
-    elif value.value == "-inf":
+    if value.value == "-inf":
         return "-Infinity"
 
     match value.kind:
@@ -253,19 +250,33 @@ def gen_value(value: WasmValue, as_arg=False) -> str:
         case "i64":
             return str(unsigned_to_signed(int(value.value), 64)) + "n"
         case "f32":
-            return float_to_str(
-                int(value.value), double=False, preserve_nan_sign=as_arg
-            )
+            return str(int(value.value)) + f" /* {float_to_str(int(value.value))} */"
         case "f64":
-            return float_to_str(int(value.value), double=True, preserve_nan_sign=as_arg)
+            return (
+                str(int(value.value))
+                + f"n /* {float_to_str(int(value.value), double=True)} */"
+            )
         case "externref" | "funcref" | "v128":
             return value.value
         case _:
             raise GenerateException(f"Not implemented: {value.kind}")
 
 
+def gen_value_result(value: WasmValue) -> GeneratedValue:
+    if (value.kind == "f32" or value.kind == "f64") and value.value.startswith("nan"):
+        num_bits = int(value.kind[1:])
+        match value.value:
+            case "nan:canonical":
+                return CanonicalNan(num_bits)
+            case "nan:arithmetic":
+                return ArithmeticNan(num_bits)
+            case _:
+                raise GenerateException(f"Unknown indeterminate nan: {value.value}")
+    return gen_value_arg(value)
+
+
 def gen_args(args: list[WasmValue]) -> str:
-    return ",".join(gen_value(arg, True) for arg in args)
+    return ",".join(gen_value_arg(arg) for arg in args)
 
 
 def gen_module_command(command: ModuleCommand, ctx: Context):
@@ -336,7 +347,18 @@ expect(_field).not.toBeUndefined();"""
     else:
         print(f"let _result = {module}.invoke(_field, {gen_args(invoke.args)});")
     if result is not None:
-        print(f"expect(_result).toBe({gen_value(result)});")
+        gen_result = gen_value_result(result)
+        match gen_result:
+            case str():
+                print(f"expect(_result).toBe({gen_result});")
+            case ArithmeticNan():
+                print(
+                    f"expect(isArithmeticNaN{gen_result.num_bits}(_result)).toBe(true);"
+                )
+            case CanonicalNan():
+                print(
+                    f"expect(isCanonicalNaN{gen_result.num_bits}(_result)).toBe(true);"
+                )
     print("});")
     if not ctx.has_unclosed:
         print("});")
@@ -351,7 +373,7 @@ def gen_get(line: int, get: Get, result: WasmValue | None, ctx: Context):
 let _field = {module}.getExport("{get.field}");"""
     )
     if result is not None:
-        print(f"expect(_field).toBe({gen_value(result)});")
+        print(f"expect(_field).toBe({gen_value_result(result)});")
     print("});")
 
 

+ 35 - 24
Tests/LibWasm/test-wasm.cpp

@@ -199,6 +199,30 @@ TESTJS_GLOBAL_FUNCTION(compare_typed_arrays, compareTypedArrays)
     return JS::Value(lhs_array.viewed_array_buffer()->buffer() == rhs_array.viewed_array_buffer()->buffer());
 }
 
+TESTJS_GLOBAL_FUNCTION(is_canonical_nan32, isCanonicalNaN32)
+{
+    auto value = TRY(vm.argument(0).to_u32(vm));
+    return value == 0x7FC00000 || value == 0xFFC00000;
+}
+
+TESTJS_GLOBAL_FUNCTION(is_canonical_nan64, isCanonicalNaN64)
+{
+    auto value = TRY(vm.argument(0).to_bigint_uint64(vm));
+    return value == 0x7FF8000000000000 || value == 0xFFF8000000000000;
+}
+
+TESTJS_GLOBAL_FUNCTION(is_arithmetic_nan32, isArithmeticNaN32)
+{
+    auto value = bit_cast<float>(TRY(vm.argument(0).to_u32(vm)));
+    return isnan(value);
+}
+
+TESTJS_GLOBAL_FUNCTION(is_arithmetic_nan64, isArithmeticNaN64)
+{
+    auto value = bit_cast<double>(TRY(vm.argument(0).to_bigint_uint64(vm)));
+    return isnan(value);
+}
+
 void WebAssemblyModule::initialize(JS::Realm& realm)
 {
     Base::initialize(realm);
@@ -257,17 +281,7 @@ JS_DEFINE_NATIVE_FUNCTION(WebAssemblyModule::wasm_invoke)
     for (auto& param : type->parameters()) {
         auto argument = vm.argument(index++);
         double double_value = 0;
-        if (argument.is_object()) {
-            auto object = MUST(argument.to_object(vm));
-            // Uint8Array allows for raw bytes to be passed into Wasm. This is
-            // particularly useful for NaN bit patterns
-            if (!is<JS::Uint8Array>(*object))
-                return vm.throw_completion<JS::TypeError>("Expected a Uint8Array object"sv);
-            auto& array = static_cast<JS::Uint8Array&>(*object);
-            if (array.array_length().length() > 8)
-                return vm.throw_completion<JS::TypeError>("Expected a Uint8Array of size <= 8"sv);
-            memcpy(&double_value, array.data().data(), array.array_length().length());
-        } else if (!argument.is_bigint())
+        if (!argument.is_bigint())
             double_value = TRY(argument.to_double(vm));
         switch (param.kind()) {
         case Wasm::ValueType::Kind::I32:
@@ -282,20 +296,15 @@ JS_DEFINE_NATIVE_FUNCTION(WebAssemblyModule::wasm_invoke)
             }
             break;
         case Wasm::ValueType::Kind::F32:
-            // double_value should contain up to 8 bytes of information,
-            // if we were passed a Uint8Array. If the expected arg is a
-            // float, we were probably passed a Uint8Array of size 4. So
-            // we copy those bytes into a float value.
-            if (argument.is_object()) {
-                float float_value = 0;
-                memcpy(&float_value, &double_value, sizeof(float));
-                arguments.append(Wasm::Value(float_value));
-            } else {
-                arguments.append(Wasm::Value(static_cast<float>(double_value)));
-            }
+            arguments.append(Wasm::Value(bit_cast<float>(static_cast<u32>(double_value))));
             break;
         case Wasm::ValueType::Kind::F64:
-            arguments.append(Wasm::Value(static_cast<double>(double_value)));
+            if (argument.is_bigint()) {
+                auto value = TRY(argument.to_bigint_uint64(vm));
+                arguments.append(Wasm::Value(param, bit_cast<double>(value)));
+            } else {
+                arguments.append(Wasm::Value(param, double_value));
+            }
             break;
         case Wasm::ValueType::Kind::V128: {
             if (!argument.is_bigint()) {
@@ -344,7 +353,9 @@ JS_DEFINE_NATIVE_FUNCTION(WebAssemblyModule::wasm_invoke)
 
     auto to_js_value = [&](Wasm::Value const& value) {
         return value.value().visit(
-            [](auto const& value) { return JS::Value(static_cast<double>(value)); },
+            // For floating point values, we're testing with their bit representation, so we bit_cast them
+            [](f32 value) { return JS::Value(static_cast<double>(bit_cast<u32>(value))); },
+            [&](f64 value) { return JS::Value(JS::BigInt::create(vm, Crypto::SignedBigInteger { Crypto::UnsignedBigInteger { bit_cast<u64>(value) } })); },
             [](i32 value) { return JS::Value(static_cast<double>(value)); },
             [&](i64 value) { return JS::Value(JS::BigInt::create(vm, Crypto::SignedBigInteger { value })); },
             [&](u128 value) {