Skip to content

Commit 10cdfe4

Browse files
added gpu_bindings.go to enable transferring index to GPU
In addition added a test for the bindings which show how to use the functionality.
1 parent 6c44a77 commit 10cdfe4

2 files changed

Lines changed: 83 additions & 0 deletions

File tree

gpu_bindings.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package faiss
2+
3+
/*
4+
#include <faiss/c_api/gpu/StandardGpuResources_c.h>
5+
#include <faiss/c_api/gpu/GpuAutoTune_c.h>
6+
*/
7+
import "C"
8+
import (
9+
"errors"
10+
)
11+
12+
func TransferToGpu(index Index) (Index, error) {
13+
var gpuResources *C.FaissStandardGpuResources
14+
var gpuIndex *C.FaissGpuIndex
15+
c := C.faiss_StandardGpuResources_new(&gpuResources)
16+
if c != 0 {
17+
return nil, errors.New("error on init gpu %v")
18+
}
19+
20+
exitCode := C.faiss_index_cpu_to_gpu(gpuResources, 0, index.cPtr(), &gpuIndex)
21+
if exitCode != 0 {
22+
return nil, errors.New("error transferring to gpu")
23+
}
24+
25+
return &faissIndex{idx: gpuIndex}, nil
26+
}

gpu_bindings_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package faiss
2+
3+
import (
4+
"fmt"
5+
"github.com/stretchr/testify/require"
6+
"testing"
7+
)
8+
9+
func TestFlatIndexOnGpu(t *testing.T) {
10+
index, err := NewIndexFlatL2(1)
11+
require.Nil(t, err)
12+
13+
gpuIdx, err := TransferToGpu(index)
14+
require.Nil(t, err)
15+
16+
vectorsToAdd := []float32{1,2,3,4,5}
17+
err = gpuIdx.Add(vectorsToAdd)
18+
require.Nil(t, err)
19+
20+
distances, resultIds, err := gpuIdx.Search(vectorsToAdd, 5)
21+
require.Nil(t, err)
22+
23+
fmt.Println(distances, resultIds, err)
24+
for i := range vectorsToAdd {
25+
require.Equal(t, int64(i), resultIds[len(vectorsToAdd)*i])
26+
require.Zero(t, distances[len(vectorsToAdd)*i])
27+
}
28+
}
29+
30+
func TestIndexIDMapOnGPU(t *testing.T) {
31+
index, err := NewIndexFlatL2(1)
32+
require.Nil(t, err)
33+
34+
indexMap, err := NewIndexIDMap(index)
35+
require.Nil(t, err)
36+
37+
gpuIndex, err := TransferToGpu(indexMap)
38+
require.Nil(t, err)
39+
40+
vectorsToAdd := []float32{1,2,3,4,5}
41+
ids := make([]int64, len(vectorsToAdd))
42+
for i := 0; i < len(vectorsToAdd); i++ {
43+
ids[i] = int64(i)
44+
}
45+
46+
err = gpuIndex.AddWithIDs(vectorsToAdd, ids)
47+
require.Nil(t, err)
48+
49+
distances, resultIds, err := gpuIndex.Search(vectorsToAdd, 5)
50+
require.Nil(t, err)
51+
fmt.Println(gpuIndex.D(), gpuIndex.Ntotal())
52+
fmt.Println(distances, resultIds, err)
53+
for i := range vectorsToAdd {
54+
require.Equal(t, ids[i], resultIds[len(vectorsToAdd)*i])
55+
require.Zero(t, distances[len(vectorsToAdd)*i])
56+
}
57+
}

0 commit comments

Comments
 (0)