diff --git a/Directory.Packages.props b/Directory.Packages.props
index 971d5960e27..cd7546526c3 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -32,6 +32,6 @@
-
+
\ No newline at end of file
diff --git a/libs/server/Resp/Vector/DiskANNService.cs b/libs/server/Resp/Vector/DiskANNService.cs
index 37c7ace6c31..439d8a50214 100644
--- a/libs/server/Resp/Vector/DiskANNService.cs
+++ b/libs/server/Resp/Vector/DiskANNService.cs
@@ -5,6 +5,7 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
+using Garnet.common;
using Tsavorite.core;
namespace Garnet.server
@@ -34,11 +35,15 @@ public nint CreateIndex(
delegate* unmanaged[Cdecl] readModifyWriteCallback
)
{
- // TODO: actually pass distance metric
-
unsafe
{
- return NativeDiskANNMethods.create_index(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, (nint)readCallback, (nint)writeCallback, (nint)deleteCallback, (nint)readModifyWriteCallback);
+ var index = NativeDiskANNMethods.create_index(context, dimensions, reduceDims, quantType, (int)distanceMetric, buildExplorationFactor, numLinks, (nint)readCallback, (nint)writeCallback, (nint)deleteCallback, (nint)readModifyWriteCallback);
+ if (index == nint.Zero)
+ {
+ throw new GarnetException("Failed to create DiskANN index, native create_index returned null");
+ }
+
+ return index;
}
}
@@ -308,6 +313,7 @@ public static partial nint create_index(
uint dimensions,
uint reduceDims,
VectorQuantType quantType,
+ int metricType,
uint buildExplorationFactor,
uint numLinks,
nint readCallback,
diff --git a/test/Garnet.test/DiskANN/DiskANNServiceTests.cs b/test/Garnet.test/DiskANN/DiskANNServiceTests.cs
index 43171efc53e..f9ed2e53113 100644
--- a/test/Garnet.test/DiskANN/DiskANNServiceTests.cs
+++ b/test/Garnet.test/DiskANN/DiskANNServiceTests.cs
@@ -170,7 +170,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength
var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel);
var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel);
- var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);
+ var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, (int)VectorDistanceMetricType.Cosine, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);
Span id = [0, 1, 2, 3];
Span elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray();
@@ -365,7 +365,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength
var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel);
var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel);
- var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);
+ var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, (int)VectorDistanceMetricType.Cosine, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);
Span id = [0, 1, 2, 3];
Span elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray();
@@ -410,7 +410,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength
{
NativeDiskANNMethods.drop_index(Context, rawIndex);
- rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);
+ rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, (int)VectorDistanceMetricType.Cosine, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);
}
// Search value