66 "testing"
77)
88
9- func TestFlatIndexOnGpu (t * testing.T ) {
9+ func TestFlatIndexOnGpuFunctionality (t * testing.T ) {
1010 index , err := NewIndexFlatL2 (1 )
1111 require .Nil (t , err )
1212
@@ -19,12 +19,22 @@ func TestFlatIndexOnGpu(t *testing.T) {
1919
2020 distances , resultIds , err := gpuIdx .Search (vectorsToAdd , 5 )
2121 require .Nil (t , err )
22+ require .Equal (t , int64 (len (vectorsToAdd )), gpuIdx .Ntotal ())
2223
2324 fmt .Println (distances , resultIds , err )
2425 for i := range vectorsToAdd {
2526 require .Equal (t , int64 (i ), resultIds [len (vectorsToAdd )* i ])
2627 require .Zero (t , distances [len (vectorsToAdd )* i ])
2728 }
29+ //This is necessary bc RemoveIDs isn't implemented for GPUIndexs
30+ cpuIdx , err := TransferToCpu (gpuIdx )
31+ require .Nil (t , err )
32+ idsSelector , err := NewIDSelectorBatch ([]int64 {0 })
33+ cpuIdx .RemoveIDs (idsSelector )
34+ gpuIdx , err = TransferToGpu (cpuIdx )
35+ require .Nil (t , err )
36+ require .Equal (t , int64 (len (vectorsToAdd )- 1 ), gpuIdx .Ntotal ())
37+
2838}
2939
3040func TestIndexIDMapOnGPU (t * testing.T ) {
@@ -55,3 +65,51 @@ func TestIndexIDMapOnGPU(t *testing.T) {
5565 require .Zero (t , distances [len (vectorsToAdd )* i ])
5666 }
5767}
68+
69+ func TestTransferToGpuAndBack (t * testing.T ) {
70+ index , err := NewIndexFlatL2 (1 )
71+ require .Nil (t , err )
72+
73+ indexMap , err := NewIndexIDMap (index )
74+ require .Nil (t , err )
75+
76+ gpuIndex , err := TransferToGpu (indexMap )
77+ require .Nil (t , err )
78+
79+ vectorsToAdd := []float32 {1 ,2 ,4 ,7 ,11 }
80+ ids := make ([]int64 , len (vectorsToAdd ))
81+ for i := 0 ; i < len (vectorsToAdd ); i ++ {
82+ ids [i ] = int64 (i )
83+ }
84+
85+ err = gpuIndex .AddWithIDs (vectorsToAdd , ids )
86+ require .Nil (t , err )
87+
88+ //This is necessary bc RemoveIDs isn't implemented for GPUIndexs
89+ cpuIdx , err := TransferToCpu (gpuIndex )
90+ require .Nil (t , err )
91+ idsSelector , err := NewIDSelectorBatch ([]int64 {0 })
92+ cpuIdx .RemoveIDs (idsSelector )
93+ gpuIndex , err = TransferToGpu (cpuIdx )
94+ require .Nil (t , err )
95+
96+ require .Equal (t , int64 (4 ), gpuIndex .Ntotal ())
97+ distances2 , resultIds2 , err := gpuIndex .Search ([]float32 {1 }, 5 )
98+ fmt .Println (distances2 , resultIds2 , gpuIndex .Ntotal ())
99+ require .Nil (t , err )
100+ require .Equal (t , float32 (1 ), distances2 [0 ])
101+
102+
103+ cpuIndex , err := TransferToCpu (gpuIndex )
104+ require .Nil (t , err )
105+ require .Equal (t , int64 (4 ), cpuIndex .Ntotal ())
106+
107+ idsSelector , err = NewIDSelectorBatch ([]int64 {0 })
108+ cpuIndex .RemoveIDs (idsSelector )
109+ distances2 , resultIds2 , err = cpuIndex .Search ([]float32 {1 }, 5 )
110+ fmt .Println (distances2 , resultIds2 , cpuIndex .Ntotal ())
111+ require .Nil (t , err )
112+ require .Equal (t , float32 (1 ), distances2 [0 ])
113+
114+ }
115+
0 commit comments