Skip to content

Commit 0547829

Browse files
dsharletgxnnpack-bot
authored andcommitted
Remove lo/hi as member functions of vec<T, N>
I think this is necessary as a step towards attempting to implement conditions and `select`, where the mask types might be things like `__mmask8`. PiperOrigin-RevId: 918084846
1 parent cc68da8 commit 0547829

12 files changed

Lines changed: 229 additions & 235 deletions

File tree

ynnpack/base/simd/arm_neon_base.h

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace ynn {
2424

2525
namespace simd {
2626

27+
// Half-vector wrappers
2728
template <>
2829
struct vec<uint8_t, 8> {
2930
using value_type = uint8_t;
@@ -36,8 +37,6 @@ struct vec<uint8_t, 8> {
3637
uint8x8_t v;
3738
};
3839

39-
using u8x8 = vec<uint8_t, 8>;
40-
4140
template <>
4241
struct vec<float, 2> {
4342
using value_type = float;
@@ -50,6 +49,36 @@ struct vec<float, 2> {
5049
float32x2_t v;
5150
};
5251

52+
template <>
53+
struct vec<bfloat16, 4> {
54+
using value_type = bfloat16;
55+
static constexpr std::integral_constant<size_t, 4> N = {};
56+
57+
vec() = default;
58+
explicit vec(uint16x4_t v) : v(v) {}
59+
vec(bfloat16 x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT
60+
61+
uint16x4_t v;
62+
};
63+
64+
template <>
65+
struct vec<half, 4> {
66+
using value_type = half;
67+
static constexpr std::integral_constant<size_t, 4> N = {};
68+
69+
vec() = default;
70+
explicit vec(uint16x4_t v) : v(v) {}
71+
vec(half x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT
72+
73+
uint16x4_t v;
74+
};
75+
76+
using u8x8 = vec<uint8_t, 8>;
77+
using f32x2 = vec<float, 2>;
78+
using bf16x4 = vec<bfloat16, 4>;
79+
using f16x4 = vec<half, 4>;
80+
81+
// Full vector wrappers
5382
template <>
5483
struct vec<float, 4> {
5584
using value_type = float;
@@ -58,12 +87,9 @@ struct vec<float, 4> {
5887
vec() = default;
5988
explicit vec(float32x4_t v) : v(v) {}
6089
vec(float x) : v(vdupq_n_f32(x)) {} // NOLINT
61-
vec(vec<float, 2> lo, vec<float, 2> hi) : v(vcombine_f32(lo.v, hi.v)) {}
90+
vec(f32x2 lo, f32x2 hi) : v(vcombine_f32(lo.v, hi.v)) {}
6291

6392
float32x4_t v;
64-
65-
vec<float, 2> lo() const { return vec<float, 2>{vget_low_f32(v)}; }
66-
vec<float, 2> hi() const { return vec<float, 2>{vget_high_f32(v)}; }
6793
};
6894

6995
#ifdef YNN_ARCH_ARM64
@@ -104,18 +130,6 @@ struct vec<int32_t, 4> {
104130
int32x4_t v;
105131
};
106132

107-
template <>
108-
struct vec<bfloat16, 4> {
109-
using value_type = bfloat16;
110-
static constexpr std::integral_constant<size_t, 4> N = {};
111-
112-
vec() = default;
113-
explicit vec(uint16x4_t v) : v(v) {}
114-
vec(bfloat16 x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT
115-
116-
uint16x4_t v;
117-
};
118-
119133
template <>
120134
struct vec<bfloat16, 8> {
121135
using value_type = bfloat16;
@@ -124,24 +138,9 @@ struct vec<bfloat16, 8> {
124138
vec() = default;
125139
explicit vec(uint16x8_t v) : v(v) {}
126140
vec(bfloat16 x) : v(vdupq_n_u16(x.to_bits())) {} // NOLINT
127-
vec(vec<bfloat16, 4> lo, vec<bfloat16, 4> hi) : v(vcombine_u16(lo.v, hi.v)) {}
141+
vec(bf16x4 lo, bf16x4 hi) : v(vcombine_u16(lo.v, hi.v)) {}
128142

129143
uint16x8_t v;
130-
131-
vec<bfloat16, 4> lo() const { return vec<bfloat16, 4>{vget_low_u16(v)}; }
132-
vec<bfloat16, 4> hi() const { return vec<bfloat16, 4>{vget_high_u16(v)}; }
133-
};
134-
135-
template <>
136-
struct vec<half, 4> {
137-
using value_type = half;
138-
static constexpr std::integral_constant<size_t, 4> N = {};
139-
140-
vec() = default;
141-
explicit vec(uint16x4_t v) : v(v) {}
142-
vec(half x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT
143-
144-
uint16x4_t v;
145144
};
146145

147146
template <>
@@ -190,9 +189,6 @@ struct vec<uint8_t, 16> {
190189
vec(u8x8 lo, u8x8 hi) : v(vcombine_u8(lo.v, hi.v)) {}
191190
vec(uint8_t x) : v(vdupq_n_u8(x)) {} // NOLINT
192191

193-
u8x8 lo() const { return u8x8{vget_low_u8(v)}; }
194-
u8x8 hi() const { return u8x8{vget_high_u8(v)}; }
195-
196192
uint8x16_t v;
197193
};
198194

@@ -208,22 +204,26 @@ struct vec<int8_t, 16> {
208204
int8x16_t v;
209205
};
210206

211-
using f32x2 = vec<float, 2>;
212207
using f32x4 = vec<float, 4>;
213208
#ifdef YNN_ARCH_ARM64
214209
using f64x2 = vec<double, 2>;
215210
#endif
216211
using u32x4 = vec<uint32_t, 4>;
217212
using s32x4 = vec<int32_t, 4>;
218-
using bf16x4 = vec<bfloat16, 4>;
219213
using bf16x8 = vec<bfloat16, 8>;
220-
using f16x4 = vec<half, 4>;
221214
using f16x8 = vec<half, 8>;
222215
using u16x8 = vec<uint16_t, 8>;
223216
using s16x8 = vec<int16_t, 8>;
224217
using u8x16 = vec<uint8_t, 16>;
225218
using s8x16 = vec<int8_t, 16>;
226219

220+
YNN_ALWAYS_INLINE f32x2 lo(f32x4 x) { return f32x2{vget_low_f32(x.v)}; }
221+
YNN_ALWAYS_INLINE f32x2 hi(f32x4 x) { return f32x2{vget_high_f32(x.v)}; }
222+
YNN_ALWAYS_INLINE bf16x4 lo(bf16x8 x) { return bf16x4{vget_low_u16(x.v)}; }
223+
YNN_ALWAYS_INLINE bf16x4 hi(bf16x8 x) { return bf16x4{vget_high_u16(x.v)}; }
224+
YNN_ALWAYS_INLINE u8x8 lo(u8x16 x) { return u8x8{vget_low_u8(x.v)}; }
225+
YNN_ALWAYS_INLINE u8x8 hi(u8x16 x) { return u8x8{vget_high_u8(x.v)}; }
226+
227227
namespace internal {
228228

229229
YNN_ALWAYS_INLINE int32x4x2_t vtrn(int32x4_t a, int32x4_t b) {
@@ -1205,15 +1205,15 @@ YNN_ALWAYS_INLINE f32x2 cast(f64x2 a, float) {
12051205
#endif // YNN_ARCH_ARM64
12061206

12071207
YNN_ALWAYS_INLINE s16x8 cast(s32x8 a, int16_t) {
1208-
return s16x8{vcombine_s16(vqmovn_s32(a.lo().v), vqmovn_s32(a.hi().v))};
1208+
return s16x8{vcombine_s16(vqmovn_s32(lo(a).v), vqmovn_s32(hi(a).v))};
12091209
}
12101210

12111211
YNN_ALWAYS_INLINE s8x16 cast(s16x16 a, int8_t) {
1212-
return s8x16{vcombine_s8(vqmovn_s16(a.lo().v), vqmovn_s16(a.hi().v))};
1212+
return s8x16{vcombine_s8(vqmovn_s16(lo(a).v), vqmovn_s16(hi(a).v))};
12131213
}
12141214

12151215
YNN_ALWAYS_INLINE u8x16 cast(s16x16 a, uint8_t) {
1216-
return u8x16{vcombine_u8(vqmovun_s16(a.lo().v), vqmovun_s16(a.hi().v))};
1216+
return u8x16{vcombine_u8(vqmovun_s16(lo(a).v), vqmovun_s16(hi(a).v))};
12171217
}
12181218

12191219
YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) {
@@ -1226,27 +1226,27 @@ YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) {
12261226

12271227
YNN_ALWAYS_INLINE s16x8 cast(f32x8 f, int16_t) {
12281228
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
1229-
s32x4 a1 = cast(round(f.lo()), int32_t{});
1230-
s32x4 a2 = cast(round(f.hi()), int32_t{});
1229+
s32x4 a1 = cast(round(lo(f)), int32_t{});
1230+
s32x4 a2 = cast(round(hi(f)), int32_t{});
12311231
return cast(s32x8{a1, a2}, int16_t{});
12321232
#else
1233-
return s16x8{vcombine_s16(vqmovn_s32(vcvtnq_s32_f32(f.lo().v)),
1234-
vqmovn_s32(vcvtnq_s32_f32(f.hi().v)))};
1233+
return s16x8{vcombine_s16(vqmovn_s32(vcvtnq_s32_f32(lo(f).v)),
1234+
vqmovn_s32(vcvtnq_s32_f32(hi(f).v)))};
12351235
#endif
12361236
}
12371237

12381238
YNN_ALWAYS_INLINE s8x16 cast(f32x16 f, int8_t) {
12391239
s16x16 f_s16 = {
1240-
cast(f.lo(), int16_t{}),
1241-
cast(f.hi(), int16_t{}),
1240+
cast(lo(f), int16_t{}),
1241+
cast(hi(f), int16_t{}),
12421242
};
12431243
return cast(f_s16, int8_t{});
12441244
}
12451245

12461246
YNN_ALWAYS_INLINE u8x16 cast(f32x16 f, uint8_t) {
12471247
s16x16 f_s16 = {
1248-
cast(f.lo(), int16_t{}),
1249-
cast(f.hi(), int16_t{}),
1248+
cast(lo(f), int16_t{}),
1249+
cast(hi(f), int16_t{}),
12501250
};
12511251
return cast(f_s16, uint8_t{});
12521252
}

ynnpack/base/simd/arm_neonfp16.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ YNN_ALWAYS_INLINE f16x4 cast(f32x4 a, half) {
3434

3535
YNN_ALWAYS_INLINE f16x8 cast(f32x8 a, half) {
3636
return f16x8{vreinterpretq_u16_f16(
37-
vcombine_f16(vcvt_f16_f32(a.lo().v), vcvt_f16_f32(a.hi().v)))};
37+
vcombine_f16(vcvt_f16_f32(lo(a).v), vcvt_f16_f32(hi(a).v)))};
3838
}
3939

4040
} // namespace simd

ynnpack/base/simd/byte_vec.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,16 @@ struct vec<uint8_t, 8> {
4646
explicit vec(uint64_t v) : v(v) {}
4747
vec(u8x4 x0, u8x4 x1) : v((static_cast<uint64_t>(x1.v) << 32) | x0.v) {}
4848

49-
u8x4 lo() const { return u8x4{static_cast<uint32_t>(v)}; }
50-
u8x4 hi() const { return u8x4{static_cast<uint32_t>(v >> 32)}; }
51-
5249
uint64_t v;
5350
};
5451

5552
using u8x8 = vec<uint8_t, 8>;
5653

54+
YNN_ALWAYS_INLINE u8x4 lo(u8x8 x) { return u8x4{static_cast<uint32_t>(x.v)}; }
55+
YNN_ALWAYS_INLINE u8x4 hi(u8x8 x) {
56+
return u8x4{static_cast<uint32_t>(x.v >> 32)};
57+
}
58+
5759
YNN_ALWAYS_INLINE u8x4 load_aligned(const uint8_t* ptr, decltype(u8x4::N),
5860
u8x4 = {}) {
5961
return u8x4{*reinterpret_cast<const uint32_t*>(ptr)};

0 commit comments

Comments
 (0)