Skip to content

Commit 341d62d

Browse files
dsharletgxnnpack-bot
authored andcommitted
Remove lo/hi as member functions of vec<T, N>
I think this is necessary as a step towards attempting to implement conditions and `select`, where the mask types might be things like `__mmask8`. PiperOrigin-RevId: 918069127
1 parent 4b53a43 commit 341d62d

12 files changed

Lines changed: 196 additions & 204 deletions

File tree

ynnpack/base/simd/arm_neon_base.h

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ struct vec<float, 4> {
6161
vec(vec<float, 2> lo, vec<float, 2> hi) : v(vcombine_f32(lo.v, hi.v)) {}
6262

6363
float32x4_t v;
64-
65-
vec<float, 2> lo() const { return vec<float, 2>{vget_low_f32(v)}; }
66-
vec<float, 2> hi() const { return vec<float, 2>{vget_high_f32(v)}; }
6764
};
6865

6966
#ifdef YNN_ARCH_ARM64
@@ -127,9 +124,6 @@ struct vec<bfloat16, 8> {
127124
vec(vec<bfloat16, 4> lo, vec<bfloat16, 4> hi) : v(vcombine_u16(lo.v, hi.v)) {}
128125

129126
uint16x8_t v;
130-
131-
vec<bfloat16, 4> lo() const { return vec<bfloat16, 4>{vget_low_u16(v)}; }
132-
vec<bfloat16, 4> hi() const { return vec<bfloat16, 4>{vget_high_u16(v)}; }
133127
};
134128

135129
template <>
@@ -190,9 +184,6 @@ struct vec<uint8_t, 16> {
190184
vec(u8x8 lo, u8x8 hi) : v(vcombine_u8(lo.v, hi.v)) {}
191185
vec(uint8_t x) : v(vdupq_n_u8(x)) {} // NOLINT
192186

193-
u8x8 lo() const { return u8x8{vget_low_u8(v)}; }
194-
u8x8 hi() const { return u8x8{vget_high_u8(v)}; }
195-
196187
uint8x16_t v;
197188
};
198189

@@ -224,6 +215,13 @@ using s16x8 = vec<int16_t, 8>;
224215
using u8x16 = vec<uint8_t, 16>;
225216
using s8x16 = vec<int8_t, 16>;
226217

