From 07e30f69140929c63642e2ae84e250caab8c1c68 Mon Sep 17 00:00:00 2001 From: "Thatikonda.Varunreddy" Date: Mon, 14 Apr 2025 11:51:58 +0530 Subject: [PATCH] Implement fvec_inner_product_sve for FP32 inner product using SVE Signed-off-by: Thatikonda.Varunreddy --- src/simd/distances_sve.cc | 22 ++++++++++++++++++++++ src/simd/distances_sve.h | 3 +++ src/simd/hook.cc | 2 +- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/simd/distances_sve.cc b/src/simd/distances_sve.cc index 3b774f191..94211f263 100644 --- a/src/simd/distances_sve.cc +++ b/src/simd/distances_sve.cc @@ -40,6 +40,28 @@ fvec_L2sqr_sve(const float* x, const float* y, size_t d) { return svaddv_f32(svptrue_b32(), sum); } +float +fvec_inner_product_sve(const float* x, const float* y, size_t d) { + svfloat32_t sum = svdup_f32(0.0f); + size_t i = 0; + + svbool_t pg = svptrue_b32(); + + while (i < d) { + if (d - i < svcntw()) + pg = svwhilelt_b32(i, d); + + svfloat32_t a = svld1_f32(pg, x + i); + svfloat32_t b = svld1_f32(pg, y + i); + sum = svmla_f32_m(pg, sum, a, b); + i += svcntw(); + } + + float result = svaddv_f32(svptrue_b32(), sum); + + return result; +} + float fp16_vec_L2sqr_sve(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { svfloat32_t sum1 = svdup_f32(0.0f); diff --git a/src/simd/distances_sve.h b/src/simd/distances_sve.h index 99af193a8..4f9d79c15 100644 --- a/src/simd/distances_sve.h +++ b/src/simd/distances_sve.h @@ -22,6 +22,9 @@ namespace faiss { float fvec_L2sqr_sve(const float* x, const float* y, size_t d); +float +fvec_inner_product_sve(const float* x, const float* y, size_t d); + float fp16_vec_L2sqr_sve(const knowhere::fp16* x, const knowhere::fp16* y, size_t d); diff --git a/src/simd/hook.cc b/src/simd/hook.cc index f16caa366..23922db89 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -445,7 +445,7 @@ fvec_hook(std::string& simd_type) { fvec_madd = fvec_madd_sve; fvec_madd_and_argmin = fvec_madd_and_argmin_sve; - fvec_inner_product = fvec_inner_product_neon; + fvec_inner_product = fvec_inner_product_sve; fvec_L2sqr_ny = fvec_L2sqr_ny_sve; fvec_inner_products_ny = fvec_inner_products_ny_neon;