Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 51 additions & 51 deletions ynnpack/base/simd/arm_neon_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace ynn {

namespace simd {

// Half-vector wrappers
template <>
struct vec<uint8_t, 8> {
using value_type = uint8_t;
Expand All @@ -36,8 +37,6 @@ struct vec<uint8_t, 8> {
uint8x8_t v;
};

using u8x8 = vec<uint8_t, 8>;

template <>
struct vec<float, 2> {
using value_type = float;
Expand All @@ -50,6 +49,36 @@ struct vec<float, 2> {
float32x2_t v;
};

template <>
struct vec<bfloat16, 4> {
using value_type = bfloat16;
static constexpr std::integral_constant<size_t, 4> N = {};

vec() = default;
explicit vec(uint16x4_t v) : v(v) {}
vec(bfloat16 x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT

uint16x4_t v;
};

template <>
struct vec<half, 4> {
using value_type = half;
static constexpr std::integral_constant<size_t, 4> N = {};

vec() = default;
explicit vec(uint16x4_t v) : v(v) {}
vec(half x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT

uint16x4_t v;
};

using u8x8 = vec<uint8_t, 8>;
using f32x2 = vec<float, 2>;
using bf16x4 = vec<bfloat16, 4>;
using f16x4 = vec<half, 4>;

// Full vector wrappers
template <>
struct vec<float, 4> {
using value_type = float;
Expand All @@ -58,12 +87,9 @@ struct vec<float, 4> {
vec() = default;
explicit vec(float32x4_t v) : v(v) {}
vec(float x) : v(vdupq_n_f32(x)) {} // NOLINT
vec(vec<float, 2> lo, vec<float, 2> hi) : v(vcombine_f32(lo.v, hi.v)) {}
vec(f32x2 lo, f32x2 hi) : v(vcombine_f32(lo.v, hi.v)) {}

float32x4_t v;

vec<float, 2> lo() const { return vec<float, 2>{vget_low_f32(v)}; }
vec<float, 2> hi() const { return vec<float, 2>{vget_high_f32(v)}; }
};

#ifdef YNN_ARCH_ARM64
Expand Down Expand Up @@ -104,18 +130,6 @@ struct vec<int32_t, 4> {
int32x4_t v;
};

template <>
struct vec<bfloat16, 4> {
using value_type = bfloat16;
static constexpr std::integral_constant<size_t, 4> N = {};

vec() = default;
explicit vec(uint16x4_t v) : v(v) {}
vec(bfloat16 x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT

uint16x4_t v;
};

template <>
struct vec<bfloat16, 8> {
using value_type = bfloat16;
Expand All @@ -124,24 +138,9 @@ struct vec<bfloat16, 8> {
vec() = default;
explicit vec(uint16x8_t v) : v(v) {}
vec(bfloat16 x) : v(vdupq_n_u16(x.to_bits())) {} // NOLINT
vec(vec<bfloat16, 4> lo, vec<bfloat16, 4> hi) : v(vcombine_u16(lo.v, hi.v)) {}
vec(bf16x4 lo, bf16x4 hi) : v(vcombine_u16(lo.v, hi.v)) {}

uint16x8_t v;

vec<bfloat16, 4> lo() const { return vec<bfloat16, 4>{vget_low_u16(v)}; }
vec<bfloat16, 4> hi() const { return vec<bfloat16, 4>{vget_high_u16(v)}; }
};

template <>
struct vec<half, 4> {
using value_type = half;
static constexpr std::integral_constant<size_t, 4> N = {};

vec() = default;
explicit vec(uint16x4_t v) : v(v) {}
vec(half x) : v(vdup_n_u16(x.to_bits())) {} // NOLINT

uint16x4_t v;
};

template <>
Expand Down Expand Up @@ -190,9 +189,6 @@ struct vec<uint8_t, 16> {
vec(u8x8 lo, u8x8 hi) : v(vcombine_u8(lo.v, hi.v)) {}
vec(uint8_t x) : v(vdupq_n_u8(x)) {} // NOLINT

u8x8 lo() const { return u8x8{vget_low_u8(v)}; }
u8x8 hi() const { return u8x8{vget_high_u8(v)}; }

uint8x16_t v;
};

Expand All @@ -208,22 +204,26 @@ struct vec<int8_t, 16> {
int8x16_t v;
};

using f32x2 = vec<float, 2>;
using f32x4 = vec<float, 4>;
#ifdef YNN_ARCH_ARM64
using f64x2 = vec<double, 2>;
#endif
using u32x4 = vec<uint32_t, 4>;
using s32x4 = vec<int32_t, 4>;
using bf16x4 = vec<bfloat16, 4>;
using bf16x8 = vec<bfloat16, 8>;
using f16x4 = vec<half, 4>;
using f16x8 = vec<half, 8>;
using u16x8 = vec<uint16_t, 8>;
using s16x8 = vec<int16_t, 8>;
using u8x16 = vec<uint8_t, 16>;
using s8x16 = vec<int8_t, 16>;

YNN_ALWAYS_INLINE f32x2 lo(f32x4 x) { return f32x2{vget_low_f32(x.v)}; }
YNN_ALWAYS_INLINE f32x2 hi(f32x4 x) { return f32x2{vget_high_f32(x.v)}; }
YNN_ALWAYS_INLINE bf16x4 lo(bf16x8 x) { return bf16x4{vget_low_u16(x.v)}; }
YNN_ALWAYS_INLINE bf16x4 hi(bf16x8 x) { return bf16x4{vget_high_u16(x.v)}; }
YNN_ALWAYS_INLINE u8x8 lo(u8x16 x) { return u8x8{vget_low_u8(x.v)}; }
YNN_ALWAYS_INLINE u8x8 hi(u8x16 x) { return u8x8{vget_high_u8(x.v)}; }

namespace internal {

YNN_ALWAYS_INLINE int32x4x2_t vtrn(int32x4_t a, int32x4_t b) {
Expand Down Expand Up @@ -1205,15 +1205,15 @@ YNN_ALWAYS_INLINE f32x2 cast(f64x2 a, float) {
#endif // YNN_ARCH_ARM64

YNN_ALWAYS_INLINE s16x8 cast(s32x8 a, int16_t) {
return s16x8{vcombine_s16(vqmovn_s32(a.lo().v), vqmovn_s32(a.hi().v))};
return s16x8{vcombine_s16(vqmovn_s32(lo(a).v), vqmovn_s32(hi(a).v))};
}

YNN_ALWAYS_INLINE s8x16 cast(s16x16 a, int8_t) {
return s8x16{vcombine_s8(vqmovn_s16(a.lo().v), vqmovn_s16(a.hi().v))};
return s8x16{vcombine_s8(vqmovn_s16(lo(a).v), vqmovn_s16(hi(a).v))};
}

YNN_ALWAYS_INLINE u8x16 cast(s16x16 a, uint8_t) {
return u8x16{vcombine_u8(vqmovun_s16(a.lo().v), vqmovun_s16(a.hi().v))};
return u8x16{vcombine_u8(vqmovun_s16(lo(a).v), vqmovun_s16(hi(a).v))};
}

YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) {
Expand All @@ -1226,27 +1226,27 @@ YNN_ALWAYS_INLINE s32x4 cast(f32x4 f, int32_t) {

YNN_ALWAYS_INLINE s16x8 cast(f32x8 f, int16_t) {
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
s32x4 a1 = cast(round(f.lo()), int32_t{});
s32x4 a2 = cast(round(f.hi()), int32_t{});
s32x4 a1 = cast(round(lo(f)), int32_t{});
s32x4 a2 = cast(round(hi(f)), int32_t{});
return cast(s32x8{a1, a2}, int16_t{});
#else
return s16x8{vcombine_s16(vqmovn_s32(vcvtnq_s32_f32(f.lo().v)),
vqmovn_s32(vcvtnq_s32_f32(f.hi().v)))};
return s16x8{vcombine_s16(vqmovn_s32(vcvtnq_s32_f32(lo(f).v)),
vqmovn_s32(vcvtnq_s32_f32(hi(f).v)))};
#endif
}

YNN_ALWAYS_INLINE s8x16 cast(f32x16 f, int8_t) {
s16x16 f_s16 = {
cast(f.lo(), int16_t{}),
cast(f.hi(), int16_t{}),
cast(lo(f), int16_t{}),
cast(hi(f), int16_t{}),
};
return cast(f_s16, int8_t{});
}

YNN_ALWAYS_INLINE u8x16 cast(f32x16 f, uint8_t) {
s16x16 f_s16 = {
cast(f.lo(), int16_t{}),
cast(f.hi(), int16_t{}),
cast(lo(f), int16_t{}),
cast(hi(f), int16_t{}),
};
return cast(f_s16, uint8_t{});
}
Expand Down
2 changes: 1 addition & 1 deletion ynnpack/base/simd/arm_neonfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ YNN_ALWAYS_INLINE f16x4 cast(f32x4 a, half) {

YNN_ALWAYS_INLINE f16x8 cast(f32x8 a, half) {
return f16x8{vreinterpretq_u16_f16(
vcombine_f16(vcvt_f16_f32(a.lo().v), vcvt_f16_f32(a.hi().v)))};
vcombine_f16(vcvt_f16_f32(lo(a).v), vcvt_f16_f32(hi(a).v)))};
}

} // namespace simd
Expand Down
8 changes: 5 additions & 3 deletions ynnpack/base/simd/byte_vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ struct vec<uint8_t, 8> {
explicit vec(uint64_t v) : v(v) {}
vec(u8x4 x0, u8x4 x1) : v((static_cast<uint64_t>(x1.v) << 32) | x0.v) {}

u8x4 lo() const { return u8x4{static_cast<uint32_t>(v)}; }
u8x4 hi() const { return u8x4{static_cast<uint32_t>(v >> 32)}; }

uint64_t v;
};

using u8x8 = vec<uint8_t, 8>;

YNN_ALWAYS_INLINE u8x4 lo(u8x8 x) { return u8x4{static_cast<uint32_t>(x.v)}; }
YNN_ALWAYS_INLINE u8x4 hi(u8x8 x) {
return u8x4{static_cast<uint32_t>(x.v >> 32)};
}

YNN_ALWAYS_INLINE u8x4 load_aligned(const uint8_t* ptr, decltype(u8x4::N),
u8x4 = {}) {
return u8x4{*reinterpret_cast<const uint32_t*>(ptr)};
Expand Down
Loading
Loading