Ver código fonte

LibWasm: Implement integer conversion and narrowing SIMD instructions

Diego Frias 11 meses atrás
pai
commit
616048c67e

+ 14 - 6
Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp

@@ -1712,18 +1712,26 @@ void BytecodeInterpreter::interpret(Configuration& configuration, InstructionPoi
         return unary_operation<u128, i32, Operators::VectorBitmask<4>>(configuration);
     case Instructions::i64x2_bitmask.value():
         return unary_operation<u128, i32, Operators::VectorBitmask<2>>(configuration);
-    case Instructions::f32x4_demote_f64x2_zero.value():
-    case Instructions::f64x2_promote_low_f32x4.value():
+    case Instructions::i32x4_dot_i16x8_s.value():
+        return binary_numeric_operation<u128, u128, Operators::VectorDotProduct<4>>(configuration);
     case Instructions::i8x16_narrow_i16x8_s.value():
+        return binary_numeric_operation<u128, u128, Operators::VectorNarrow<16, i8>>(configuration);
     case Instructions::i8x16_narrow_i16x8_u.value():
-    case Instructions::i16x8_q15mulr_sat_s.value():
+        return binary_numeric_operation<u128, u128, Operators::VectorNarrow<16, u8>>(configuration);
     case Instructions::i16x8_narrow_i32x4_s.value():
+        return binary_numeric_operation<u128, u128, Operators::VectorNarrow<8, i16>>(configuration);
     case Instructions::i16x8_narrow_i32x4_u.value():
-    case Instructions::i32x4_dot_i16x8_s.value():
-    case Instructions::i32x4_trunc_sat_f64x2_s_zero.value():
-    case Instructions::i32x4_trunc_sat_f64x2_u_zero.value():
+        return binary_numeric_operation<u128, u128, Operators::VectorNarrow<8, u16>>(configuration);
+    case Instructions::i16x8_q15mulr_sat_s.value():
+        return binary_numeric_operation<u128, u128, Operators::VectorIntegerBinaryOp<8, Operators::SaturatingOp<i16, Operators::Q15Mul>, MakeSigned>>(configuration);
     case Instructions::f32x4_convert_i32x4_s.value():
+        return unary_operation<u128, u128, Operators::VectorIntegerConvertOp<4, Operators::Convert<f32>, MakeSigned>>(configuration);
     case Instructions::f32x4_convert_i32x4_u.value():
+        return unary_operation<u128, u128, Operators::VectorIntegerConvertOp<4, Operators::Convert<f32>, MakeUnsigned>>(configuration);
+    case Instructions::f32x4_demote_f64x2_zero.value():
+    case Instructions::f64x2_promote_low_f32x4.value():
+    case Instructions::i32x4_trunc_sat_f64x2_s_zero.value():
+    case Instructions::i32x4_trunc_sat_f64x2_u_zero.value():
     case Instructions::f64x2_convert_low_i32x4_s.value():
     case Instructions::f64x2_convert_low_i32x4_u.value():
         dbgln_if(WASM_TRACE_DEBUG, "Instruction '{}' not implemented", instruction_name(instruction.opcode()));

+ 93 - 0
Userland/Libraries/LibWasm/AbstractMachine/Operators.h

@@ -94,6 +94,16 @@ struct Average {
     static StringView name() { return "avgr"sv; }
 };
 
+struct Q15Mul {
+    template<typename Lhs, typename Rhs>
+    auto operator()(Lhs lhs, Rhs rhs) const
+    {
+        return (lhs * rhs + 0x4000) >> 15;
+    }
+
+    static StringView name() { return "q15mul"sv; }
+};
+
 struct BitShiftLeft {
     template<typename Lhs, typename Rhs>
     auto operator()(Lhs lhs, Rhs rhs) const { return lhs << (rhs % (sizeof(lhs) * 8)); }
@@ -727,6 +737,62 @@ struct VectorBitmask {
     static StringView name() { return "bitmask"sv; }
 };
 
+template<size_t VectorSize>
+struct VectorDotProduct {
+    auto operator()(u128 lhs, u128 rhs) const
+    {
+        using VectorInput = NativeVectorType<128 / (VectorSize * 2), VectorSize * 2, MakeSigned>;
+        using VectorResult = NativeVectorType<128 / VectorSize, VectorSize, MakeSigned>;
+        auto v1 = bit_cast<VectorInput>(lhs);
+        auto v2 = bit_cast<VectorInput>(rhs);
+        VectorResult result;
+
+        using ResultType = MakeUnsigned<NativeIntegralType<128 / VectorSize>>;
+        for (size_t i = 0; i < VectorSize; ++i) {
+            ResultType low = v1[i * 2] * v2[i * 2];
+            ResultType high = v1[(i * 2) + 1] * v2[(i * 2) + 1];
+            result[i] = low + high;
+        }
+
+        return bit_cast<u128>(result);
+    }
+
+    static StringView name() { return "dot"sv; }
+};
+
+template<size_t VectorSize, typename Element>
+struct VectorNarrow {
+    auto operator()(u128 lhs, u128 rhs) const
+    {
+        using VectorInput = NativeVectorType<128 / (VectorSize / 2), VectorSize / 2, MakeSigned>;
+        using VectorResult = NativeVectorType<128 / VectorSize, VectorSize, MakeUnsigned>;
+        auto v1 = bit_cast<VectorInput>(lhs);
+        auto v2 = bit_cast<VectorInput>(rhs);
+        VectorResult result;
+
+        for (size_t i = 0; i < (VectorSize / 2); ++i) {
+            if (v1[i] <= NumericLimits<Element>::min())
+                result[i] = NumericLimits<Element>::min();
+            else if (v1[i] >= NumericLimits<Element>::max())
+                result[i] = NumericLimits<Element>::max();
+            else
+                result[i] = v1[i];
+        }
+        for (size_t i = 0; i < (VectorSize / 2); ++i) {
+            if (v2[i] <= NumericLimits<Element>::min())
+                result[i + VectorSize / 2] = NumericLimits<Element>::min();
+            else if (v2[i] >= NumericLimits<Element>::max())
+                result[i + VectorSize / 2] = NumericLimits<Element>::max();
+            else
+                result[i + VectorSize / 2] = v2[i];
+        }
+
+        return bit_cast<u128>(result);
+    }
+
+    static StringView name() { return "narrow"sv; }
+};
+
 template<size_t VectorSize, typename Op, template<typename> typename SetSign = MakeSigned>
 struct VectorIntegerUnaryOp {
     auto operator()(u128 lhs) const
@@ -844,6 +910,33 @@ struct VectorFloatConvertOp {
     }
 };
 
+template<size_t VectorSize, typename Op, template<typename> typename SetSign = MakeSigned>
+struct VectorIntegerConvertOp {
+    auto operator()(u128 lhs) const
+    {
+        using VectorInput = NativeVectorType<128 / VectorSize, VectorSize, SetSign>;
+        using VectorResult = NativeFloatingVectorType<128, VectorSize, NativeFloatingType<128 / VectorSize>>;
+        auto value = bit_cast<VectorInput>(lhs);
+        VectorResult result;
+        Op op;
+        for (size_t i = 0; i < VectorSize; ++i)
+            result[i] = op(value[i]);
+        return bit_cast<u128>(result);
+    }
+
+    static StringView name()
+    {
+        switch (VectorSize) {
+        case 4:
+            return "vec(32x4).cvt_op"sv;
+        case 2:
+            return "vec(64x2).cvt_op"sv;
+        default:
+            VERIFY_NOT_REACHED();
+        }
+    }
+};
+
 struct Floor {
     template<typename Lhs>
     auto operator()(Lhs lhs) const