diff --git a/ynnpack/base/simd/arm_neon_base.h b/ynnpack/base/simd/arm_neon_base.h index 6d7061a9ac0..0c2bd9df6e2 100644 --- a/ynnpack/base/simd/arm_neon_base.h +++ b/ynnpack/base/simd/arm_neon_base.h @@ -24,6 +24,7 @@ namespace ynn { namespace simd { +// Half-vector wrappers template <> struct vec { using value_type = uint8_t; @@ -36,8 +37,6 @@ struct vec { uint8x8_t v; }; -using u8x8 = vec; - template <> struct vec { using value_type = float; @@ -50,6 +49,36 @@ struct vec { float32x2_t v; }; +template <> +struct vec { + using value_type = bfloat16; + static constexpr std::integral_constant N = {}; + + vec() = default; + explicit vec(uint16x4_t v) : v(v) {} + vec(bfloat16 x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT + + uint16x4_t v; +}; + +template <> +struct vec { + using value_type = half; + static constexpr std::integral_constant N = {}; + + vec() = default; + explicit vec(uint16x4_t v) : v(v) {} + vec(half x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT + + uint16x4_t v; +}; + +using u8x8 = vec; +using f32x2 = vec; +using bf16x4 = vec; +using f16x4 = vec; + +// Full vector wrappers template <> struct vec { using value_type = float; @@ -58,12 +87,9 @@ struct vec { vec() = default; explicit vec(float32x4_t v) : v(v) {} vec(float x) : v(vdupq_n_f32(x)) {} // NOLINT - vec(vec lo, vec hi) : v(vcombine_f32(lo.v, hi.v)) {} + vec(f32x2 lo, f32x2 hi) : v(vcombine_f32(lo.v, hi.v)) {} float32x4_t v; - - vec lo() const { return vec{vget_low_f32(v)}; } - vec hi() const { return vec{vget_high_f32(v)}; } }; #ifdef YNN_ARCH_ARM64 @@ -104,18 +130,6 @@ struct vec { int32x4_t v; }; -template <> -struct vec { - using value_type = bfloat16; - static constexpr std::integral_constant N = {}; - - vec() = default; - explicit vec(uint16x4_t v) : v(v) {} - vec(bfloat16 x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT - - uint16x4_t v; -}; - template <> struct vec { using value_type = bfloat16; @@ -124,24 +138,9 @@ struct vec { vec() = default; explicit vec(uint16x8_t v) : v(v) {} vec(bfloat16 x) : v(vdupq_n_u16(x.to_bits())) {} // NOLINT - vec(vec lo, vec hi) : v(vcombine_u16(lo.v, hi.v)) {} + vec(bf16x4 lo, bf16x4 hi) : v(vcombine_u16(lo.v, hi.v)) {} uint16x8_t v; - - vec lo() const { return vec{vget_low_u16(v)}; } - vec hi() const { return vec{vget_high_u16(v)}; } -}; - -template <> -struct vec { - using value_type = half; - static constexpr std::integral_constant N = {}; - - vec() = default; - explicit vec(uint16x4_t v) : v(v) {} - vec(half x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT - - uint16x4_t v; }; template <> @@ -190,9 +189,6 @@ struct vec { vec(u8x8 lo, u8x8 hi) : v(vcombine_u8(lo.v, hi.v)) {} vec(uint8_t x) : v(vdupq_n_u8(x)) {} // NOLINT - u8x8 lo() const { return u8x8{vget_low_u8(v)}; } - u8x8 hi() const { return u8x8{vget_high_u8(v)}; } - uint8x16_t v; }; @@ -208,22 +204,26 @@ struct vec { int8x16_t v; }; -using f32x2 = vec; using f32x4 = vec; #ifdef YNN_ARCH_ARM64 using f64x2 = vec; #endif using u32x4 = vec; using s32x4 = vec; -using bf16x4 = vec; using bf16x8 = vec; -using f16x4 = vec; using f16x8 = vec; using u16x8 = vec; using s16x8 = vec; using u8x16 = vec; using s8x16 = vec; +YNN_ALWAYS_INLINE f32x2 lo(f32x4 x) { return f32x2{vget_low_f32(x.v)}; } +YNN_ALWAYS_INLINE f32x2 hi(f32x4 x) { return f32x2{vget_high_f32(x.v)}; } +YNN_ALWAYS_INLINE bf16x4 lo(bf16x8 x) { return bf16x4{vget_low_u16(x.v)}; } +YNN_ALWAYS_INLINE bf16x4 hi(bf16x8 x) { return bf16x4{vget_high_u16(x.v)}; } +YNN_ALWAYS_INLINE u8x8 lo(u8x16 x) { return u8x8{vget_low_u8(x.v)}; } +YNN_ALWAYS_INLINE u8x8 hi(u8x16 x) { return u8x8{vget_high_u8(x.v)}; } + namespace internal { YNN_ALWAYS_INLINE int32x4x2_t vtrn(int32x4_t a, int32x4_t b) { @@ -1205,15 +1205,15 @@ YNN_ALWAYS_INLINE f32x2 cast(f64x2 a, float) { #endif // YNN_ARCH_ARM64 YNN_ALWAYS_INLINE s16x8 cast(s32x8 a, int16_t) { - return s16x8{vcombine_s16(vqmovn_s32(a.lo().v), vqmovn_s32(a.hi().v))}; + return s16x8{vcombine_s16(vqmovn_s32(lo(a).v), vqmovn_s32(hi(a).v))}; } YNN_ALWAYS_INLINE s8x16 cast(s16x16 a, int8_t) { - return s8x16{vcombine_s8(vqmovn_s16(a.lo().v), vqmovn_s16(a.hi().v))}; + return s8x16{vcombine_s8(vqmovn_s16(lo(a).v), vqmovn_s16(hi(a).v))}; } YNN_ALWAYS_INLINE u8x16 cast(s16x16 a, uint8_t) { - return u8x16{vcombine_u8(vqmovun_s16(a.lo().v), vqmovun_s16(a.hi().v))}; + return u8x16{vcombine_u8(vqmovun_s16(lo(a).v), vqmovun_s16(hi(a).v))}; } YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) { @@ -1226,27 +1226,27 @@ YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) { YNN_ALWAYS_INLINE s16x8 cast(f32x8 f, int16_t) { #if defined(__ARM_ARCH) && __ARM_ARCH < 8 - s32x4 a1 = cast(round(f.lo()), int32_t{}); - s32x4 a2 = cast(round(f.hi()), int32_t{}); + s32x4 a1 = cast(round(lo(f)), int32_t{}); + s32x4 a2 = cast(round(hi(f)), int32_t{}); return cast(s32x8{a1, a2}, int16_t{}); #else - return s16x8{vcombine_s16(vqmovn_s32(vcvtnq_s32_f32(f.lo().v)), - vqmovn_s32(vcvtnq_s32_f32(f.hi().v)))}; + return s16x8{vcombine_s16(vqmovn_s32(vcvtnq_s32_f32(lo(f).v)), + vqmovn_s32(vcvtnq_s32_f32(hi(f).v)))}; #endif } YNN_ALWAYS_INLINE s8x16 cast(f32x16 f, int8_t) { s16x16 f_s16 = { - cast(f.lo(), int16_t{}), - cast(f.hi(), int16_t{}), + cast(lo(f), int16_t{}), + cast(hi(f), int16_t{}), }; return cast(f_s16, int8_t{}); } YNN_ALWAYS_INLINE u8x16 cast(f32x16 f, uint8_t) { s16x16 f_s16 = { - cast(f.lo(), int16_t{}), - cast(f.hi(), int16_t{}), + cast(lo(f), int16_t{}), + cast(hi(f), int16_t{}), }; return cast(f_s16, uint8_t{}); } diff --git a/ynnpack/base/simd/arm_neonfp16.h b/ynnpack/base/simd/arm_neonfp16.h index 3eee4904684..25d91fe0bc1 100644 --- a/ynnpack/base/simd/arm_neonfp16.h +++ b/ynnpack/base/simd/arm_neonfp16.h @@ -34,7 +34,7 @@ YNN_ALWAYS_INLINE f16x4 cast(f32x4 a, half) { YNN_ALWAYS_INLINE f16x8 cast(f32x8 a, half) { return f16x8{vreinterpretq_u16_f16( - vcombine_f16(vcvt_f16_f32(a.lo().v), vcvt_f16_f32(a.hi().v)))}; + vcombine_f16(vcvt_f16_f32(lo(a).v), vcvt_f16_f32(hi(a).v)))}; } } // namespace simd diff --git a/ynnpack/base/simd/byte_vec.h b/ynnpack/base/simd/byte_vec.h index 757f8c22565..a12abae98d6 100644 --- a/ynnpack/base/simd/byte_vec.h +++ b/ynnpack/base/simd/byte_vec.h @@ -46,14 +46,16 @@ struct vec { explicit vec(uint64_t v) : v(v) {} vec(u8x4 x0, u8x4 x1) : v((static_cast(x1.v) << 32) | x0.v) {} - u8x4 lo() const { return u8x4{static_cast(v)}; } - u8x4 hi() const { return u8x4{static_cast(v >> 32)}; } - uint64_t v; }; using u8x8 = vec; +YNN_ALWAYS_INLINE u8x4 lo(u8x8 x) { return u8x4{static_cast(x.v)}; } +YNN_ALWAYS_INLINE u8x4 hi(u8x8 x) { + return u8x4{static_cast(x.v >> 32)}; +} + YNN_ALWAYS_INLINE u8x4 load_aligned(const uint8_t* ptr, decltype(u8x4::N), u8x4 = {}) { return u8x4{*reinterpret_cast(ptr)}; diff --git a/ynnpack/base/simd/generic.inc b/ynnpack/base/simd/generic.inc index 7a385c63415..ffe1ba628f1 100644 --- a/ynnpack/base/simd/generic.inc +++ b/ynnpack/base/simd/generic.inc @@ -52,9 +52,9 @@ template YNN_ALWAYS_INLINE vec load(const T* ptr, size_t n, vec src) { std::integral_constant n2 = {}; if (n < n2) { - return {load(ptr, n, src.lo()), src.hi()}; + return {load(ptr, n, lo(src)), hi(src)}; } else { - return {load(ptr, n2), load(ptr + n2, n - n2, src.hi())}; + return {load(ptr, n2), load(ptr + n2, n - n2, hi(src))}; } } template @@ -81,43 +81,43 @@ template YNN_ALWAYS_INLINE void store(T* ptr, vec value, std::integral_constant n) { std::integral_constant n2 = {}; - store(ptr, value.lo(), n2); - store(ptr + n2, value.hi(), n2); + store(ptr, lo(value), n2); + store(ptr + n2, hi(value), n2); } template YNN_ALWAYS_INLINE void store_aligned(T* ptr, vec value, std::integral_constant n) { std::integral_constant n2 = {}; - store_aligned(ptr, value.lo(), n2); - store_aligned(ptr + n2, value.hi(), n2); + store_aligned(ptr, lo(value), n2); + store_aligned(ptr + n2, hi(value), n2); } template YNN_ALWAYS_INLINE void store(T* ptr, vec value, size_t n) { std::integral_constant n2 = {}; if (n < n2) { - store(ptr, value.lo(), n); + store(ptr, lo(value), n); } else { - store(ptr, value.lo(), n2); - store(ptr + n2, value.hi(), n - n2); + store(ptr, lo(value), n2); + store(ptr + n2, hi(value), n - n2); } } // Arithmetic operators. template YNN_ALWAYS_INLINE vec operator+(vec a, vec b) { - return {a.lo() + b.lo(), a.hi() + b.hi()}; + return {lo(a) + lo(b), hi(a) + hi(b)}; } template YNN_ALWAYS_INLINE vec operator-(vec a, vec b) { - return {a.lo() - b.lo(), a.hi() - b.hi()}; + return {lo(a) - lo(b), hi(a) - hi(b)}; } template YNN_ALWAYS_INLINE vec operator*(vec a, vec b) { - return {a.lo() * b.lo(), a.hi() * b.hi()}; + return {lo(a) * lo(b), hi(a) * hi(b)}; } template YNN_ALWAYS_INLINE vec operator/(vec a, vec b) { - return {a.lo() / b.lo(), a.hi() / b.hi()}; + return {lo(a) / lo(b), hi(a) / hi(b)}; } template @@ -144,23 +144,23 @@ YNN_ALWAYS_INLINE vec& operator/=(vec& a, vec b) { // Boolean operators. template YNN_ALWAYS_INLINE vec operator&(vec a, vec b) { - return {a.lo() & b.lo(), a.hi() & b.hi()}; + return {lo(a) & lo(b), hi(a) & hi(b)}; } template YNN_ALWAYS_INLINE vec operator|(vec a, vec b) { - return {a.lo() | b.lo(), a.hi() | b.hi()}; + return {lo(a) | lo(b), hi(a) | hi(b)}; } template YNN_ALWAYS_INLINE vec operator^(vec a, vec b) { - return {a.lo() ^ b.lo(), a.hi() ^ b.hi()}; + return {lo(a) ^ lo(b), hi(a) ^ hi(b)}; } template YNN_ALWAYS_INLINE vec operator~(vec a) { - return {~a.lo(), ~a.hi()}; + return {~lo(a), ~hi(a)}; } template YNN_ALWAYS_INLINE vec operator<<(vec a, int b) { - return {a.lo() << b, a.hi() << b}; + return {lo(a) << b, hi(a) << b}; } template @@ -180,60 +180,60 @@ YNN_ALWAYS_INLINE vec& operator^=(vec& a, vec b) { } template YNN_ALWAYS_INLINE vec min(vec a, vec b) { - return {min(a.lo(), b.lo()), min(a.hi(), b.hi())}; + return {min(lo(a), lo(b)), min(hi(a), hi(b))}; } template YNN_ALWAYS_INLINE vec max(vec a, vec b) { - return {max(a.lo(), b.lo()), max(a.hi(), b.hi())}; + return {max(lo(a), lo(b)), max(hi(a), hi(b))}; } template YNN_ALWAYS_INLINE vec copysign(vec mag, vec sgn) { - return {copysign(mag.lo(), sgn.lo()), copysign(mag.hi(), sgn.hi())}; + return {copysign(lo(mag), lo(sgn)), copysign(hi(mag), hi(sgn))}; }; template YNN_ALWAYS_INLINE vec abs(vec a) { - return {abs(a.lo()), abs(a.hi())}; + return {abs(lo(a)), abs(hi(a))}; } template YNN_ALWAYS_INLINE vec add_sat(vec a, vec b) { - return {add_sat(a.lo(), b.lo()), add_sat(a.hi(), b.hi())}; + return {add_sat(lo(a), lo(b)), add_sat(hi(a), hi(b))}; } template YNN_ALWAYS_INLINE vec sub_sat(vec a, vec b) { - return {sub_sat(a.lo(), b.lo()), sub_sat(a.hi(), b.hi())}; + return {sub_sat(lo(a), lo(b)), sub_sat(hi(a), hi(b))}; } template YNN_ALWAYS_INLINE vec floor(vec a) { - return {floor(a.lo()), floor(a.hi())}; + return {floor(lo(a)), floor(hi(a))}; } template YNN_ALWAYS_INLINE vec floor_log2(vec a) { - return {floor_log2(a.lo()), floor_log2(a.hi())}; + return {floor_log2(lo(a)), floor_log2(hi(a))}; } template YNN_ALWAYS_INLINE vec exp2_round(vec a) { - return {exp2_round(a.lo()), exp2_round(a.hi())}; + return {exp2_round(lo(a)), exp2_round(hi(a))}; } template YNN_ALWAYS_INLINE vec copynan(vec x, vec nan) { - return {copynan(x.lo(), nan.lo()), copynan(x.hi(), nan.hi())}; + return {copynan(lo(x), lo(nan)), copynan(hi(x), hi(nan))}; } template YNN_ALWAYS_INLINE vec ceil(vec a) { - return {ceil(a.lo()), ceil(a.hi())}; + return {ceil(lo(a)), ceil(hi(a))}; } template YNN_ALWAYS_INLINE vec round(vec a) { - return {round(a.lo()), round(a.hi())}; + return {round(lo(a)), round(hi(a))}; } template YNN_ALWAYS_INLINE vec sqrt(vec a) { - return {sqrt(a.lo()), sqrt(a.hi())}; + return {sqrt(lo(a)), sqrt(hi(a))}; } template YNN_ALWAYS_INLINE vec fma(vec a, vec b, vec acc) { - return {fma(a.lo(), b.lo(), acc.lo()), fma(a.hi(), b.hi(), acc.hi())}; + return {fma(lo(a), lo(b), lo(acc)), fma(hi(a), hi(b), hi(acc))}; } template @@ -246,7 +246,7 @@ template YNN_ALWAYS_INLINE vec extract(vec x, std::integral_constant) { static_assert(Index == 0 || Index == 1, ""); - return Index == 0 ? x.lo() : x.hi(); + return Index == 0 ? lo(x) : hi(x); } template YNN_ALWAYS_INLINE vec extract(vec x, @@ -263,31 +263,31 @@ YNN_ALWAYS_INLINE vec concat(vec a, vec b) { template YNN_ALWAYS_INLINE vec cast(vec from, To) { - return {cast(from.lo(), To()), cast(from.hi(), To())}; + return {cast(lo(from), To()), cast(hi(from), To())}; } template YNN_ALWAYS_INLINE T horizontal_sum(vec x) { - return horizontal_sum(x.lo() + x.hi()); + return horizontal_sum(lo(x) + hi(x)); } template YNN_ALWAYS_INLINE T horizontal_min(vec x) { - return horizontal_min(min(x.lo(), x.hi())); + return horizontal_min(min(lo(x), hi(x))); } template YNN_ALWAYS_INLINE T horizontal_max(vec x) { - return horizontal_max(max(x.lo(), x.hi())); + return horizontal_max(max(lo(x), hi(x))); } template YNN_ALWAYS_INLINE void kahan_sum(vec a, vec& acc, vec& error) { - vec acc_lo = acc.lo(); - vec acc_hi = acc.hi(); - vec error_lo = error.lo(); - vec error_hi = error.hi(); - kahan_sum(a.lo(), acc_lo, error_lo); - kahan_sum(a.hi(), acc_hi, error_hi); + vec acc_lo = lo(acc); + vec acc_hi = hi(acc); + vec error_lo = lo(error); + vec error_hi = hi(error); + kahan_sum(lo(a), acc_lo, error_lo); + kahan_sum(hi(a), acc_hi, error_hi); acc = concat(acc_lo, acc_hi); error = concat(error_lo, error_hi); } diff --git a/ynnpack/base/simd/vec.h b/ynnpack/base/simd/vec.h index 9ffb3427c7c..c078d734e24 100644 --- a/ynnpack/base/simd/vec.h +++ b/ynnpack/base/simd/vec.h @@ -68,11 +68,6 @@ struct vec { subvec v[2]; - subvec& lo() { return v[0]; } - const subvec& lo() const { return v[0]; } - subvec& hi() { return v[1]; } - const subvec& hi() const { return v[1]; } - vec() = default; YNN_ALWAYS_INLINE explicit vec(value_type x) : v{subvec{x}, subvec{x}} {} YNN_ALWAYS_INLINE vec(subvec v0, subvec v1) : v{v0, v1} {} @@ -81,6 +76,26 @@ struct vec { YNN_ALWAYS_INLINE const subvec& operator[](size_t i) const { return v[i]; } }; +template +YNN_ALWAYS_INLINE vec& lo(vec& x) { + return x.v[0]; +} + +template +YNN_ALWAYS_INLINE const vec& lo(const vec& x) { + return x.v[0]; +} + +template +YNN_ALWAYS_INLINE vec& hi(vec& x) { + return x.v[1]; +} + +template +YNN_ALWAYS_INLINE const vec& hi(const vec& x) { + return x.v[1]; +} + template YNN_ALWAYS_INLINE vec broadcast(T x) { return vec{x}; diff --git a/ynnpack/base/simd/wasm_simd128.h b/ynnpack/base/simd/wasm_simd128.h index 318c267c84f..2de8e2a7e79 100644 --- a/ynnpack/base/simd/wasm_simd128.h +++ b/ynnpack/base/simd/wasm_simd128.h @@ -616,15 +616,15 @@ YNN_ALWAYS_INLINE f32x4 cast(s32x4 x, float) { } YNN_ALWAYS_INLINE s16x8 cast(s32x8 a, int16_t) { - return s16x8{wasm_i16x8_narrow_i32x4(a.lo().v, a.hi().v)}; + return s16x8{wasm_i16x8_narrow_i32x4(lo(a).v, hi(a).v)}; } YNN_ALWAYS_INLINE s8x16 cast(s16x16 a, int8_t) { - return s8x16{wasm_i8x16_narrow_i16x8(a.lo().v, a.hi().v)}; + return s8x16{wasm_i8x16_narrow_i16x8(lo(a).v, hi(a).v)}; } YNN_ALWAYS_INLINE u8x16 cast(s16x16 a, uint8_t) { - return u8x16{wasm_u8x16_narrow_i16x8(a.lo().v, a.hi().v)}; + return u8x16{wasm_u8x16_narrow_i16x8(lo(a).v, hi(a).v)}; } YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) { @@ -632,26 +632,22 @@ YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) { } YNN_ALWAYS_INLINE s16x8 cast(f32x8 f, int16_t) { - const v128_t i0 = wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(f.lo().v)); - const v128_t i1 = wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(f.hi().v)); + const v128_t i0 = wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(lo(f).v)); + const v128_t i1 = wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(hi(f).v)); return cast(s32x8(s32x4(i0), s32x4(i1)), int16_t()); } YNN_ALWAYS_INLINE s8x16 cast(f32x16 f, int8_t) { - const s16x8 i01 = cast(f32x8(f.lo().lo(), f.lo().hi()), int16_t()); - const s16x8 i23 = cast(f32x8(f.hi().lo(), f.hi().hi()), int16_t()); + const s16x8 i01 = cast(f32x8(lo(lo(f)), hi(lo(f))), int16_t()); + const s16x8 i23 = cast(f32x8(lo(hi(f)), hi(hi(f))), int16_t()); return cast(s16x16(i01, i23), int8_t()); } YNN_ALWAYS_INLINE u8x16 cast(f32x16 f, uint8_t) { - const v128_t i0 = - wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(f.lo().lo().v)); - const v128_t i1 = - wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(f.lo().hi().v)); - const v128_t i2 = - wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(f.hi().lo().v)); - const v128_t i3 = - wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(f.hi().hi().v)); + const v128_t i0 = wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(lo(lo(f)).v)); + const v128_t i1 = wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(hi(lo(f)).v)); + const v128_t i2 = wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(lo(hi(f)).v)); + const v128_t i3 = wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_nearest(hi(hi(f)).v)); const v128_t i01_16 = wasm_i16x8_narrow_i32x4(i0, i1); const v128_t i23_16 = wasm_i16x8_narrow_i32x4(i2, i3); return u8x16{wasm_u8x16_narrow_i16x8(i01_16, i23_16)}; diff --git a/ynnpack/base/simd/x86_avx2.h b/ynnpack/base/simd/x86_avx2.h index 36b3094a18c..e29a55773ca 100644 --- a/ynnpack/base/simd/x86_avx2.h +++ b/ynnpack/base/simd/x86_avx2.h @@ -47,8 +47,8 @@ YNN_ALWAYS_INLINE f32x8 cast(s32x8 x, float) { } YNN_ALWAYS_INLINE bf16x16 cast(f32x16 a, bfloat16) { - __m256 nan_mask_lo = _mm256_cmp_ps(a.lo().v, a.lo().v, _CMP_UNORD_Q); - __m256i u_lo = _mm256_castps_si256(a.lo().v); + __m256 nan_mask_lo = _mm256_cmp_ps(lo(a).v, lo(a).v, _CMP_UNORD_Q); + __m256i u_lo = _mm256_castps_si256(lo(a).v); __m256i lsb_lo = _mm256_and_si256(_mm256_srli_epi32(u_lo, 16), _mm256_set1_epi32(1)); __m256i bias_lo = _mm256_add_epi32(_mm256_set1_epi32(0x7FFF), lsb_lo); @@ -58,8 +58,8 @@ YNN_ALWAYS_INLINE bf16x16 cast(f32x16 a, bfloat16) { nan_mask_lo)); __m256i c1 = _mm256_srli_epi32(res_lo, 16); - __m256 nan_mask_hi = _mm256_cmp_ps(a.hi().v, a.hi().v, _CMP_UNORD_Q); - __m256i u_hi = _mm256_castps_si256(a.hi().v); + __m256 nan_mask_hi = _mm256_cmp_ps(hi(a).v, hi(a).v, _CMP_UNORD_Q); + __m256i u_hi = _mm256_castps_si256(hi(a).v); __m256i lsb_hi = _mm256_and_si256(_mm256_srli_epi32(u_hi, 16), _mm256_set1_epi32(1)); __m256i bias_hi = _mm256_add_epi32(_mm256_set1_epi32(0x7FFF), lsb_hi); @@ -74,17 +74,17 @@ YNN_ALWAYS_INLINE bf16x16 cast(f32x16 a, bfloat16) { } YNN_ALWAYS_INLINE s16x16 cast(s32x16 a, int16_t) { - const __m256i r = _mm256_packs_epi32(a.lo().v, a.hi().v); + const __m256i r = _mm256_packs_epi32(lo(a).v, hi(a).v); return s16x16{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))}; } YNN_ALWAYS_INLINE s8x32 cast(s16x32 a, int8_t) { - const __m256i r = _mm256_packs_epi16(a.lo().v, a.hi().v); + const __m256i r = _mm256_packs_epi16(lo(a).v, hi(a).v); return s8x32{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))}; } YNN_ALWAYS_INLINE u8x32 cast(s16x32 a, uint8_t) { - const __m256i r = _mm256_packus_epi16(a.lo().v, a.hi().v); + const __m256i r = _mm256_packus_epi16(lo(a).v, hi(a).v); return u8x32{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))}; } @@ -99,22 +99,22 @@ YNN_ALWAYS_INLINE s32x8 cast(f32x8 f, int32_t) { } YNN_ALWAYS_INLINE s16x16 cast(f32x16 f, int16_t) { - const s32x8 i0 = cast(f.lo(), int32_t()); - const s32x8 i1 = cast(f.hi(), int32_t()); + const s32x8 i0 = cast(lo(f), int32_t()); + const s32x8 i1 = cast(hi(f), int32_t()); return cast(s32x16(i0, i1), int16_t()); } YNN_ALWAYS_INLINE s8x32 cast(f32x32 f, int8_t) { - const s16x16 i01 = cast(f.lo(), int16_t()); - const s16x16 i23 = cast(f.hi(), int16_t()); + const s16x16 i01 = cast(lo(f), int16_t()); + const s16x16 i23 = cast(hi(f), int16_t()); return cast(s16x32(i01, i23), int8_t()); } YNN_ALWAYS_INLINE u8x32 cast(f32x32 f, uint8_t) { - const s32x8 i0 = cast(f.lo().lo(), int32_t()); - const s32x8 i1 = cast(f.lo().hi(), int32_t()); - const s32x8 i2 = cast(f.hi().lo(), int32_t()); - const s32x8 i3 = cast(f.hi().hi(), int32_t()); + const s32x8 i0 = cast(lo(lo(f)), int32_t()); + const s32x8 i1 = cast(hi(lo(f)), int32_t()); + const s32x8 i2 = cast(lo(hi(f)), int32_t()); + const s32x8 i3 = cast(hi(hi(f)), int32_t()); const __m256i i01_16 = _mm256_packs_epi32(i0.v, i1.v); const __m256i i23_16 = _mm256_packs_epi32(i2.v, i3.v); const __m256i r = _mm256_packus_epi16(i01_16, i23_16); diff --git a/ynnpack/base/simd/x86_avx512.h b/ynnpack/base/simd/x86_avx512.h index a9b55148e1c..48b788965d1 100644 --- a/ynnpack/base/simd/x86_avx512.h +++ b/ynnpack/base/simd/x86_avx512.h @@ -69,9 +69,6 @@ struct vec { vec(double x) : v(_mm512_set1_pd(x)) {} // NOLINT __m512d v; - - YNN_ALWAYS_INLINE f64x4 lo() const { return f64x4{internal::lo(v)}; } - YNN_ALWAYS_INLINE f64x4 hi() const { return f64x4{internal::hi(v)}; } }; template <> @@ -85,9 +82,6 @@ struct vec { vec(float x) : v(_mm512_set1_ps(x)) {} // NOLINT __m512 v; - - YNN_ALWAYS_INLINE f32x8 lo() const { return f32x8{internal::lo(v)}; } - YNN_ALWAYS_INLINE f32x8 hi() const { return f32x8{internal::hi(v)}; } }; template <> @@ -101,9 +95,6 @@ struct vec { vec(uint32_t x) : v(_mm512_set1_epi32(x)) {} // NOLINT __m512i v; - - YNN_ALWAYS_INLINE u32x8 lo() const { return u32x8{internal::lo(v)}; } - YNN_ALWAYS_INLINE u32x8 hi() const { return u32x8{internal::hi(v)}; } }; template <> @@ -117,9 +108,6 @@ struct vec { vec(int32_t x) : v(_mm512_set1_epi32(x)) {} // NOLINT __m512i v; - - YNN_ALWAYS_INLINE s32x8 lo() const { return s32x8{internal::lo(v)}; } - YNN_ALWAYS_INLINE s32x8 hi() const { return s32x8{internal::hi(v)}; } }; template <> @@ -133,9 +121,6 @@ struct vec { vec(bfloat16 x) : v(_mm512_set1_epi16(x.to_bits())) {} // NOLINT __m512i v; - - YNN_ALWAYS_INLINE bf16x16 lo() const { return bf16x16{internal::lo(v)}; } - YNN_ALWAYS_INLINE bf16x16 hi() const { return bf16x16{internal::hi(v)}; } }; template <> @@ -149,9 +134,6 @@ struct vec { vec(half x) : v(_mm512_set1_epi16(x.to_bits())) {} // NOLINT __m512i v; - - YNN_ALWAYS_INLINE f16x16 lo() const { return f16x16{internal::lo(v)}; } - YNN_ALWAYS_INLINE f16x16 hi() const { return f16x16{internal::hi(v)}; } }; template <> @@ -165,9 +147,6 @@ struct vec { vec(uint16_t x) : v(_mm512_set1_epi16(x)) {} // NOLINT __m512i v; - - YNN_ALWAYS_INLINE u16x16 lo() const { return u16x16{internal::lo(v)}; } - YNN_ALWAYS_INLINE u16x16 hi() const { return u16x16{internal::hi(v)}; } }; template <> @@ -181,9 +160,6 @@ struct vec { vec(int16_t x) : v(_mm512_set1_epi16(x)) {} // NOLINT __m512i v; - - YNN_ALWAYS_INLINE s16x16 lo() const { return s16x16{internal::lo(v)}; } - YNN_ALWAYS_INLINE s16x16 hi() const { return s16x16{internal::hi(v)}; } }; template <> @@ -197,9 +173,6 @@ struct vec { vec(uint8_t x) : v(_mm512_set1_epi8(x)) {} // NOLINT __m512i v; - - YNN_ALWAYS_INLINE u8x32 lo() const { return u8x32{internal::lo(v)}; } - YNN_ALWAYS_INLINE u8x32 hi() const { return u8x32{internal::hi(v)}; } }; template <> @@ -213,9 +186,6 @@ struct vec { vec(int8_t x) : v(_mm512_set1_epi8(x)) {} // NOLINT __m512i v; - - YNN_ALWAYS_INLINE s8x32 lo() const { return s8x32{internal::lo(v)}; } - YNN_ALWAYS_INLINE s8x32 hi() const { return s8x32{internal::hi(v)}; } }; using f64x8 = vec; @@ -230,6 +200,27 @@ using u8x64 = vec; using s8x64 = vec; using f32x64 = vec; +YNN_ALWAYS_INLINE f64x4 lo(f64x8 x) { return f64x4{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE f64x4 hi(f64x8 x) { return f64x4{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE f32x8 lo(f32x16 x) { return f32x8{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE f32x8 hi(f32x16 x) { return f32x8{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE u32x8 lo(u32x16 x) { return u32x8{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE u32x8 hi(u32x16 x) { return u32x8{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE s32x8 lo(s32x16 x) { return s32x8{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE s32x8 hi(s32x16 x) { return s32x8{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE bf16x16 lo(bf16x32 x) { return bf16x16{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE bf16x16 hi(bf16x32 x) { return bf16x16{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE f16x16 lo(f16x32 x) { return f16x16{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE f16x16 hi(f16x32 x) { return f16x16{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE u16x16 lo(u16x32 x) { return u16x16{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE u16x16 hi(u16x32 x) { return u16x16{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE s16x16 lo(s16x32 x) { return s16x16{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE s16x16 hi(s16x32 x) { return s16x16{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE u8x32 lo(u8x64 x) { return u8x32{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE u8x32 hi(u8x64 x) { return u8x32{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE s8x32 lo(s8x64 x) { return s8x32{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE s8x32 hi(s8x64 x) { return s8x32{internal::hi(x.v)}; } + YNN_ALWAYS_INLINE f64x8 load_aligned(const double* ptr, decltype(f64x8::N), f64x8 = {}) { return f64x8{_mm512_load_pd(ptr)}; @@ -1094,10 +1085,10 @@ YNN_ALWAYS_INLINE f32x16 cast(bf16x16 a, float) { YNN_ALWAYS_INLINE bf16x32 cast(f32x32 a, bfloat16) { #ifdef YNN_ARCH_X86_AVX512BF16 - return bf16x32{(__m512i)_mm512_cvtne2ps_pbh(a.hi().v, a.lo().v)}; + return bf16x32{(__m512i)_mm512_cvtne2ps_pbh(hi(a).v, lo(a).v)}; #else - __m512i u_lo = _mm512_castps_si512(a.lo().v); - __mmask16 nan_mask_lo = _mm512_cmp_ps_mask(a.lo().v, a.lo().v, _CMP_UNORD_Q); + __m512i u_lo = _mm512_castps_si512(lo(a).v); + __mmask16 nan_mask_lo = _mm512_cmp_ps_mask(lo(a).v, lo(a).v, _CMP_UNORD_Q); __m512i lsb_lo = _mm512_and_si512(_mm512_srli_epi32(u_lo, 16), _mm512_set1_epi32(1)); __m512i bias_lo = _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), lsb_lo); @@ -1106,8 +1097,8 @@ YNN_ALWAYS_INLINE bf16x32 cast(f32x32 a, bfloat16) { _mm512_set1_epi32(0x00010000)); __m512i c1 = _mm512_srli_epi32(res_lo, 16); - __m512i u_hi = _mm512_castps_si512(a.hi().v); - __mmask16 nan_mask_hi = _mm512_cmp_ps_mask(a.hi().v, a.hi().v, _CMP_UNORD_Q); + __m512i u_hi = _mm512_castps_si512(hi(a).v); + __mmask16 nan_mask_hi = _mm512_cmp_ps_mask(hi(a).v, hi(a).v, _CMP_UNORD_Q); __m512i lsb_hi = _mm512_and_si512(_mm512_srli_epi32(u_hi, 16), _mm512_set1_epi32(1)); __m512i bias_hi = _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), lsb_hi); @@ -1166,19 +1157,19 @@ YNN_ALWAYS_INLINE f32x8 cast(f64x8 a, float) { } YNN_ALWAYS_INLINE s16x32 cast(s32x32 a, int16_t) { - const __m512i r = _mm512_packs_epi32(a.lo().v, a.hi().v); + const __m512i r = _mm512_packs_epi32(lo(a).v, hi(a).v); return s16x32{ _mm512_permutexvar_epi64(_mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7), r)}; } YNN_ALWAYS_INLINE s8x64 cast(s16x64 a, int8_t) { - const __m512i r = _mm512_packs_epi16(a.lo().v, a.hi().v); + const __m512i r = _mm512_packs_epi16(lo(a).v, hi(a).v); return s8x64{ _mm512_permutexvar_epi64(_mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7), r)}; } YNN_ALWAYS_INLINE u8x64 cast(s16x64 a, uint8_t) { - const __m512i r = _mm512_packus_epi16(a.lo().v, a.hi().v); + const __m512i r = _mm512_packus_epi16(lo(a).v, hi(a).v); return u8x64{ _mm512_permutexvar_epi64(_mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7), r)}; } @@ -1193,16 +1184,16 @@ YNN_ALWAYS_INLINE s32x16 cast(f32x16 f, int32_t) { } YNN_ALWAYS_INLINE s16x32 cast(f32x32 f, int16_t) { - const s32x16 i0 = cast(f.lo(), int32_t()); - const s32x16 i1 = cast(f.hi(), int32_t()); + const s32x16 i0 = cast(lo(f), int32_t()); + const s32x16 i1 = cast(hi(f), int32_t()); return cast(s32x32(i0, i1), int16_t()); } YNN_ALWAYS_INLINE u8x64 cast(f32x64 f, uint8_t) { - const s32x16 i0 = cast(f.lo().lo(), int32_t()); - const s32x16 i1 = cast(f.lo().hi(), int32_t()); - const s32x16 i2 = cast(f.hi().lo(), int32_t()); - const s32x16 i3 = cast(f.hi().hi(), int32_t()); + const s32x16 i0 = cast(lo(lo(f)), int32_t()); + const s32x16 i1 = cast(hi(lo(f)), int32_t()); + const s32x16 i2 = cast(lo(hi(f)), int32_t()); + const s32x16 i3 = cast(hi(hi(f)), int32_t()); const __m512i i01_16 = _mm512_packs_epi32(i0.v, i1.v); const __m512i i23_16 = _mm512_packs_epi32(i2.v, i3.v); const __m512i r = _mm512_packus_epi16(i01_16, i23_16); @@ -1212,10 +1203,10 @@ YNN_ALWAYS_INLINE u8x64 cast(f32x64 f, uint8_t) { } YNN_ALWAYS_INLINE s8x64 cast(f32x64 f, int8_t) { - const s32x16 i0 = cast(f.lo().lo(), int32_t()); - const s32x16 i1 = cast(f.lo().hi(), int32_t()); - const s32x16 i2 = cast(f.hi().lo(), int32_t()); - const s32x16 i3 = cast(f.hi().hi(), int32_t()); + const s32x16 i0 = cast(lo(lo(f)), int32_t()); + const s32x16 i1 = cast(hi(lo(f)), int32_t()); + const s32x16 i2 = cast(lo(hi(f)), int32_t()); + const s32x16 i3 = cast(hi(hi(f)), int32_t()); const __m512i i01_16 = _mm512_packs_epi32(i0.v, i1.v); const __m512i i23_16 = _mm512_packs_epi32(i2.v, i3.v); const __m512i r = _mm512_packs_epi16(i01_16, i23_16); diff --git a/ynnpack/base/simd/x86_avx_base.h b/ynnpack/base/simd/x86_avx_base.h index c86a96e55ce..c34148df3d0 100644 --- a/ynnpack/base/simd/x86_avx_base.h +++ b/ynnpack/base/simd/x86_avx_base.h @@ -38,6 +38,9 @@ YNN_ALWAYS_INLINE __m128i hi(__m256i x) { return _mm_castps_si128(_mm256_extractf128_ps(_mm256_castsi256_ps(x), 1)); } +YNN_ALWAYS_INLINE __m128d lo(__m256d x) {return _mm256_castpd256_pd128(x); } +YNN_ALWAYS_INLINE __m128d hi(__m256d x) { return _mm256_extractf128_pd(x, 1); } + YNN_ALWAYS_INLINE __m256 concat(__m128 lo, __m128 hi) { return _mm256_insertf128_ps(_mm256_castps128_ps256(lo), hi, 1); } @@ -62,13 +65,6 @@ struct vec { vec(double x) : v(_mm256_set1_pd(x)) {} // NOLINT __m256d v; - - YNN_ALWAYS_INLINE f64x2 lo() const { - return f64x2{_mm256_castpd256_pd128(v)}; - } - YNN_ALWAYS_INLINE f64x2 hi() const { - return f64x2{_mm256_extractf128_pd(v, 1)}; - } }; template <> @@ -82,9 +78,6 @@ struct vec { vec(float x) : v(_mm256_set1_ps(x)) {} // NOLINT __m256 v; - - YNN_ALWAYS_INLINE f32x4 lo() const { return f32x4{internal::lo(v)}; } - YNN_ALWAYS_INLINE f32x4 hi() const { return f32x4{internal::hi(v)}; } }; template <> @@ -98,9 +91,6 @@ struct vec { vec(uint32_t x) : v(_mm256_set1_epi32(x)) {} // NOLINT __m256i v; - - YNN_ALWAYS_INLINE u32x4 lo() const { return u32x4{internal::lo(v)}; } - YNN_ALWAYS_INLINE u32x4 hi() const { return u32x4{internal::hi(v)}; } }; template <> @@ -114,9 +104,6 @@ struct vec { vec(int32_t x) : v(_mm256_set1_epi32(x)) {} // NOLINT __m256i v; - - YNN_ALWAYS_INLINE s32x4 lo() const { return s32x4{internal::lo(v)}; } - YNN_ALWAYS_INLINE s32x4 hi() const { return s32x4{internal::hi(v)}; } }; template <> @@ -130,9 +117,6 @@ struct vec { vec(bfloat16 x) : v(_mm256_set1_epi16(x.to_bits())) {} // NOLINT __m256i v; - - YNN_ALWAYS_INLINE bf16x8 lo() const { return bf16x8{internal::lo(v)}; } - YNN_ALWAYS_INLINE bf16x8 hi() const { return bf16x8{internal::hi(v)}; } }; template <> @@ -146,9 +130,6 @@ struct vec { vec(half x) : v(_mm256_set1_epi16(x.to_bits())) {} // NOLINT __m256i v; - - YNN_ALWAYS_INLINE f16x8 lo() const { return f16x8{internal::lo(v)}; } - YNN_ALWAYS_INLINE f16x8 hi() const { return f16x8{internal::hi(v)}; } }; template <> @@ -162,9 +143,6 @@ struct vec { vec(uint16_t x) : v(_mm256_set1_epi16(x)) {} // NOLINT __m256i v; - - YNN_ALWAYS_INLINE u16x8 lo() const { return u16x8{internal::lo(v)}; } - YNN_ALWAYS_INLINE u16x8 hi() const { return u16x8{internal::hi(v)}; } }; template <> @@ -178,9 +156,6 @@ struct vec { vec(int16_t x) : v(_mm256_set1_epi16(x)) {} // NOLINT __m256i v; - - YNN_ALWAYS_INLINE s16x8 lo() const { return s16x8{internal::lo(v)}; } - YNN_ALWAYS_INLINE s16x8 hi() const { return s16x8{internal::hi(v)}; } }; template <> @@ -194,9 +169,6 @@ struct vec { vec(uint8_t x) : v(_mm256_set1_epi8(x)) {} // NOLINT __m256i v; - - YNN_ALWAYS_INLINE u8x16 lo() const { return u8x16{internal::lo(v)}; } - YNN_ALWAYS_INLINE u8x16 hi() const { return u8x16{internal::hi(v)}; } }; template <> @@ -210,9 +182,6 @@ struct vec { vec(int8_t x) : v(_mm256_set1_epi8(x)) {} // NOLINT __m256i v; - - YNN_ALWAYS_INLINE s8x16 lo() const { return s8x16{internal::lo(v)}; } - YNN_ALWAYS_INLINE s8x16 hi() const { return s8x16{internal::hi(v)}; } }; struct s2x128 { @@ -234,6 +203,27 @@ using s16x16 = vec; using u8x32 = vec; using s8x32 = vec; +YNN_ALWAYS_INLINE f64x2 lo(f64x4 x) { return f64x2{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE f64x2 hi(f64x4 x) { return f64x2{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE f32x4 lo(f32x8 x) { return f32x4{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE f32x4 hi(f32x8 x) { return f32x4{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE u32x4 lo(u32x8 x) { return u32x4{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE u32x4 hi(u32x8 x) { return u32x4{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE s32x4 lo(s32x8 x) { return s32x4{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE s32x4 hi(s32x8 x) { return s32x4{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE bf16x8 lo(bf16x16 x) { return bf16x8{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE bf16x8 hi(bf16x16 x) { return bf16x8{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE f16x8 lo(f16x16 x) { return f16x8{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE f16x8 hi(f16x16 x) { return f16x8{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE u16x8 lo(u16x16 x) { return u16x8{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE u16x8 hi(u16x16 x) { return u16x8{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE s16x8 lo(s16x16 x) { return s16x8{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE s16x8 hi(s16x16 x) { return s16x8{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE u8x16 lo(u8x32 x) { return u8x16{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE u8x16 hi(u8x32 x) { return u8x16{internal::hi(x.v)}; } +YNN_ALWAYS_INLINE s8x16 lo(s8x32 x) { return s8x16{internal::lo(x.v)}; } +YNN_ALWAYS_INLINE s8x16 hi(s8x32 x) { return s8x16{internal::hi(x.v)}; } + namespace internal { // These overloads are x86-specific helpers for implementing templated diff --git a/ynnpack/base/simd/x86_sse2_base.h b/ynnpack/base/simd/x86_sse2_base.h index 65ab8693ad4..ba5d27c4b64 100644 --- a/ynnpack/base/simd/x86_sse2_base.h +++ b/ynnpack/base/simd/x86_sse2_base.h @@ -158,15 +158,15 @@ struct vec { vec(double x) : v(_mm_set1_pd(x)) {} // NOLINT __m128d v; - - YNN_ALWAYS_INLINE vec lo() const { - return vec{_mm_cvtsd_f64(v)}; - } - YNN_ALWAYS_INLINE vec hi() const { - return vec{_mm_cvtsd_f64(_mm_unpackhi_pd(v, v))}; - } }; +YNN_ALWAYS_INLINE vec lo(vec x) { + return vec{_mm_cvtsd_f64(x.v)}; +} +YNN_ALWAYS_INLINE vec hi(vec x) { + return vec{_mm_cvtsd_f64(_mm_unpackhi_pd(x.v, x.v))}; +} + using f64x2 = vec; using f32x4 = vec; using u32x4 = vec; diff --git a/ynnpack/base/simd/x86_sse2_cast.h b/ynnpack/base/simd/x86_sse2_cast.h index 15cb8f4c150..ef3d63dfcad 100644 --- a/ynnpack/base/simd/x86_sse2_cast.h +++ b/ynnpack/base/simd/x86_sse2_cast.h @@ -34,34 +34,34 @@ using f32x16 = vec; using s32x16 = vec; YNN_ALWAYS_INLINE s16x8 cast(s32x8 a, int16_t) { - return s16x8{_mm_packs_epi32(a.lo().v, a.hi().v)}; + return s16x8{_mm_packs_epi32(lo(a).v, hi(a).v)}; } YNN_ALWAYS_INLINE s8x16 cast(s16x16 a, int8_t) { - return s8x16{_mm_packs_epi16(a.lo().v, a.hi().v)}; + return s8x16{_mm_packs_epi16(lo(a).v, hi(a).v)}; } YNN_ALWAYS_INLINE u8x16 cast(s16x16 a, uint8_t) { - return u8x16{_mm_packus_epi16(a.lo().v, a.hi().v)}; + return u8x16{_mm_packus_epi16(lo(a).v, hi(a).v)}; } YNN_ALWAYS_INLINE s16x8 cast(f32x8 f, int16_t) { - const s32x4 i0 = cast(f.lo(), int32_t()); - const s32x4 i1 = cast(f.hi(), int32_t()); + const s32x4 i0 = cast(lo(f), int32_t()); + const s32x4 i1 = cast(hi(f), int32_t()); return cast(s32x8(i0, i1), int16_t()); } YNN_ALWAYS_INLINE s8x16 cast(f32x16 f, int8_t) { - const s16x8 i01 = cast(f.lo(), int16_t()); - const s16x8 i23 = cast(f.hi(), int16_t()); + const s16x8 i01 = cast(lo(f), int16_t()); + const s16x8 i23 = cast(hi(f), int16_t()); return cast(s16x16(i01, i23), int8_t()); } YNN_ALWAYS_INLINE u8x16 cast(f32x16 f, uint8_t) { - const s32x4 i0 = cast(f.lo().lo(), int32_t()); - const s32x4 i1 = cast(f.lo().hi(), int32_t()); - const s32x4 i2 = cast(f.hi().lo(), int32_t()); - const s32x4 i3 = cast(f.hi().hi(), int32_t()); + const s32x4 i0 = cast(lo(lo(f)), int32_t()); + const s32x4 i1 = cast(hi(lo(f)), int32_t()); + const s32x4 i2 = cast(lo(hi(f)), int32_t()); + const s32x4 i3 = cast(hi(hi(f)), int32_t()); const __m128i i01_16 = _mm_packs_epi32(i0.v, i1.v); const __m128i i23_16 = _mm_packs_epi32(i2.v, i3.v); return u8x16{_mm_packus_epi16(i01_16, i23_16)}; diff --git a/ynnpack/kernels/elementwise/compiler.py b/ynnpack/kernels/elementwise/compiler.py index 9f127b7305d..941c2521949 100644 --- a/ynnpack/kernels/elementwise/compiler.py +++ b/ynnpack/kernels/elementwise/compiler.py @@ -846,8 +846,8 @@ def __init__( template YNN_INTRINSIC simd::vec select_greater_than(simd::vec a, simd::vec b, simd::vec c, simd::vec d) { - return simd::vec(select_greater_than(a.lo(), b.lo(), c.lo(), d.lo()), - select_greater_than(a.hi(), b.hi(), c.hi(), d.hi())); + return simd::vec(select_greater_than(lo(a), lo(b), lo(c), lo(d)), + select_greater_than(hi(a), hi(b), hi(c), hi(d))); } } // namespace