Skip to content

Commit 73e5334

Browse files
gpu_bindings.go: added TransferToCpu method
This is potentially necessary for various things, but specifically is critical for us NOW because RemoveIDs isn't implemented for GPUIndexs, so until we implement that functionality in faiss, the only solution is to transfer the index back and forth to the CPU and remove ids which the index is on CPU.
1 parent 10cdfe4 commit 73e5334

2 files changed

Lines changed: 71 additions & 1 deletion

File tree

gpu_bindings.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,15 @@ func TransferToGpu(index Index) (Index, error) {
2424

2525
return &faissIndex{idx: gpuIndex}, nil
2626
}
27+
28+
func TransferToCpu(gpuIndex Index) (Index, error) {
29+
var cpuIndex *C.FaissIndex
30+
31+
exitCode := C.faiss_index_gpu_to_cpu(gpuIndex.cPtr(), &cpuIndex)
32+
if exitCode != 0 {
33+
return nil, errors.New("error transferring to gpu")
34+
}
35+
36+
return &faissIndex{idx: cpuIndex}, nil
37+
}
38+

gpu_bindings_test.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"testing"
77
)
88

9-
func TestFlatIndexOnGpu(t *testing.T) {
9+
func TestFlatIndexOnGpuFunctionality(t *testing.T) {
1010
index, err := NewIndexFlatL2(1)
1111
require.Nil(t, err)
1212

@@ -19,12 +19,22 @@ func TestFlatIndexOnGpu(t *testing.T) {
1919

2020
distances, resultIds, err := gpuIdx.Search(vectorsToAdd, 5)
2121
require.Nil(t, err)
22+
require.Equal(t, int64(len(vectorsToAdd)), gpuIdx.Ntotal())
2223

2324
fmt.Println(distances, resultIds, err)
2425
for i := range vectorsToAdd {
2526
require.Equal(t, int64(i), resultIds[len(vectorsToAdd)*i])
2627
require.Zero(t, distances[len(vectorsToAdd)*i])
2728
}
29+
//This is necessary bc RemoveIDs isn't implemented for GPUIndexs
30+
cpuIdx, err := TransferToCpu(gpuIdx)
31+
require.Nil(t, err)
32+
idsSelector, err := NewIDSelectorBatch([]int64{0})
33+
cpuIdx.RemoveIDs(idsSelector)
34+
gpuIdx, err = TransferToGpu(cpuIdx)
35+
require.Nil(t, err)
36+
require.Equal(t, int64(len(vectorsToAdd)-1), gpuIdx.Ntotal())
37+
2838
}
2939

3040
func TestIndexIDMapOnGPU(t *testing.T) {
@@ -55,3 +65,51 @@ func TestIndexIDMapOnGPU(t *testing.T) {
5565
require.Zero(t, distances[len(vectorsToAdd)*i])
5666
}
5767
}
68+
69+
func TestTransferToGpuAndBack(t *testing.T) {
70+
index, err := NewIndexFlatL2(1)
71+
require.Nil(t, err)
72+
73+
indexMap, err := NewIndexIDMap(index)
74+
require.Nil(t, err)
75+
76+
gpuIndex, err := TransferToGpu(indexMap)
77+
require.Nil(t, err)
78+
79+
vectorsToAdd := []float32{1,2,4,7,11}
80+
ids := make([]int64, len(vectorsToAdd))
81+
for i := 0; i < len(vectorsToAdd); i++ {
82+
ids[i] = int64(i)
83+
}
84+
85+
err = gpuIndex.AddWithIDs(vectorsToAdd, ids)
86+
require.Nil(t, err)
87+
88+
//This is necessary bc RemoveIDs isn't implemented for GPUIndexs
89+
cpuIdx, err := TransferToCpu(gpuIndex)
90+
require.Nil(t, err)
91+
idsSelector, err := NewIDSelectorBatch([]int64{0})
92+
cpuIdx.RemoveIDs(idsSelector)
93+
gpuIndex, err = TransferToGpu(cpuIdx)
94+
require.Nil(t, err)
95+
96+
require.Equal(t, int64(4), gpuIndex.Ntotal())
97+
distances2, resultIds2, err := gpuIndex.Search([]float32{1}, 5)
98+
fmt.Println(distances2, resultIds2, gpuIndex.Ntotal())
99+
require.Nil(t, err)
100+
require.Equal(t, float32(1), distances2[0])
101+
102+
103+
cpuIndex, err := TransferToCpu(gpuIndex)
104+
require.Nil(t, err)
105+
require.Equal(t, int64(4), cpuIndex.Ntotal())
106+
107+
idsSelector, err = NewIDSelectorBatch([]int64{0})
108+
cpuIndex.RemoveIDs(idsSelector)
109+
distances2, resultIds2, err = cpuIndex.Search([]float32{1}, 5)
110+
fmt.Println(distances2, resultIds2, cpuIndex.Ntotal())
111+
require.Nil(t, err)
112+
require.Equal(t, float32(1), distances2[0])
113+
114+
}
115+

0 commit comments

Comments
 (0)