Skip to content

Commit 8453d57

Browse files
author
Raghuveer Devulapalli
committed
Wire calculatePartialSums to native SIMD via Panama FFI downcall
1 parent 0602b8c commit 8453d57

2 files changed

Lines changed: 45 additions & 0 deletions

File tree

jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,17 @@ public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encodedOffse
104104
// encoded is a pointer into a PQ chunk - we need to index into it by encodedOffset and provide encodedLength to the native code
105105
return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encodedOffset, encodedLength, clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude);
106106
}
107+
108+
@Override
109+
public void calculatePartialSums(VectorFloat<?> codebook, int codebookIndex, int size, int clusterCount, VectorFloat<?> query, int queryOffset, VectorSimilarityFunction vsf, VectorFloat<?> partialSums) {
110+
var nativeCodebook = ((MemorySegmentVectorFloat) codebook).get();
111+
var nativeQuery = ((MemorySegmentVectorFloat) query).get();
112+
var nativePartialSums = ((MemorySegmentVectorFloat) partialSums).get();
113+
int similarityFunction = switch (vsf) {
114+
case EUCLIDEAN -> 0;
115+
case DOT_PRODUCT -> 1;
116+
default -> throw new UnsupportedOperationException("Unsupported similarity function " + vsf);
117+
};
118+
NativeSimdOps.calculate_partial_sums_f32_512(nativeCodebook, codebookIndex, size, clusterCount, nativeQuery, queryOffset, similarityFunction, nativePartialSums);
119+
}
107120
}

jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,4 +847,36 @@ public static void calculate_partial_sums_best_euclidean_f32_512(MemorySegment c
847847
throw new AssertionError("should not reach here", ex$);
848848
}
849849
}
850+
851+
private static class calculate_partial_sums_f32_512 {
852+
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
853+
NativeSimdOps.C_POINTER,
854+
NativeSimdOps.C_INT,
855+
NativeSimdOps.C_INT,
856+
NativeSimdOps.C_INT,
857+
NativeSimdOps.C_POINTER,
858+
NativeSimdOps.C_INT,
859+
NativeSimdOps.C_INT,
860+
NativeSimdOps.C_POINTER
861+
);
862+
public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("calculate_partial_sums_f32_512");
863+
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC, Linker.Option.critical(true));
864+
}
865+
866+
/**
867+
* {@snippet lang=c :
868+
* void calculate_partial_sums_f32_512(const float *codebook, int codebookIndex, int size, int clusterCount, const float *query, int queryOffset, int similarityFunction, float *partialSums)
869+
* }
870+
*/
871+
public static void calculate_partial_sums_f32_512(MemorySegment codebook, int codebookIndex, int size, int clusterCount, MemorySegment query, int queryOffset, int similarityFunction, MemorySegment partialSums) {
872+
var mh$ = calculate_partial_sums_f32_512.HANDLE;
873+
try {
874+
if (TRACE_DOWNCALLS) {
875+
traceDowncall("calculate_partial_sums_f32_512", codebook, codebookIndex, size, clusterCount, query, queryOffset, similarityFunction, partialSums);
876+
}
877+
mh$.invokeExact(codebook, codebookIndex, size, clusterCount, query, queryOffset, similarityFunction, partialSums);
878+
} catch (Throwable ex$) {
879+
throw new AssertionError("should not reach here", ex$);
880+
}
881+
}
850882
}

0 commit comments

Comments
 (0)