Skip to content

Commit 97a1ada

Browse files
committed
Optimize size == 2
1 parent b530cb7 commit 97a1ada

1 file changed

Lines changed: 60 additions & 2 deletions

File tree

jvector-native/src/main/c/jvector_simd.c

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,36 @@ JV_INLINE float euclidean_f32(const float* a, int aoffset, const float* b, int b
262262
JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
263263
int codebookBase = codebookIndex * clusterCount;
264264
float tempdat[16];
265-
if (size == 4) {
265+
if (size == 2) {
266+
int i = 0;
267+
// use a zmm register to calculate 8 partial sums in parallel:
268+
__m128 q_lo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(query + queryOffset)));
269+
__m512 qq = _mm512_broadcast_f32x2(q_lo); // broadcast 2 query floats to all 8 x 64-bit positions
270+
for (; i + 8 <= clusterCount; i += 8) {
271+
// load eight consecutive centroids (16 floats) from the codebook into zmm
272+
__m512 c = _mm512_loadu_ps(codebook + i * size);
273+
__m512 prod = _mm512_mul_ps(c, qq);
274+
// horizontal reduce: sum the two products within each 64-bit centroid slot
275+
// shuffle swaps pairs within each 128-bit lane: [a,b,c,d] -> [b,a,d,c]
276+
__m512 temp = _mm512_shuffle_ps(prod, prod, _MM_SHUFFLE(2, 3, 0, 1));
277+
__m512 sum = _mm512_add_ps(prod, temp);
278+
// results sit at even positions (0,2,4,6,8,10,12,14)
279+
// resgular store and load seem to be better tha vcompress or vpermutex2var for extracting the results
280+
_mm512_storeu_ps(tempdat, sum);
281+
partialSums[codebookBase + i] = tempdat[0];
282+
partialSums[codebookBase + i + 1] = tempdat[2];
283+
partialSums[codebookBase + i + 2] = tempdat[4];
284+
partialSums[codebookBase + i + 3] = tempdat[6];
285+
partialSums[codebookBase + i + 4] = tempdat[8];
286+
partialSums[codebookBase + i + 5] = tempdat[10];
287+
partialSums[codebookBase + i + 6] = tempdat[12];
288+
partialSums[codebookBase + i + 7] = tempdat[14];
289+
}
290+
for (; i < clusterCount; i++) {
291+
partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size);
292+
}
293+
}
294+
else if (size == 4) {
266295
int i = 0;
267296
// use a zmm register to calculate 4 partial sums in parallel:
268297
__m128 q = _mm_loadu_ps(query + queryOffset);
@@ -339,7 +368,36 @@ JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int cod
339368
JV_INLINE void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
340369
int codebookBase = codebookIndex * clusterCount;
341370
float tempdat[16];
342-
if (size == 4) {
371+
if (size == 2) {
372+
int i = 0;
373+
// use a zmm register to calculate 8 partial sums in parallel:
374+
__m128 q_lo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(query + queryOffset)));
375+
__m512 qq = _mm512_broadcast_f32x2(q_lo); // broadcast 2 query floats to all 8 x 64-bit positions
376+
for (; i + 8 <= clusterCount; i += 8) {
377+
// load eight consecutive centroids (16 floats) from the codebook into zmm
378+
__m512 c = _mm512_loadu_ps(codebook + i * size);
379+
__m512 diff = _mm512_sub_ps(c, qq);
380+
__m512 sq = _mm512_mul_ps(diff, diff);
381+
// horizontal reduce: sum the two squared diffs within each 64-bit centroid slot
382+
// shuffle swaps pairs within each 128-bit lane: [a,b,c,d] -> [b,a,d,c]
383+
__m512 temp = _mm512_shuffle_ps(sq, sq, _MM_SHUFFLE(2, 3, 0, 1));
384+
__m512 sum = _mm512_add_ps(sq, temp);
385+
// results sit at even positions (0,2,4,6,8,10,12,14)
386+
_mm512_storeu_ps(tempdat, sum);
387+
partialSums[codebookBase + i] = tempdat[0];
388+
partialSums[codebookBase + i + 1] = tempdat[2];
389+
partialSums[codebookBase + i + 2] = tempdat[4];
390+
partialSums[codebookBase + i + 3] = tempdat[6];
391+
partialSums[codebookBase + i + 4] = tempdat[8];
392+
partialSums[codebookBase + i + 5] = tempdat[10];
393+
partialSums[codebookBase + i + 6] = tempdat[12];
394+
partialSums[codebookBase + i + 7] = tempdat[14];
395+
}
396+
for (; i < clusterCount; i++) {
397+
partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size);
398+
}
399+
}
400+
else if (size == 4) {
343401
int i = 0;
344402
// use a zmm register to calculate 4 partial sums in parallel:
345403
__m128 q = _mm_loadu_ps(query + queryOffset);

0 commit comments

Comments
 (0)