@@ -83,6 +83,48 @@ impl Index {
8383 ) )
8484 }
8585 }
86+
87+ /// Perform a filtered Approximate Nearest Neighbors search on the Index
88+ ///
89+ /// Like [`search`](Self::search), but accepts a bitset filter to exclude
90+ /// vectors during graph traversal. Filtered vectors are never visited,
91+ /// giving better recall than post-filtering.
92+ ///
93+ /// # Arguments
94+ ///
95+ /// * `res` - Resources to use
96+ /// * `params` - Parameters to use in searching the index
97+ /// * `queries` - A matrix in device memory to query for
98+ /// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
99+ /// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
100+ /// * `bitset` - A 1-D `uint32` device tensor with `ceil(n_rows / 32)` elements.
101+ /// Each bit corresponds to a dataset row: bit 1 = include, bit 0 = exclude.
102+ pub fn search_with_filter (
103+ & self ,
104+ res : & Resources ,
105+ params : & SearchParams ,
106+ queries : & ManagedTensor ,
107+ neighbors : & ManagedTensor ,
108+ distances : & ManagedTensor ,
109+ bitset : & ManagedTensor ,
110+ ) -> Result < ( ) > {
111+ unsafe {
112+ let prefilter = ffi:: cuvsFilter {
113+ addr : bitset. as_ptr ( ) as usize ,
114+ type_ : ffi:: cuvsFilterType:: BITSET ,
115+ } ;
116+
117+ check_cuvs ( ffi:: cuvsCagraSearch (
118+ res. 0 ,
119+ params. 0 ,
120+ self . 0 ,
121+ queries. as_ptr ( ) ,
122+ neighbors. as_ptr ( ) ,
123+ distances. as_ptr ( ) ,
124+ prefilter,
125+ ) )
126+ }
127+ }
86128}
87129
88130impl Drop for Index {
@@ -168,6 +210,76 @@ mod tests {
168210 test_cagra ( build_params) ;
169211 }
170212
213+ /// Test bitset-filtered search: exclude odd-indexed rows, verify they don't appear.
214+ #[ test]
215+ fn test_cagra_search_with_filter ( ) {
216+ let res = Resources :: new ( ) . unwrap ( ) ;
217+ let build_params = IndexParams :: new ( ) . unwrap ( ) ;
218+
219+ let n_datapoints = 256 ;
220+ let n_features = 16 ;
221+ let dataset =
222+ ndarray:: Array :: < f32 , _ > :: random ( ( n_datapoints, n_features) , Uniform :: new ( 0. , 1.0 ) ) ;
223+
224+ let index =
225+ Index :: build ( & res, & build_params, & dataset) . expect ( "failed to create cagra index" ) ;
226+
227+ // Build a bitset that includes only even-indexed rows
228+ let n_words = ( n_datapoints + 31 ) / 32 ;
229+ let mut bitset_host = ndarray:: Array :: < u32 , _ > :: zeros ( ndarray:: Ix1 ( n_words) ) ;
230+ for i in 0 ..n_datapoints {
231+ if i % 2 == 0 {
232+ bitset_host[ i / 32 ] |= 1u32 << ( i % 32 ) ;
233+ }
234+ }
235+ let bitset = ManagedTensor :: from ( & bitset_host) . to_device ( & res) . unwrap ( ) ;
236+
237+ // Query with the first 4 even-indexed rows
238+ let n_queries = 4 ;
239+ let queries = dataset. slice ( s ! [ 0 ..n_queries * 2 ; 2 , ..] ) ; // rows 0, 2, 4, 6
240+ let queries = ManagedTensor :: from ( & queries) . to_device ( & res) . unwrap ( ) ;
241+
242+ let k = 10 ;
243+ let mut neighbors_host = ndarray:: Array :: < u32 , _ > :: zeros ( ( n_queries, k) ) ;
244+ let neighbors = ManagedTensor :: from ( & neighbors_host)
245+ . to_device ( & res)
246+ . unwrap ( ) ;
247+ let mut distances_host = ndarray:: Array :: < f32 , _ > :: zeros ( ( n_queries, k) ) ;
248+ let distances = ManagedTensor :: from ( & distances_host)
249+ . to_device ( & res)
250+ . unwrap ( ) ;
251+
252+ let search_params = SearchParams :: new ( ) . unwrap ( ) ;
253+
254+ index
255+ . search_with_filter (
256+ & res,
257+ & search_params,
258+ & queries,
259+ & neighbors,
260+ & distances,
261+ & bitset,
262+ )
263+ . unwrap ( ) ;
264+
265+ neighbors. to_host ( & res, & mut neighbors_host) . unwrap ( ) ;
266+
267+ // All returned neighbors must be even-indexed (odd rows are filtered out).
268+ for q in 0 ..n_queries {
269+ for n in 0 ..k {
270+ let neighbor_id = neighbors_host[ [ q, n] ] ;
271+ assert_eq ! (
272+ neighbor_id % 2 ,
273+ 0 ,
274+ "query {q}, neighbor {n}: got odd index {neighbor_id}, expected only even"
275+ ) ;
276+ }
277+ }
278+
279+ // First query (row 0) should find itself as the nearest neighbor.
280+ assert_eq ! ( neighbors_host[ [ 0 , 0 ] ] , 0 ) ;
281+ }
282+
171283 /// Test that an index can be searched multiple times without rebuilding.
172284 /// This validates that search() takes &self instead of self.
173285 #[ test]
0 commit comments