@@ -24,6 +24,7 @@ namespace ynn {
2424
2525namespace simd {
2626
27+ // Half-vector wrappers
2728template <>
2829struct vec <uint8_t , 8 > {
2930 using value_type = uint8_t ;
@@ -36,8 +37,6 @@ struct vec<uint8_t, 8> {
3637 uint8x8_t v;
3738};
3839
39- using u8x8 = vec<uint8_t , 8 >;
40-
4140template <>
4241struct vec <float , 2 > {
4342 using value_type = float ;
@@ -50,6 +49,36 @@ struct vec<float, 2> {
5049 float32x2_t v;
5150};
5251
52+ template <>
53+ struct vec <bfloat16, 4 > {
54+ using value_type = bfloat16;
55+ static constexpr std::integral_constant<size_t , 4 > N = {};
56+
57+ vec () = default ;
58+ explicit vec (uint16x4_t v) : v(v) {}
59+ vec (bfloat16 x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT
60+
61+ uint16x4_t v;
62+ };
63+
64+ template <>
65+ struct vec <half, 4 > {
66+ using value_type = half;
67+ static constexpr std::integral_constant<size_t , 4 > N = {};
68+
69+ vec () = default ;
70+ explicit vec (uint16x4_t v) : v(v) {}
71+ vec (half x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT
72+
73+ uint16x4_t v;
74+ };
75+
76+ using u8x8 = vec<uint8_t , 8 >;
77+ using f32x2 = vec<float , 2 >;
78+ using bf16x4 = vec<bfloat16, 4 >;
79+ using f16x4 = vec<half, 4 >;
80+
81+ // Full vector wrappers
5382template <>
5483struct vec <float , 4 > {
5584 using value_type = float ;
@@ -58,12 +87,9 @@ struct vec<float, 4> {
5887 vec () = default ;
5988 explicit vec (float32x4_t v) : v(v) {}
6089 vec (float x) : v(vdupq_n_f32(x)) {} // NOLINT
61- vec (vec< float , 2 > lo, vec< float , 2 > hi) : v(vcombine_f32(lo.v, hi.v)) {}
90+ vec (f32x2 lo, f32x2 hi) : v(vcombine_f32(lo.v, hi.v)) {}
6291
6392 float32x4_t v;
64-
65- vec<float , 2 > lo () const { return vec<float , 2 >{vget_low_f32 (v)}; }
66- vec<float , 2 > hi () const { return vec<float , 2 >{vget_high_f32 (v)}; }
6793};
6894
6995#ifdef YNN_ARCH_ARM64
@@ -104,18 +130,6 @@ struct vec<int32_t, 4> {
104130 int32x4_t v;
105131};
106132
107- template <>
108- struct vec <bfloat16, 4 > {
109- using value_type = bfloat16;
110- static constexpr std::integral_constant<size_t , 4 > N = {};
111-
112- vec () = default ;
113- explicit vec (uint16x4_t v) : v(v) {}
114- vec (bfloat16 x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT
115-
116- uint16x4_t v;
117- };
118-
119133template <>
120134struct vec <bfloat16, 8 > {
121135 using value_type = bfloat16;
@@ -124,24 +138,9 @@ struct vec<bfloat16, 8> {
124138 vec () = default ;
125139 explicit vec (uint16x8_t v) : v(v) {}
126140 vec (bfloat16 x) : v(vdupq_n_u16(x.to_bits())) {} // NOLINT
127- vec (vec<bfloat16, 4 > lo, vec<bfloat16, 4 > hi) : v(vcombine_u16(lo.v, hi.v)) {}
141+ vec (bf16x4 lo, bf16x4 hi) : v(vcombine_u16(lo.v, hi.v)) {}
128142
129143 uint16x8_t v;
130-
131- vec<bfloat16, 4 > lo () const { return vec<bfloat16, 4 >{vget_low_u16 (v)}; }
132- vec<bfloat16, 4 > hi () const { return vec<bfloat16, 4 >{vget_high_u16 (v)}; }
133- };
134-
135- template <>
136- struct vec <half, 4 > {
137- using value_type = half;
138- static constexpr std::integral_constant<size_t , 4 > N = {};
139-
140- vec () = default ;
141- explicit vec (uint16x4_t v) : v(v) {}
142- vec (half x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT
143-
144- uint16x4_t v;
145144};
146145
147146template <>
@@ -190,9 +189,6 @@ struct vec<uint8_t, 16> {
190189 vec (u8x8 lo, u8x8 hi) : v(vcombine_u8(lo.v, hi.v)) {}
191190 vec (uint8_t x) : v(vdupq_n_u8(x)) {} // NOLINT
192191
193- u8x8 lo () const { return u8x8{vget_low_u8 (v)}; }
194- u8x8 hi () const { return u8x8{vget_high_u8 (v)}; }
195-
196192 uint8x16_t v;
197193};
198194
@@ -208,22 +204,26 @@ struct vec<int8_t, 16> {
208204 int8x16_t v;
209205};
210206
211- using f32x2 = vec<float , 2 >;
212207using f32x4 = vec<float , 4 >;
213208#ifdef YNN_ARCH_ARM64
214209using f64x2 = vec<double , 2 >;
215210#endif
216211using u32x4 = vec<uint32_t , 4 >;
217212using s32x4 = vec<int32_t , 4 >;
218- using bf16x4 = vec<bfloat16, 4 >;
219213using bf16x8 = vec<bfloat16, 8 >;
220- using f16x4 = vec<half, 4 >;
221214using f16x8 = vec<half, 8 >;
222215using u16x8 = vec<uint16_t , 8 >;
223216using s16x8 = vec<int16_t , 8 >;
224217using u8x16 = vec<uint8_t , 16 >;
225218using s8x16 = vec<int8_t , 16 >;
226219
220+ YNN_ALWAYS_INLINE f32x2 lo (f32x4 x) { return f32x2{vget_low_f32 (x.v )}; }
221+ YNN_ALWAYS_INLINE f32x2 hi (f32x4 x) { return f32x2{vget_high_f32 (x.v )}; }
222+ YNN_ALWAYS_INLINE bf16x4 lo (bf16x8 x) { return bf16x4{vget_low_u16 (x.v )}; }
223+ YNN_ALWAYS_INLINE bf16x4 hi (bf16x8 x) { return bf16x4{vget_high_u16 (x.v )}; }
224+ YNN_ALWAYS_INLINE u8x8 lo (u8x16 x) { return u8x8{vget_low_u8 (x.v )}; }
225+ YNN_ALWAYS_INLINE u8x8 hi (u8x16 x) { return u8x8{vget_high_u8 (x.v )}; }
226+
227227namespace internal {
228228
229229YNN_ALWAYS_INLINE int32x4x2_t vtrn (int32x4_t a, int32x4_t b) {
@@ -1205,15 +1205,15 @@ YNN_ALWAYS_INLINE f32x2 cast(f64x2 a, float) {
12051205#endif // YNN_ARCH_ARM64
12061206
12071207YNN_ALWAYS_INLINE s16x8 cast (s32x8 a, int16_t ) {
1208- return s16x8{vcombine_s16 (vqmovn_s32 (a. lo ().v ), vqmovn_s32 (a. hi ().v ))};
1208+ return s16x8{vcombine_s16 (vqmovn_s32 (lo (a ).v ), vqmovn_s32 (hi (a ).v ))};
12091209}
12101210
12111211YNN_ALWAYS_INLINE s8x16 cast (s16x16 a, int8_t ) {
1212- return s8x16{vcombine_s8 (vqmovn_s16 (a. lo ().v ), vqmovn_s16 (a. hi ().v ))};
1212+ return s8x16{vcombine_s8 (vqmovn_s16 (lo (a ).v ), vqmovn_s16 (hi (a ).v ))};
12131213}
12141214
12151215YNN_ALWAYS_INLINE u8x16 cast (s16x16 a, uint8_t ) {
1216- return u8x16{vcombine_u8 (vqmovun_s16 (a. lo ().v ), vqmovun_s16 (a. hi ().v ))};
1216+ return u8x16{vcombine_u8 (vqmovun_s16 (lo (a ).v ), vqmovun_s16 (hi (a ).v ))};
12171217}
12181218
12191219YNN_ALWAYS_INLINE s32x4 cast (f32x4 f, int32_t ) {
@@ -1226,27 +1226,27 @@ YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) {
12261226
12271227YNN_ALWAYS_INLINE s16x8 cast (f32x8 f, int16_t ) {
12281228#if defined(__ARM_ARCH) && __ARM_ARCH < 8
1229- s32x4 a1 = cast (round (f. lo ()), int32_t {});
1230- s32x4 a2 = cast (round (f. hi ()), int32_t {});
1229+ s32x4 a1 = cast (round (lo (f )), int32_t {});
1230+ s32x4 a2 = cast (round (hi (f )), int32_t {});
12311231 return cast (s32x8{a1, a2}, int16_t {});
12321232#else
1233- return s16x8{vcombine_s16 (vqmovn_s32 (vcvtnq_s32_f32 (f. lo ().v )),
1234- vqmovn_s32 (vcvtnq_s32_f32 (f. hi ().v )))};
1233+ return s16x8{vcombine_s16 (vqmovn_s32 (vcvtnq_s32_f32 (lo (f ).v )),
1234+ vqmovn_s32 (vcvtnq_s32_f32 (hi (f ).v )))};
12351235#endif
12361236}
12371237
12381238YNN_ALWAYS_INLINE s8x16 cast (f32x16 f, int8_t ) {
12391239 s16x16 f_s16 = {
1240- cast (f. lo (), int16_t {}),
1241- cast (f. hi (), int16_t {}),
1240+ cast (lo (f ), int16_t {}),
1241+ cast (hi (f ), int16_t {}),
12421242 };
12431243 return cast (f_s16, int8_t {});
12441244}
12451245
12461246YNN_ALWAYS_INLINE u8x16 cast (f32x16 f, uint8_t ) {
12471247 s16x16 f_s16 = {
1248- cast (f. lo (), int16_t {}),
1249- cast (f. hi (), int16_t {}),
1248+ cast (lo (f ), int16_t {}),
1249+ cast (hi (f ), int16_t {}),
12501250 };
12511251 return cast (f_s16, uint8_t {});
12521252}
0 commit comments