@@ -97,6 +97,48 @@ impl Index {
9797 }
9898 }
9999
100+ /// Perform a filtered Approximate Nearest Neighbors search on the Index
101+ ///
102+ /// Like [`search`](Self::search), but accepts a bitset filter to exclude
103+ /// vectors during graph traversal. Filtered vectors are never visited,
104+ /// giving better recall than post-filtering.
105+ ///
106+ /// # Arguments
107+ ///
108+ /// * `res` - Resources to use
109+ /// * `params` - Parameters to use in searching the index
110+ /// * `queries` - A matrix in device memory to query for
111+ /// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
112+ /// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
113+ /// * `bitset` - A 1-D `uint32` device tensor with `ceil(n_rows / 32)` elements.
114+ /// Each bit corresponds to a dataset row: bit 1 = include, bit 0 = exclude.
115+ pub fn search_with_filter (
116+ & self ,
117+ res : & Resources ,
118+ params : & SearchParams ,
119+ queries : & ManagedTensor ,
120+ neighbors : & ManagedTensor ,
121+ distances : & ManagedTensor ,
122+ bitset : & ManagedTensor ,
123+ ) -> Result < ( ) > {
124+ unsafe {
125+ let prefilter = ffi:: cuvsFilter {
126+ addr : bitset. as_ptr ( ) as usize ,
127+ type_ : ffi:: cuvsFilterType:: BITSET ,
128+ } ;
129+
130+ check_cuvs ( ffi:: cuvsCagraSearch (
131+ res. 0 ,
132+ params. 0 ,
133+ self . 0 ,
134+ queries. as_ptr ( ) ,
135+ neighbors. as_ptr ( ) ,
136+ distances. as_ptr ( ) ,
137+ prefilter,
138+ ) )
139+ }
140+ }
141+
100142 /// Save the CAGRA index to file.
101143 ///
102144 /// Experimental, both the API and the serialization format are subject to change.
@@ -254,6 +296,76 @@ mod tests {
254296 test_cagra ( build_params) ;
255297 }
256298
299+ /// Test bitset-filtered search: exclude odd-indexed rows, verify they don't appear.
300+ #[ test]
301+ fn test_cagra_search_with_filter ( ) {
302+ let res = Resources :: new ( ) . unwrap ( ) ;
303+ let build_params = IndexParams :: new ( ) . unwrap ( ) ;
304+
305+ let n_datapoints = 256 ;
306+ let n_features = 16 ;
307+ let dataset =
308+ ndarray:: Array :: < f32 , _ > :: random ( ( n_datapoints, n_features) , Uniform :: new ( 0. , 1.0 ) ) ;
309+
310+ let index =
311+ Index :: build ( & res, & build_params, & dataset) . expect ( "failed to create cagra index" ) ;
312+
313+ // Build a bitset that includes only even-indexed rows
314+ let n_words = ( n_datapoints + 31 ) / 32 ;
315+ let mut bitset_host = ndarray:: Array :: < u32 , _ > :: zeros ( ndarray:: Ix1 ( n_words) ) ;
316+ for i in 0 ..n_datapoints {
317+ if i % 2 == 0 {
318+ bitset_host[ i / 32 ] |= 1u32 << ( i % 32 ) ;
319+ }
320+ }
321+ let bitset = ManagedTensor :: from ( & bitset_host) . to_device ( & res) . unwrap ( ) ;
322+
323+ // Query with the first 4 even-indexed rows
324+ let n_queries = 4 ;
325+ let queries = dataset. slice ( s ! [ 0 ..n_queries * 2 ; 2 , ..] ) ; // rows 0, 2, 4, 6
326+ let queries = ManagedTensor :: from ( & queries) . to_device ( & res) . unwrap ( ) ;
327+
328+ let k = 10 ;
329+ let mut neighbors_host = ndarray:: Array :: < u32 , _ > :: zeros ( ( n_queries, k) ) ;
330+ let neighbors = ManagedTensor :: from ( & neighbors_host)
331+ . to_device ( & res)
332+ . unwrap ( ) ;
333+ let mut distances_host = ndarray:: Array :: < f32 , _ > :: zeros ( ( n_queries, k) ) ;
334+ let distances = ManagedTensor :: from ( & distances_host)
335+ . to_device ( & res)
336+ . unwrap ( ) ;
337+
338+ let search_params = SearchParams :: new ( ) . unwrap ( ) ;
339+
340+ index
341+ . search_with_filter (
342+ & res,
343+ & search_params,
344+ & queries,
345+ & neighbors,
346+ & distances,
347+ & bitset,
348+ )
349+ . unwrap ( ) ;
350+
351+ neighbors. to_host ( & res, & mut neighbors_host) . unwrap ( ) ;
352+
353+ // All returned neighbors must be even-indexed (odd rows are filtered out).
354+ for q in 0 ..n_queries {
355+ for n in 0 ..k {
356+ let neighbor_id = neighbors_host[ [ q, n] ] ;
357+ assert_eq ! (
358+ neighbor_id % 2 ,
359+ 0 ,
360+ "query {q}, neighbor {n}: got odd index {neighbor_id}, expected only even"
361+ ) ;
362+ }
363+ }
364+
365+ // First query (row 0) should find itself as the nearest neighbor.
366+ assert_eq ! ( neighbors_host[ [ 0 , 0 ] ] , 0 ) ;
367+ }
368+
257369 /// Test that an index can be searched multiple times without rebuilding.
258370 /// This validates that `search()` takes `&self` instead of `self`.
259371 #[ test]
0 commit comments