218+
YNN_ALWAYS_INLINE f32x2 lo(f32x4 x) { return f32x2{vget_low_f32(x.v)}; }
219+
YNN_ALWAYS_INLINE f32x2 hi(f32x4 x) { return f32x2{vget_high_f32(x.v)}; }
220+
YNN_ALWAYS_INLINE bf16x4 lo(bf16x8 x) { return bf16x4{vget_low_u16(x.v)}; }
221+
YNN_ALWAYS_INLINE bf16x4 hi(bf16x8 x) { return bf16x4{vget_high_u16(x.v)}; }
222+
YNN_ALWAYS_INLINE u8x8 lo(u8x16 x) { return u8x8{vget_low_u8(x.v)}; }
223+
YNN_ALWAYS_INLINE u8x8 hi(u8x16 x) { return u8x8{vget_high_u8(x.v)}; }
224+
227225
namespace internal {
228226

229227
YNN_ALWAYS_INLINE int32x4x2_t vtrn(int32x4_t a, int32x4_t b) {
@@ -1205,15 +1203,15 @@ YNN_ALWAYS_INLINE f32x2 cast(f64x2 a, float) {
12051203
#endif // YNN_ARCH_ARM64
12061204

12071205
YNN_ALWAYS_INLINE s16x8 cast(s32x8 a, int16_t) {
1208-
return s16x8{vcombine_s16(vqmovn_s32(a.lo().v), vqmovn_s32(a.hi().v))};
1206+
return s16x8{vcombine_s16(vqmovn_s32(lo(a).v), vqmovn_s32(hi(a).v))};
12091207
}
12101208

12111209
YNN_ALWAYS_INLINE s8x16 cast(s16x16 a, int8_t) {
1212-
return s8x16{vcombine_s8(vqmovn_s16(a.lo().v), vqmovn_s16(a.hi().v))};
1210+
return s8x16{vcombine_s8(vqmovn_s16(lo(a).v), vqmovn_s16(hi(a).v))};
12131211
}
12141212

12151213
YNN_ALWAYS_INLINE u8x16 cast(s16x16 a, uint8_t) {
1216-
return u8x16{vcombine_u8(vqmovun_s16(a.lo().v), vqmovun_s16(a.hi().v))};
1214+
return u8x16{vcombine_u8(vqmovun_s16(lo(a).v), vqmovun_s16(hi(a).v))};
12171215
}
12181216

12191217
YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) {
@@ -1226,27 +1224,27 @@ YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) {
12261224

12271225
YNN_ALWAYS_INLINE s16x8 cast(f32x8 f, int16_t) {
12281226
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
1229-
s32x4 a1 = cast(round(f.lo()), int32_t{});
1230-
s32x4 a2 = cast(round(f.hi()), int32_t{});
1227+
s32x4 a1 = cast(round(lo(f)), int32_t{});
1228+
s32x4 a2 = cast(round(hi(f)), int32_t{});
12311229
return cast(s32x8{a1, a2}, int16_t{});
12321230
#else
1233-
return s16x8{vcombine_s16(vqmovn_s32(vcvtnq_s32_f32(f.lo().v)),
1234-
vqmovn_s32(vcvtnq_s32_f32(f.hi().v)))};
1231+
return s16x8{vcombine_s16(vqmovn_s32(vcvtnq_s32_f32(lo(f).v)),
1232+
vqmovn_s32(vcvtnq_s32_f32(hi(f).v)))};
12351233
#endif
12361234
}
12371235

12381236
YNN_ALWAYS_INLINE s8x16 cast(f32x16 f, int8_t) {
12391237
s16x16 f_s16 = {
1240-
cast(f.lo(), int16_t{}),
1241-
cast(f.hi(), int16_t{}),
1238+
cast(lo(f), int16_t{}),
1239+
cast(hi(f), int16_t{}),
12421240
};
12431241
return cast(f_s16, int8_t{});
12441242
}
12451243

12461244
YNN_ALWAYS_INLINE u8x16 cast(f32x16 f, uint8_t) {
12471245
s16x16 f_s16 = {
1248-
cast(f.lo(), int16_t{}),
1249-
cast(f.hi(), int16_t{}),
1246+
cast(lo(f), int16_t{}),
1247+
cast(hi(f), int16_t{}),
12501248
};
12511249
return cast(f_s16, uint8_t{});
12521250
}

ynnpack/base/simd/arm_neonfp16.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ YNN_ALWAYS_INLINE f16x4 cast(f32x4 a, half) {
3434

3535
YNN_ALWAYS_INLINE f16x8 cast(f32x8 a, half) {
3636
return f16x8{vreinterpretq_u16_f16(
37-
vcombine_f16(vcvt_f16_f32(a.lo().v), vcvt_f16_f32(a.hi().v)))};
37+
vcombine_f16(vcvt_f16_f32(lo(a).v), vcvt_f16_f32(hi(a).v)))};
3838
}
3939

4040
} // namespace simd

ynnpack/base/simd/byte_vec.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,16 @@ struct vec<uint8_t, 8> {
4646
explicit vec(uint64_t v) : v(v) {}
4747
vec(u8x4 x0, u8x4 x1) : v((static_cast<uint64_t>(x1.v) << 32) | x0.v) {}
4848

49-
u8x4 lo() const { return u8x4{static_cast<uint32_t>(v)}; }
50-
u8x4 hi() const { return u8x4{static_cast<uint32_t>(v >> 32)}; }
51-
5249
uint64_t v;
5350
};
5451

5552
using u8x8 = vec<uint8_t, 8>;
5653

54+
YNN_ALWAYS_INLINE u8x4 lo(u8x8 x) { return u8x4{static_cast<uint32_t>(x.v)}; }
55+
YNN_ALWAYS_INLINE u8x4 hi(u8x8 x) {
56+
return u8x4{static_cast<uint32_t>(x.v >> 32)};
57+
}
58+
5759
YNN_ALWAYS_INLINE u8x4 load_aligned(const uint8_t* ptr, decltype(u8x4::N),
5860
u8x4 = {}) {
5961
return u8x4{*reinterpret_cast<const uint32_t*>(ptr)};

ynnpack/base/simd/generic.inc

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ template <typename T, size_t N>
5252
YNN_ALWAYS_INLINE vec<T, N> load(const T* ptr, size_t n, vec<T, N> src) {
5353
std::integral_constant<size_t, N / 2> n2 = {};
5454
if (n < n2) {
55-
return {load(ptr, n, src.lo()), src.hi()};
55+
return {load(ptr, n, lo(src)), hi(src)};
5656
} else {
57-
return {load(ptr, n2), load(ptr + n2, n - n2, src.hi())};
57+
return {load(ptr, n2), load(ptr + n2, n - n2, hi(src))};
5858
}
5959
}
6060
template <typename T, size_t N>
@@ -81,43 +81,43 @@ template <typename T, size_t N>
8181
YNN_ALWAYS_INLINE void store(T* ptr, vec<T, N> value,
8282
std::integral_constant<size_t, N> n) {
8383
std::integral_constant<size_t, N / 2> n2 = {};
84-
store(ptr, value.lo(), n2);
85-
store(ptr + n2, value.hi(), n2);
84+
store(ptr, lo(value), n2);
85+
store(ptr + n2, hi(value), n2);
8686
}
8787
template <typename T, size_t N>
8888
YNN_ALWAYS_INLINE void store_aligned(T* ptr, vec<T, N> value,
8989
std::integral_constant<size_t, N> n) {
9090
std::integral_constant<size_t, N / 2> n2 = {};
91-
store_aligned(ptr, value.lo(), n2);
92-
store_aligned(ptr + n2, value.hi(), n2);
91+
store_aligned(ptr, lo(value), n2);
92+
store_aligned(ptr + n2, hi(value), n2);
9393
}
9494
template <typename T, size_t N>
9595
YNN_ALWAYS_INLINE void store(T* ptr, vec<T, N> value, size_t n) {
9696
std::integral_constant<size_t, N / 2> n2 = {};
9797
if (n < n2) {
98-
store(ptr, value.lo(), n);
98+
store(ptr, lo(value), n);
9999
} else {
100-
store(ptr, value.lo(), n2);
101-
store(ptr + n2, value.hi(), n - n2);
100+
store(ptr, lo(value), n2);
101+
store(ptr + n2, hi(value), n - n2);
102102
}
103103
}
104104

105105
// Arithmetic operators.
106106
template <typename T, size_t N>
107107
YNN_ALWAYS_INLINE vec<T, N> operator+(vec<T, N> a, vec<T, N> b) {
108-
return {a.lo() + b.lo(), a.hi() + b.hi()};
108+
return {lo(a) + lo(b), hi(a) + hi(b)};
109109
}
110110
template <typename T, size_t N>
111111
YNN_ALWAYS_INLINE vec<T, N> operator-(vec<T, N> a, vec<T, N> b) {
112-
return {a.lo() - b.lo(), a.hi() - b.hi()};
112+
return {lo(a) - lo(b), hi(a) - hi(b)};
113113
}
114114
template <typename T, size_t N>
115115
YNN_ALWAYS_INLINE vec<T, N> operator*(vec<T, N> a, vec<T, N> b) {
116-
return {a.lo() * b.lo(), a.hi() * b.hi()};
116+
return {lo(a) * lo(b), hi(a) * hi(b)};
117117
}
118118
template <typename T, size_t N>
119119
YNN_ALWAYS_INLINE vec<T, N> operator/(vec<T, N> a, vec<T, N> b) {
120-
return {a.lo() / b.lo(), a.hi() / b.hi()};
120+
return {lo(a) / lo(b), hi(a) / hi(b)};
121121
}
122122

123123
template <typename T, size_t N>
@@ -144,23 +144,23 @@ YNN_ALWAYS_INLINE vec<T, N>& operator/=(vec<T, N>& a, vec<T, N> b) {
144144
// Boolean operators.
145145
template <typename T, size_t N>
146146
YNN_ALWAYS_INLINE vec<T, N> operator&(vec<T, N> a, vec<T, N> b) {
147-
return {a.lo() & b.lo(), a.hi() & b.hi()};
147+
return {lo(a) & lo(b), hi(a) & hi(b)};
148148
}
149149
template <typename T, size_t N>
150150
YNN_ALWAYS_INLINE vec<T, N> operator|(vec<T, N> a, vec<T, N> b) {
151-
return {a.lo() | b.lo(), a.hi() | b.hi()};
151+
return {lo(a) | lo(b), hi(a) | hi(b)};
152152
}
153153
template <typename T, size_t N>
154154
YNN_ALWAYS_INLINE vec<T, N> operator^(vec<T, N> a, vec<T, N> b) {
155-
return {a.lo() ^ b.lo(), a.hi() ^ b.hi()};
155+
return {lo(a) ^ lo(b), hi(a) ^ hi(b)};
156156
}
157157
template <typename T, size_t N>
158158
YNN_ALWAYS_INLINE vec<T, N> operator~(vec<T, N> a) {
159-
return {~a.lo(), ~a.hi()};
159+
return {~lo(a), ~hi(a)};
160160
}
161161
template <typename T, size_t N>
162162
YNN_ALWAYS_INLINE vec<T, N> operator<<(vec<T, N> a, int b) {
163-
return {a.lo() << b, a.hi() << b};
163+
return {lo(a) << b, hi(a) << b};
164164
}
165165

166166
template <typename T, size_t N>
@@ -180,60 +180,60 @@ YNN_ALWAYS_INLINE vec<T, N>& operator^=(vec<T, N>& a, vec<T, N> b) {
180180
}
181181
template <typename T, size_t N>
182182
YNN_ALWAYS_INLINE vec<T, N> min(vec<T, N> a, vec<T, N> b) {
183-
return {min(a.lo(), b.lo()), min(a.hi(), b.hi())};
183+
return {min(lo(a), lo(b)), min(hi(a), hi(b))};
184184
}
185185
template <typename T, size_t N>
186186
YNN_ALWAYS_INLINE vec<T, N> max(vec<T, N> a, vec<T, N> b) {
187-
return {max(a.lo(), b.lo()), max(a.hi(), b.hi())};
187+
return {max(lo(a), lo(b)), max(hi(a), hi(b))};
188188
}
189189
template <typename T, size_t N>
190190
YNN_ALWAYS_INLINE vec<T, N> copysign(vec<T, N> mag, vec<T, N> sgn) {
191-
return {copysign(mag.lo(), sgn.lo()), copysign(mag.hi(), sgn.hi())};
191+
return {copysign(lo(mag), lo(sgn)), copysign(hi(mag), hi(sgn))};
192192
};
193193
template <typename T, size_t N>
194194
YNN_ALWAYS_INLINE vec<T, N> abs(vec<T, N> a) {
195-
return {abs(a.lo()), abs(a.hi())};
195+
return {abs(lo(a)), abs(hi(a))};
196196
}
197197
template <typename T, size_t N>
198198
YNN_ALWAYS_INLINE vec<T, N> add_sat(vec<T, N> a, vec<T, N> b) {
199-
return {add_sat(a.lo(), b.lo()), add_sat(a.hi(), b.hi())};
199+
return {add_sat(lo(a), lo(b)), add_sat(hi(a), hi(b))};
200200
}
201201
template <typename T, size_t N>
202202
YNN_ALWAYS_INLINE vec<T, N> sub_sat(vec<T, N> a, vec<T, N> b) {
203-
return {sub_sat(a.lo(), b.lo()), sub_sat(a.hi(), b.hi())};
203+
return {sub_sat(lo(a), lo(b)), sub_sat(hi(a), hi(b))};
204204
}
205205
template <typename T, size_t N>
206206
YNN_ALWAYS_INLINE vec<T, N> floor(vec<T, N> a) {
207-
return {floor(a.lo()), floor(a.hi())};
207+
return {floor(lo(a)), floor(hi(a))};
208208
}
209209
template <typename T, size_t N>
210210
YNN_ALWAYS_INLINE vec<T, N> floor_log2(vec<T, N> a) {
211-
return {floor_log2(a.lo()), floor_log2(a.hi())};
211+
return {floor_log2(lo(a)), floor_log2(hi(a))};
212212
}
213213

214214
template <typename T, size_t N>
215215
YNN_ALWAYS_INLINE vec<T, N> exp2_round(vec<T, N> a) {
216-
return {exp2_round(a.lo()), exp2_round(a.hi())};
216+
return {exp2_round(lo(a)), exp2_round(hi(a))};
217217
}
218218
template <typename T, size_t N>
219219
YNN_ALWAYS_INLINE vec<T, N> copynan(vec<T, N> x, vec<T, N> nan) {
220-
return {copynan(x.lo(), nan.lo()), copynan(x.hi(), nan.hi())};
220+
return {copynan(lo(x), lo(nan)), copynan(hi(x), hi(nan))};
221221
}
222222
template <typename T, size_t N>
223223
YNN_ALWAYS_INLINE vec<T, N> ceil(vec<T, N> a) {
224-
return {ceil(a.lo()), ceil(a.hi())};
224+
return {ceil(lo(a)), ceil(hi(a))};
225225
}
226226
template <typename T, size_t N>
227227
YNN_ALWAYS_INLINE vec<T, N> round(vec<T, N> a) {
228-
return {round(a.lo()), round(a.hi())};
228+
return {round(lo(a)), round(hi(a))};
229229
}
230230
template <typename T, size_t N>
231231
YNN_ALWAYS_INLINE vec<T, N> sqrt(vec<T, N> a) {
232-
return {sqrt(a.lo()), sqrt(a.hi())};
232+
return {sqrt(lo(a)), sqrt(hi(a))};
233233
}
234234
template <typename T, size_t N>
235235
YNN_ALWAYS_INLINE vec<T, N> fma(vec<T, N> a, vec<T, N> b, vec<T, N> acc) {
236-
return {fma(a.lo(), b.lo(), acc.lo()), fma(a.hi(), b.hi(), acc.hi())};
236+
return {fma(lo(a), lo(b), lo(acc)), fma(hi(a), hi(b), hi(acc))};
237237
}
238238

239239
template <int Index, typename T, size_t N>
@@ -246,7 +246,7 @@ template <int Index, typename T, size_t N>
246246
YNN_ALWAYS_INLINE vec<T, N / 2> extract(vec<T, N> x,
247247
std::integral_constant<size_t, N / 2>) {
248248
static_assert(Index == 0 || Index == 1, "");
249-
return Index == 0 ? x.lo() : x.hi();
249+
return Index == 0 ? lo(x) : hi(x);
250250
}
251251
template <int Index, typename T, size_t N>
252252
YNN_ALWAYS_INLINE vec<T, N / 4> extract(vec<T, N> x,
@@ -263,31 +263,31 @@ YNN_ALWAYS_INLINE vec<T, N*2> concat(vec<T, N> a, vec<T, N> b) {
263263

264264
template <typename To, typename From, size_t N>
265265
YNN_ALWAYS_INLINE vec<To, N> cast(vec<From, N> from, To) {
266-
return {cast(from.lo(), To()), cast(from.hi(), To())};
266+
return {cast(lo(from), To()), cast(hi(from), To())};
267267
}
268268

269269
template <typename T, size_t N>
270270
YNN_ALWAYS_INLINE T horizontal_sum(vec<T, N> x) {
271-
return horizontal_sum(x.lo() + x.hi());
271+
return horizontal_sum(lo(x) + hi(x));
272272
}
273273
template <typename T, size_t N>
274274
YNN_ALWAYS_INLINE T horizontal_min(vec<T, N> x) {
275-
return horizontal_min(min(x.lo(), x.hi()));
275+
return horizontal_min(min(lo(x), hi(x)));
276276
}
277277
template <typename T, size_t N>
278278
YNN_ALWAYS_INLINE T horizontal_max(vec<T, N> x) {
279-
return horizontal_max(max(x.lo(), x.hi()));
279+
return horizontal_max(max(lo(x), hi(x)));
280280
}
281281

282282
template <typename T, size_t N>
283283
YNN_ALWAYS_INLINE void kahan_sum(vec<T, N> a, vec<T, N>& acc,
284284
vec<T, N>& error) {
285-
vec<T, N / 2> acc_lo = acc.lo();
286-
vec<T, N / 2> acc_hi = acc.hi();
287-
vec<T, N / 2> error_lo = error.lo();
288-
vec<T, N / 2> error_hi = error.hi();
289-
kahan_sum(a.lo(), acc_lo, error_lo);
290-
kahan_sum(a.hi(), acc_hi, error_hi);
285+
vec<T, N / 2> acc_lo = lo(acc);
286+
vec<T, N / 2> acc_hi = hi(acc);
287+
vec<T, N / 2> error_lo = lo(error);
288+
vec<T, N / 2> error_hi = hi(error);
289+
kahan_sum(lo(a), acc_lo, error_lo);
290+
kahan_sum(hi(a), acc_hi, error_hi);
291291
acc = concat(acc_lo, acc_hi);
292292
error = concat(error_lo, error_hi);
293293
}

ynnpack/base/simd/vec.h

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,6 @@ struct vec {
6868

6969
subvec v[2];
7070

71-
subvec& lo() { return v[0]; }
72-
const subvec& lo() const { return v[0]; }
73-
subvec& hi() { return v[1]; }
74-
const subvec& hi() const { return v[1]; }
75-
7671
vec() = default;
7772
YNN_ALWAYS_INLINE explicit vec(value_type x) : v{subvec{x}, subvec{x}} {}
7873
YNN_ALWAYS_INLINE vec(subvec v0, subvec v1) : v{v0, v1} {}
@@ -81,6 +76,26 @@ struct vec {
8176
YNN_ALWAYS_INLINE const subvec& operator[](size_t i) const { return v[i]; }
8277
};
8378

79+
template <typename T, size_t N>
80+
YNN_ALWAYS_INLINE vec<T, N / 2>& lo(vec<T, N>& x) {
81+
return x.v[0];
82+
}
83+
84+
template <typename T, size_t N>
85+
YNN_ALWAYS_INLINE const vec<T, N / 2>& lo(const vec<T, N>& x) {
86+
return x.v[0];
87+
}
88+
89+
template <typename T, size_t N>
90+
YNN_ALWAYS_INLINE vec<T, N / 2>& hi(vec<T, N>& x) {
91+
return x.v[1];
92+
}
93+
94+
template <typename T, size_t N>
95+
YNN_ALWAYS_INLINE const vec<T, N / 2>& hi(const vec<T, N>& x) {
96+
return x.v[1];
97+
}
98+
8499
template <size_t N, typename T>
85100
YNN_ALWAYS_INLINE vec<T, N> broadcast(T x) {
86101
return vec<T, N>{x};

0 commit comments

Comments
 (0)