Skip to content

Commit 0276bf1

Browse files
committed
Remove calculate_partial_sums_f32_512 dispatcher, expose variants directly
1 parent 97a1ada commit 0276bf1

4 files changed

Lines changed: 65 additions & 551 deletions

File tree

jvector-native/src/main/c/jvector_simd.c

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ JV_INLINE float euclidean_f32(const float* a, int aoffset, const float* b, int b
259259
: euclidean_f32_256(a, aoffset, b, boffset, length);
260260
}
261261

262-
JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
262+
void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
263263
int codebookBase = codebookIndex * clusterCount;
264264
float tempdat[16];
265265
if (size == 2) {
@@ -365,7 +365,7 @@ JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int cod
365365
}
366366
}
367367

368-
JV_INLINE void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
368+
void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
369369
int codebookBase = codebookIndex * clusterCount;
370370
float tempdat[16];
371371
if (size == 2) {
@@ -728,15 +728,4 @@ float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int
728728
return sumResult / sqrtf(aMagnitudeResult * bMagnitude);
729729
}
730730

731-
void calculate_partial_sums_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, int similarityFunction, float* partialSums) {
732-
switch (similarityFunction) {
733-
case 0:
734-
calculate_partial_sums_euclidean_f32_512(codebook, codebookIndex, size, clusterCount, query, queryOffset, partialSums);
735-
break;
736-
case 1:
737-
calculate_partial_sums_dot_f32_512(codebook, codebookIndex, size, clusterCount, query, queryOffset, partialSums);
738-
break;
739-
default:
740-
break;
741-
}
742-
}
731+

jvector-native/src/main/c/jvector_simd.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,6 @@ bool check_avx512_compatibility(void);
2727
// APIs exposed to Java via FFI
2828
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength);
2929
float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude);
30-
void calculate_partial_sums_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, int similarityFunction, float* partialSums);
30+
void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
31+
void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
3132
#endif

jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,10 @@ public void calculatePartialSums(VectorFloat<?> codebook, int codebookIndex, int
110110
var nativeCodebook = ((MemorySegmentVectorFloat) codebook).get();
111111
var nativeQuery = ((MemorySegmentVectorFloat) query).get();
112112
var nativePartialSums = ((MemorySegmentVectorFloat) partialSums).get();
113-
int similarityFunction = switch (vsf) {
114-
case EUCLIDEAN -> 0;
115-
case DOT_PRODUCT -> 1;
113+
switch (vsf) {
114+
case EUCLIDEAN -> NativeSimdOps.calculate_partial_sums_euclidean_f32_512(nativeCodebook, codebookIndex, size, clusterCount, nativeQuery, queryOffset, nativePartialSums);
115+
case DOT_PRODUCT -> NativeSimdOps.calculate_partial_sums_dot_f32_512(nativeCodebook, codebookIndex, size, clusterCount, nativeQuery, queryOffset, nativePartialSums);
116116
default -> throw new UnsupportedOperationException("Unsupported similarity function " + vsf);
117-
};
118-
NativeSimdOps.calculate_partial_sums_f32_512(nativeCodebook, codebookIndex, size, clusterCount, nativeQuery, queryOffset, similarityFunction, nativePartialSums);
117+
}
119118
}
120119
}

0 commit comments

Comments
 (0)