@@ -262,7 +262,36 @@ JV_INLINE float euclidean_f32(const float* a, int aoffset, const float* b, int b
262262JV_INLINE void calculate_partial_sums_dot_f32_512 (const float * codebook , int codebookIndex , int size , int clusterCount , const float * query , int queryOffset , float * partialSums ) {
263263 int codebookBase = codebookIndex * clusterCount ;
264264 float tempdat [16 ];
265- if (size == 4 ) {
265+ if (size == 2 ) {
266+ int i = 0 ;
267+ // use a zmm register to calculate 8 partial sums in parallel:
268+ __m128 q_lo = _mm_castsi128_ps (_mm_loadl_epi64 ((__m128i * )(query + queryOffset )));
269+ __m512 qq = _mm512_broadcast_f32x2 (q_lo ); // broadcast 2 query floats to all 8 x 64-bit positions
270+ for (; i + 8 <= clusterCount ; i += 8 ) {
271+ // load eight consecutive centroids (16 floats) from the codebook into zmm
272+ __m512 c = _mm512_loadu_ps (codebook + i * size );
273+ __m512 prod = _mm512_mul_ps (c , qq );
274+ // horizontal reduce: sum the two products within each 64-bit centroid slot
275+ // shuffle swaps pairs within each 128-bit lane: [a,b,c,d] -> [b,a,d,c]
276+ __m512 temp = _mm512_shuffle_ps (prod , prod , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
277+ __m512 sum = _mm512_add_ps (prod , temp );
278+ // results sit at even positions (0,2,4,6,8,10,12,14)
279+ // resgular store and load seem to be better tha vcompress or vpermutex2var for extracting the results
280+ _mm512_storeu_ps (tempdat , sum );
281+ partialSums [codebookBase + i ] = tempdat [0 ];
282+ partialSums [codebookBase + i + 1 ] = tempdat [2 ];
283+ partialSums [codebookBase + i + 2 ] = tempdat [4 ];
284+ partialSums [codebookBase + i + 3 ] = tempdat [6 ];
285+ partialSums [codebookBase + i + 4 ] = tempdat [8 ];
286+ partialSums [codebookBase + i + 5 ] = tempdat [10 ];
287+ partialSums [codebookBase + i + 6 ] = tempdat [12 ];
288+ partialSums [codebookBase + i + 7 ] = tempdat [14 ];
289+ }
290+ for (; i < clusterCount ; i ++ ) {
291+ partialSums [codebookBase + i ] = dot_product_f32 (codebook , i * size , query , queryOffset , size );
292+ }
293+ }
294+ else if (size == 4 ) {
266295 int i = 0 ;
267296 // use a zmm register to calculate 4 partial sums in parallel:
268297 __m128 q = _mm_loadu_ps (query + queryOffset );
@@ -339,7 +368,36 @@ JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int cod
339368JV_INLINE void calculate_partial_sums_euclidean_f32_512 (const float * codebook , int codebookIndex , int size , int clusterCount , const float * query , int queryOffset , float * partialSums ) {
340369 int codebookBase = codebookIndex * clusterCount ;
341370 float tempdat [16 ];
342- if (size == 4 ) {
371+ if (size == 2 ) {
372+ int i = 0 ;
373+ // use a zmm register to calculate 8 partial sums in parallel:
374+ __m128 q_lo = _mm_castsi128_ps (_mm_loadl_epi64 ((__m128i * )(query + queryOffset )));
375+ __m512 qq = _mm512_broadcast_f32x2 (q_lo ); // broadcast 2 query floats to all 8 x 64-bit positions
376+ for (; i + 8 <= clusterCount ; i += 8 ) {
377+ // load eight consecutive centroids (16 floats) from the codebook into zmm
378+ __m512 c = _mm512_loadu_ps (codebook + i * size );
379+ __m512 diff = _mm512_sub_ps (c , qq );
380+ __m512 sq = _mm512_mul_ps (diff , diff );
381+ // horizontal reduce: sum the two squared diffs within each 64-bit centroid slot
382+ // shuffle swaps pairs within each 128-bit lane: [a,b,c,d] -> [b,a,d,c]
383+ __m512 temp = _mm512_shuffle_ps (sq , sq , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
384+ __m512 sum = _mm512_add_ps (sq , temp );
385+ // results sit at even positions (0,2,4,6,8,10,12,14)
386+ _mm512_storeu_ps (tempdat , sum );
387+ partialSums [codebookBase + i ] = tempdat [0 ];
388+ partialSums [codebookBase + i + 1 ] = tempdat [2 ];
389+ partialSums [codebookBase + i + 2 ] = tempdat [4 ];
390+ partialSums [codebookBase + i + 3 ] = tempdat [6 ];
391+ partialSums [codebookBase + i + 4 ] = tempdat [8 ];
392+ partialSums [codebookBase + i + 5 ] = tempdat [10 ];
393+ partialSums [codebookBase + i + 6 ] = tempdat [12 ];
394+ partialSums [codebookBase + i + 7 ] = tempdat [14 ];
395+ }
396+ for (; i < clusterCount ; i ++ ) {
397+ partialSums [codebookBase + i ] = euclidean_f32 (codebook , i * size , query , queryOffset , size );
398+ }
399+ }
400+ else if (size == 4 ) {
343401 int i = 0 ;
344402 // use a zmm register to calculate 4 partial sums in parallel:
345403 __m128 q = _mm_loadu_ps (query + queryOffset );
0 commit comments