@@ -96,6 +96,37 @@ fp16_vec_L2sqr_sve(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) {
9696 return svaddv_f32 (pg_32, total_sum);
9797}
9898
99+ float
100+ fp16_vec_inner_product_sve (const knowhere::fp16* x, const knowhere::fp16* y, size_t d) {
101+ svfloat32_t sum1 = svdup_f32 (0 .0f );
102+ svfloat32_t sum2 = svdup_f32 (0 .0f );
103+ size_t i = 0 ;
104+
105+ svbool_t pg_16 = svptrue_b16 ();
106+ svbool_t pg_32 = svptrue_b32 ();
107+
108+ while (i < d) {
109+ if (d - i < svcnth ())
110+ pg_16 = svwhilelt_b16 (i, d);
111+
112+ svfloat16_t a_fp16 = svld1_f16 (pg_16, reinterpret_cast <const __fp16*>(x + i));
113+ svfloat16_t b_fp16 = svld1_f16 (pg_16, reinterpret_cast <const __fp16*>(y + i));
114+
115+ svfloat32_t a_fp32_low = svcvt_f32_f16_z (pg_32, svtrn1_f16 (a_fp16, a_fp16));
116+ svfloat32_t a_fp32_high = svcvt_f32_f16_z (pg_32, svtrn2_f16 (a_fp16, a_fp16));
117+ svfloat32_t b_fp32_low = svcvt_f32_f16_z (pg_32, svtrn1_f16 (b_fp16, b_fp16));
118+ svfloat32_t b_fp32_high = svcvt_f32_f16_z (pg_32, svtrn2_f16 (b_fp16, b_fp16));
119+
120+ sum1 = svmla_f32_m (pg_32, sum1, a_fp32_low, b_fp32_low);
121+ sum2 = svmla_f32_m (pg_32, sum2, a_fp32_high, b_fp32_high);
122+
123+ i += svcnth ();
124+ }
125+
126+ svfloat32_t total_sum = svadd_f32_m (pg_32, sum1, sum2);
127+ return svaddv_f32 (pg_32, total_sum);
128+ }
129+
99130float
100131fvec_L1_sve (const float * x, const float * y, size_t d) {
101132 svfloat32_t sum = svdup_f32 (0 .0f );
@@ -308,6 +339,14 @@ fvec_L2sqr_ny_sve(float* dis, const float* x, const float* y, size_t d, size_t n
308339 }
309340}
310341
342+ void
343+ fvec_inner_products_ny_sve (float * ip, const float * x, const float * y, size_t d, size_t ny) {
344+ for (size_t i = 0 ; i < ny; ++i) {
345+ ip[i] = fvec_inner_product_sve (x, y, d);
346+ y += d;
347+ }
348+ }
349+
311350} // namespace faiss
312351
313352#endif
0 commit comments