Skip to content

Commit 01c3950

Browse files
committed
feat(rust): add search_with_filter to CAGRA Index
Add `Index::search_with_filter()` that accepts a bitset filter via DLPack ManagedTensor. The C API `cuvsCagraSearch()` already supports `cuvsFilter` with BITSET type, but the Rust bindings hardcoded `NO_FILTER`. This exposes the existing capability. The bitset is a 1-D uint32 device tensor with ceil(n_rows / 32) elements. Bit = 1 includes the row, bit = 0 excludes it. Filtering happens during graph traversal, not post-retrieval. Includes test: builds a 256-point index, filters to even-indexed rows, verifies all returned neighbors pass the filter.
1 parent c95625b commit 01c3950

1 file changed

Lines changed: 112 additions & 0 deletions

File tree

rust/cuvs/src/cagra/index.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,48 @@ impl Index {
8383
))
8484
}
8585
}
86+
87+
/// Perform a filtered Approximate Nearest Neighbors search on the Index
88+
///
89+
/// Like [`search`](Self::search), but accepts a bitset filter to exclude
90+
/// vectors during graph traversal. Filtered vectors are never visited,
91+
/// giving better recall than post-filtering.
92+
///
93+
/// # Arguments
94+
///
95+
/// * `res` - Resources to use
96+
/// * `params` - Parameters to use in searching the index
97+
/// * `queries` - A matrix in device memory to query for
98+
/// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
99+
/// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
100+
/// * `bitset` - A 1-D `uint32` device tensor with `ceil(n_rows / 32)` elements.
101+
/// Each bit corresponds to a dataset row: bit 1 = include, bit 0 = exclude.
102+
pub fn search_with_filter(
103+
&self,
104+
res: &Resources,
105+
params: &SearchParams,
106+
queries: &ManagedTensor,
107+
neighbors: &ManagedTensor,
108+
distances: &ManagedTensor,
109+
bitset: &ManagedTensor,
110+
) -> Result<()> {
111+
unsafe {
112+
let prefilter = ffi::cuvsFilter {
113+
addr: bitset.as_ptr() as usize,
114+
type_: ffi::cuvsFilterType::BITSET,
115+
};
116+
117+
check_cuvs(ffi::cuvsCagraSearch(
118+
res.0,
119+
params.0,
120+
self.0,
121+
queries.as_ptr(),
122+
neighbors.as_ptr(),
123+
distances.as_ptr(),
124+
prefilter,
125+
))
126+
}
127+
}
86128
}
87129

88130
impl Drop for Index {
@@ -168,6 +210,76 @@ mod tests {
168210
test_cagra(build_params);
169211
}
170212

213+
/// Test bitset-filtered search: exclude odd-indexed rows, verify they don't appear.
214+
#[test]
215+
fn test_cagra_search_with_filter() {
216+
let res = Resources::new().unwrap();
217+
let build_params = IndexParams::new().unwrap();
218+
219+
let n_datapoints = 256;
220+
let n_features = 16;
221+
let dataset =
222+
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
223+
224+
let index =
225+
Index::build(&res, &build_params, &dataset).expect("failed to create cagra index");
226+
227+
// Build a bitset that includes only even-indexed rows
228+
let n_words = (n_datapoints + 31) / 32;
229+
let mut bitset_host = ndarray::Array::<u32, _>::zeros(ndarray::Ix1(n_words));
230+
for i in 0..n_datapoints {
231+
if i % 2 == 0 {
232+
bitset_host[i / 32] |= 1u32 << (i % 32);
233+
}
234+
}
235+
let bitset = ManagedTensor::from(&bitset_host).to_device(&res).unwrap();
236+
237+
// Query with the first 4 even-indexed rows
238+
let n_queries = 4;
239+
let queries = dataset.slice(s![0..n_queries * 2;2, ..]); // rows 0, 2, 4, 6
240+
let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
241+
242+
let k = 10;
243+
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
244+
let neighbors = ManagedTensor::from(&neighbors_host)
245+
.to_device(&res)
246+
.unwrap();
247+
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
248+
let distances = ManagedTensor::from(&distances_host)
249+
.to_device(&res)
250+
.unwrap();
251+
252+
let search_params = SearchParams::new().unwrap();
253+
254+
index
255+
.search_with_filter(
256+
&res,
257+
&search_params,
258+
&queries,
259+
&neighbors,
260+
&distances,
261+
&bitset,
262+
)
263+
.unwrap();
264+
265+
neighbors.to_host(&res, &mut neighbors_host).unwrap();
266+
267+
// All returned neighbors must be even-indexed (odd rows are filtered out).
268+
for q in 0..n_queries {
269+
for n in 0..k {
270+
let neighbor_id = neighbors_host[[q, n]];
271+
assert_eq!(
272+
neighbor_id % 2,
273+
0,
274+
"query {q}, neighbor {n}: got odd index {neighbor_id}, expected only even"
275+
);
276+
}
277+
}
278+
279+
// First query (row 0) should find itself as the nearest neighbor.
280+
assert_eq!(neighbors_host[[0, 0]], 0);
281+
}
282+
171283
/// Test that an index can be searched multiple times without rebuilding.
172284
/// This validates that search() takes &self instead of self.
173285
#[test]

0 commit comments

Comments
 (0)