diff --git a/src/simd/distances_sve.cc b/src/simd/distances_sve.cc index 94211f263..96795a500 100644 --- a/src/simd/distances_sve.cc +++ b/src/simd/distances_sve.cc @@ -96,6 +96,37 @@ fp16_vec_L2sqr_sve(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { return svaddv_f32(pg_32, total_sum); } +float +fp16_vec_inner_product_sve(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { + svfloat32_t sum1 = svdup_f32(0.0f); + svfloat32_t sum2 = svdup_f32(0.0f); + size_t i = 0; + + svbool_t pg_16 = svptrue_b16(); + svbool_t pg_32 = svptrue_b32(); + + while (i < d) { + if (d - i < svcnth()) + pg_16 = svwhilelt_b16(i, d); + + svfloat16_t a_fp16 = svld1_f16(pg_16, reinterpret_cast(x + i)); + svfloat16_t b_fp16 = svld1_f16(pg_16, reinterpret_cast(y + i)); + + svfloat32_t a_fp32_low = svcvt_f32_f16_z(pg_32, svtrn1_f16(a_fp16, a_fp16)); + svfloat32_t a_fp32_high = svcvt_f32_f16_z(pg_32, svtrn2_f16(a_fp16, a_fp16)); + svfloat32_t b_fp32_low = svcvt_f32_f16_z(pg_32, svtrn1_f16(b_fp16, b_fp16)); + svfloat32_t b_fp32_high = svcvt_f32_f16_z(pg_32, svtrn2_f16(b_fp16, b_fp16)); + + sum1 = svmla_f32_m(pg_32, sum1, a_fp32_low, b_fp32_low); + sum2 = svmla_f32_m(pg_32, sum2, a_fp32_high, b_fp32_high); + + i += svcnth(); + } + + svfloat32_t total_sum = svadd_f32_m(pg_32, sum1, sum2); + return svaddv_f32(pg_32, total_sum); +} + float fvec_L1_sve(const float* x, const float* y, size_t d) { svfloat32_t sum = svdup_f32(0.0f); @@ -308,6 +339,14 @@ fvec_L2sqr_ny_sve(float* dis, const float* x, const float* y, size_t d, size_t n } } +void +fvec_inner_products_ny_sve(float* ip, const float* x, const float* y, size_t d, size_t ny) { + for (size_t i = 0; i < ny; ++i) { + ip[i] = fvec_inner_product_sve(x, y, d); + y += d; + } +} + } // namespace faiss #endif diff --git a/src/simd/distances_sve.h b/src/simd/distances_sve.h index 4f9d79c15..a141ca51b 100644 --- a/src/simd/distances_sve.h +++ b/src/simd/distances_sve.h @@ -28,6 +28,9 @@ fvec_inner_product_sve(const float* x, const float* y, size_t d); float fp16_vec_L2sqr_sve(const knowhere::fp16* x, const knowhere::fp16* y, size_t d); +float +fp16_vec_inner_product_sve(const knowhere::fp16* x, const knowhere::fp16* y, size_t d); + float fvec_L1_sve(const float* x, const float* y, size_t d); @@ -60,5 +63,8 @@ fvec_L2sqr_batch_4_sve(const float* x, const float* y0, const float* y1, const f void fvec_L2sqr_ny_sve(float* dis, const float* x, const float* y, size_t d, size_t ny); +void +fvec_inner_products_ny_sve(float* ip, const float* x, const float* y, size_t d, size_t ny); + } // namespace faiss #endif diff --git a/src/simd/hook.cc b/src/simd/hook.cc index 23922db89..a342cc27c 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -447,13 +447,13 @@ fvec_hook(std::string& simd_type) { fvec_inner_product = fvec_inner_product_sve; fvec_L2sqr_ny = fvec_L2sqr_ny_sve; - fvec_inner_products_ny = fvec_inner_products_ny_neon; + fvec_inner_products_ny = fvec_inner_products_ny_sve; ivec_inner_product = ivec_inner_product_neon; ivec_L2sqr = ivec_L2sqr_neon; // fp16 - fp16_vec_inner_product = fp16_vec_inner_product_neon; + fp16_vec_inner_product = fp16_vec_inner_product_sve; fp16_vec_L2sqr = fp16_vec_L2sqr_sve; fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_sve;