Skip to content

Commit 006aa84

Browse files
committed
feat(rust): add search_with_filter to CAGRA Index
Add `Index::search_with_filter()` that accepts a bitset filter via DLPack ManagedTensor. The C API `cuvsCagraSearch()` already supports `cuvsFilter` with BITSET type, but the Rust bindings hardcoded `NO_FILTER`. This exposes the existing capability. The bitset is a 1-D uint32 device tensor with ceil(n_rows / 32) elements. Bit = 1 includes the row, bit = 0 excludes it. Filtering happens during graph traversal, not post-retrieval. Includes test: builds a 256-point index, filters to even-indexed rows, verifies all returned neighbors pass the filter.
1 parent a626f60 commit 006aa84

1 file changed

Lines changed: 112 additions & 0 deletions

File tree

rust/cuvs/src/cagra/index.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,48 @@ impl Index {
9797
}
9898
}
9999

100+
/// Perform a filtered Approximate Nearest Neighbors search on the Index
101+
///
102+
/// Like [`search`](Self::search), but accepts a bitset filter to exclude
103+
/// vectors during graph traversal. Filtered vectors are never visited,
104+
/// giving better recall than post-filtering.
105+
///
106+
/// # Arguments
107+
///
108+
/// * `res` - Resources to use
109+
/// * `params` - Parameters to use in searching the index
110+
/// * `queries` - A matrix in device memory to query for
111+
/// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
112+
/// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
113+
/// * `bitset` - A 1-D `uint32` device tensor with `ceil(n_rows / 32)` elements.
114+
/// Each bit corresponds to a dataset row: bit 1 = include, bit 0 = exclude.
115+
pub fn search_with_filter(
116+
&self,
117+
res: &Resources,
118+
params: &SearchParams,
119+
queries: &ManagedTensor,
120+
neighbors: &ManagedTensor,
121+
distances: &ManagedTensor,
122+
bitset: &ManagedTensor,
123+
) -> Result<()> {
124+
unsafe {
125+
let prefilter = ffi::cuvsFilter {
126+
addr: bitset.as_ptr() as usize,
127+
type_: ffi::cuvsFilterType::BITSET,
128+
};
129+
130+
check_cuvs(ffi::cuvsCagraSearch(
131+
res.0,
132+
params.0,
133+
self.0,
134+
queries.as_ptr(),
135+
neighbors.as_ptr(),
136+
distances.as_ptr(),
137+
prefilter,
138+
))
139+
}
140+
}
141+
100142
/// Save the CAGRA index to file.
101143
///
102144
/// Experimental, both the API and the serialization format are subject to change.
@@ -254,6 +296,76 @@ mod tests {
254296
test_cagra(build_params);
255297
}
256298

299+
/// Test bitset-filtered search: exclude odd-indexed rows, verify they don't appear.
300+
#[test]
301+
fn test_cagra_search_with_filter() {
302+
let res = Resources::new().unwrap();
303+
let build_params = IndexParams::new().unwrap();
304+
305+
let n_datapoints = 256;
306+
let n_features = 16;
307+
let dataset =
308+
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
309+
310+
let index =
311+
Index::build(&res, &build_params, &dataset).expect("failed to create cagra index");
312+
313+
// Build a bitset that includes only even-indexed rows
314+
let n_words = (n_datapoints + 31) / 32;
315+
let mut bitset_host = ndarray::Array::<u32, _>::zeros(ndarray::Ix1(n_words));
316+
for i in 0..n_datapoints {
317+
if i % 2 == 0 {
318+
bitset_host[i / 32] |= 1u32 << (i % 32);
319+
}
320+
}
321+
let bitset = ManagedTensor::from(&bitset_host).to_device(&res).unwrap();
322+
323+
// Query with the first 4 even-indexed rows
324+
let n_queries = 4;
325+
let queries = dataset.slice(s![0..n_queries * 2;2, ..]); // rows 0, 2, 4, 6
326+
let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
327+
328+
let k = 10;
329+
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
330+
let neighbors = ManagedTensor::from(&neighbors_host)
331+
.to_device(&res)
332+
.unwrap();
333+
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
334+
let distances = ManagedTensor::from(&distances_host)
335+
.to_device(&res)
336+
.unwrap();
337+
338+
let search_params = SearchParams::new().unwrap();
339+
340+
index
341+
.search_with_filter(
342+
&res,
343+
&search_params,
344+
&queries,
345+
&neighbors,
346+
&distances,
347+
&bitset,
348+
)
349+
.unwrap();
350+
351+
neighbors.to_host(&res, &mut neighbors_host).unwrap();
352+
353+
// All returned neighbors must be even-indexed (odd rows are filtered out).
354+
for q in 0..n_queries {
355+
for n in 0..k {
356+
let neighbor_id = neighbors_host[[q, n]];
357+
assert_eq!(
358+
neighbor_id % 2,
359+
0,
360+
"query {q}, neighbor {n}: got odd index {neighbor_id}, expected only even"
361+
);
362+
}
363+
}
364+
365+
// First query (row 0) should find itself as the nearest neighbor.
366+
assert_eq!(neighbors_host[[0, 0]], 0);
367+
}
368+
257369
/// Test that an index can be searched multiple times without rebuilding.
258370
/// This validates that `search()` takes `&self` instead of `self`.
259371
#[test]

0 commit comments

Comments
 (0)