diff --git a/Directory.Packages.props b/Directory.Packages.props index 971d5960e27..cd7546526c3 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -32,6 +32,6 @@ - + \ No newline at end of file diff --git a/libs/server/Resp/Vector/DiskANNService.cs b/libs/server/Resp/Vector/DiskANNService.cs index 37c7ace6c31..439d8a50214 100644 --- a/libs/server/Resp/Vector/DiskANNService.cs +++ b/libs/server/Resp/Vector/DiskANNService.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using Garnet.common; using Tsavorite.core; namespace Garnet.server @@ -34,11 +35,15 @@ public nint CreateIndex( delegate* unmanaged[Cdecl] readModifyWriteCallback ) { - // TODO: actually pass distance metric - unsafe { - return NativeDiskANNMethods.create_index(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, (nint)readCallback, (nint)writeCallback, (nint)deleteCallback, (nint)readModifyWriteCallback); + var index = NativeDiskANNMethods.create_index(context, dimensions, reduceDims, quantType, (int)distanceMetric, buildExplorationFactor, numLinks, (nint)readCallback, (nint)writeCallback, (nint)deleteCallback, (nint)readModifyWriteCallback); + if (index == nint.Zero) + { + throw new GarnetException("Failed to create DiskANN index, native create_index returned null"); + } + + return index; } } @@ -308,6 +313,7 @@ public static partial nint create_index( uint dimensions, uint reduceDims, VectorQuantType quantType, + int metricType, uint buildExplorationFactor, uint numLinks, nint readCallback, diff --git a/test/Garnet.test/DiskANN/DiskANNServiceTests.cs b/test/Garnet.test/DiskANN/DiskANNServiceTests.cs index 43171efc53e..f9ed2e53113 100644 --- a/test/Garnet.test/DiskANN/DiskANNServiceTests.cs +++ b/test/Garnet.test/DiskANN/DiskANNServiceTests.cs @@ -170,7 +170,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel); var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel); - var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, (int)VectorDistanceMetricType.Cosine, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); Span id = [0, 1, 2, 3]; Span elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); @@ -365,7 +365,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel); var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel); - var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, (int)VectorDistanceMetricType.Cosine, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); Span id = [0, 1, 2, 3]; Span elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); @@ -410,7 +410,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength { NativeDiskANNMethods.drop_index(Context, rawIndex); - rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, (int)VectorDistanceMetricType.Cosine, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); } // Search value