From 5e7ba0f14bf4bbf4484fe285415c6699847d318b Mon Sep 17 00:00:00 2001 From: Thomas Nibler Date: Mon, 27 Oct 2025 17:46:36 +0100 Subject: [PATCH 1/7] DRAFT: error handling, safe Rust methods --- rust/lib.rs | 431 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 371 insertions(+), 60 deletions(-) diff --git a/rust/lib.rs b/rust/lib.rs index 3b97b830d..e16f9e7f7 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -527,6 +527,7 @@ pub enum MetricFunction { /// refer to the individual method documentation. pub struct Index { inner: cxx::UniquePtr, + scalar_kind: ScalarKind, metric_fn: Option, } @@ -585,10 +586,65 @@ impl Clone for ffi::IndexOptions { } } +/// Data types are cast on the C++ side, but only some conversions are valid. +/// TODO: think about this more. I guess you can cast int8 to double, +/// but it's more likely an error. Maybe all casts except float to smaller float should be errors? +fn is_kind_convertible_to(a: ScalarKind, b: ScalarKind) -> bool { + match a { + ScalarKind::F16 | ScalarKind::F32 | ScalarKind::F64 | ScalarKind::BF16 => [ + ScalarKind::F16, + ScalarKind::F32, + ScalarKind::F64, + ScalarKind::BF16, + ] + .contains(&b), + ScalarKind::B1 | ScalarKind::I8 => a == b, + ScalarKind::Unknown => false, + ScalarKind { repr: _ } => unreachable!("Invalid Enum representation"), + } +} + +#[non_exhaustive] +pub enum IndexOperationError { + TypeError(ScalarKind, ScalarKind), + CXXException(cxx::Exception), +} + +impl std::error::Error for IndexOperationError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + IndexOperationError::CXXException(exception) => Some(exception), + _ => None, + } + } +} + +impl std::fmt::Display for IndexOperationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IndexOperationError::TypeError(twant, tgot) => write!( + f, + "Type Error: Attempted to use an Index storing type {twant:?} with type {tgot:?}.", + ), + IndexOperationError::CXXException(exception) => { + write!(f, "C++ Exception: {}", exception) + } + } + } +} + +impl std::fmt::Debug for IndexOperationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self, f) + } +} + /// The `VectorType` trait defines operations for managing and querying vectors /// in an index. It supports generic operations on vectors of different types, /// allowing for the addition, retrieval, and search of vectors within an index. pub trait VectorType { + const KIND: ScalarKind; + /// Adds a vector to the index under the specified key. /// /// # Parameters @@ -598,8 +654,33 @@ pub trait VectorType { /// /// # Returns /// - `Ok(())` if the vector was successfully added to the index. - /// - `Err(cxx::Exception)` if an error occurred during the operation. - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> + /// - `Err(IndexOperationError)` if an error occurred during the operation. + fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), IndexOperationError> + where + Self: Sized, + { + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and vector scalar types are compatible + // Check that `vector.len()` matches `dimensionality` happens on the C++ side. + unsafe { + Self::add_unchecked(index, key, vector).map_err(IndexOperationError::CXXException) + } + } + + /// Adds a vector to the index under the specified key. + /// Refer to [VectorType::add] for usage. + /// + /// # Safety + /// + /// - The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as that of `vector`. + unsafe fn add_unchecked(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> where Self: Sized; @@ -611,10 +692,39 @@ pub trait VectorType { /// - `buffer`: A mutable slice where the retrieved vector will be stored. The size of the /// buffer determines the maximum number of elements that can be retrieved. /// - /// # Returns + /// # Retuns /// - `Ok(usize)` indicating the number of elements actually written into the `buffer`. - /// - `Err(cxx::Exception)` if an error occurred during the operation. - fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result + /// - `Err(IndexOperationError)` if an error occurred during the operation. + fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result + where + Self: Sized, + { + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and vector scalar types are the same. + // Check that `buffer` is large enough happens on the C++ side. + unsafe { + Self::get_unchecked(index, key, buffer).map_err(IndexOperationError::CXXException) + } + } + + /// Retrieves a vector from the index by its key. + /// Refer to [Index::get] for usage. + /// + /// # Safety + /// + /// The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as that of `buffer`. + unsafe fn get_unchecked( + index: &Index, + key: Key, + buffer: &mut [Self], + ) -> Result where Self: Sized; @@ -628,8 +738,41 @@ pub trait VectorType { /// /// # Returns /// - `Ok(ffi::Matches)` containing the matches found. - /// - `Err(cxx::Exception)` if an error occurred during the search operation. - fn search(index: &Index, query: &[Self], count: usize) -> Result + /// - `Err(IndexOperationError)` if an error occurred during the search operation. + fn search( + index: &Index, + query: &[Self], + count: usize, + ) -> Result + where + Self: Sized, + { + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and vector scalar types are the same. + unsafe { + Self::search_unchecked(index, query, count).map_err(IndexOperationError::CXXException) + } + } + + /// Performs a search in the index using the given query vector, returning + /// up to `count` closest matches. + /// Refer to [Index::search] for usage. + /// + /// # Safety + /// + /// The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as that of `query`. + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result where Self: Sized; @@ -644,11 +787,40 @@ pub trait VectorType { /// /// # Returns /// - `Ok(ffi::Matches)` containing the matches found. - /// - `Err(cxx::Exception)` if an error occurred during the search operation. + /// - `Err(IndexOperationError)` if an error occurred during the search operation. fn exact_search( index: &Index, query: &[Self], count: usize, + ) -> Result + where + Self: Sized, + { + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and vector scalar types are the same. + unsafe { + Self::exact_search_unchecked(index, query, count) + .map_err(IndexOperationError::CXXException) + } + } + + /// Performs an exact (brute force) search in the index using the given query vector. + /// Refer to [Index::exact_search] for usage. + /// + /// # Safety + /// + /// The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as that of `query`. + unsafe fn exact_search_unchecked( + index: &Index, + query: &[Self], + count: usize, ) -> Result where Self: Sized; @@ -671,6 +843,38 @@ pub trait VectorType { query: &[Self], count: usize, filter: F, + ) -> Result + where + Self: Sized, + F: Fn(Key) -> bool, + { + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and vector scalar types are the same. + unsafe { + Self::filtered_search_unchecked(index, query, count, filter) + .map_err(IndexOperationError::CXXException) + } + } + + /// Performs a filtered search in the index using a query vector and a custom + /// filter function. + /// Refer to [Index::filtered_search] for usage. + /// + /// # Safety + /// + /// The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as that of `query`. + unsafe fn filtered_search_unchecked( + index: &Index, + query: &[Self], + count: usize, + filter: F, ) -> Result where Self: Sized, @@ -685,21 +889,55 @@ pub trait VectorType { /// /// # Returns /// - `Ok(())` if the metric was successfully changed. - /// - `Err(cxx::Exception)` if an error occurred during the operation. + /// - `Err(IndexOperationError)` if an error occurred during the operation. fn change_metric( index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, + ) -> Result<(), IndexOperationError> + where + Self: Sized, + { + // TODO: same question as higher up. what kind of casts are allowed and sensible here? + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and metric types are the same. + unsafe { + Self::change_metric_unchecked(index, metric).map_err(IndexOperationError::CXXException) + } + } + + /// Changes the metric used for distance calculations within the index. + /// Refer to [Index::change_metric] for usage. + /// + /// # Safety + /// + /// The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as the arguments to `metric`. + unsafe fn change_metric_unchecked( + index: &mut Index, + metric: std::boxed::Box Distance + Send + Sync>, ) -> Result<(), cxx::Exception> where Self: Sized; } impl VectorType for f32 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { + const KIND: ScalarKind = ScalarKind::F32; + + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { index.inner.search_f32(query, count) } - fn exact_search( + unsafe fn exact_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -707,15 +945,23 @@ impl VectorType for f32 { index.inner.exact_search_f32(query, count) } - fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { + unsafe fn get_unchecked( + index: &Index, + key: Key, + vector: &mut [Self], + ) -> Result { index.inner.get_f32(key, vector) } - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + unsafe fn add_unchecked( + index: &Index, + key: Key, + vector: &[Self], + ) -> Result<(), cxx::Exception> { index.inner.add_f32(key, vector) } - fn filtered_search( + unsafe fn filtered_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -739,7 +985,7 @@ impl VectorType for f32 { .filtered_search_f32(query, count, trampoline_fn, closure_address) } - fn change_metric( + unsafe fn change_metric_unchecked( index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, ) -> Result<(), cxx::Exception> { @@ -770,11 +1016,17 @@ impl VectorType for f32 { } impl VectorType for i8 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { + const KIND: ScalarKind = ScalarKind::I8; + + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { index.inner.search_i8(query, count) } - fn exact_search( + unsafe fn exact_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -782,15 +1034,23 @@ impl VectorType for i8 { index.inner.exact_search_i8(query, count) } - fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { + unsafe fn get_unchecked( + index: &Index, + key: Key, + vector: &mut [Self], + ) -> Result { index.inner.get_i8(key, vector) } - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + unsafe fn add_unchecked( + index: &Index, + key: Key, + vector: &[Self], + ) -> Result<(), cxx::Exception> { index.inner.add_i8(key, vector) } - fn filtered_search( + unsafe fn filtered_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -813,7 +1073,7 @@ impl VectorType for i8 { .inner .filtered_search_i8(query, count, trampoline_fn, closure_address) } - fn change_metric( + unsafe fn change_metric_unchecked( index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, ) -> Result<(), cxx::Exception> { @@ -844,11 +1104,17 @@ impl VectorType for i8 { } impl VectorType for f64 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { + const KIND: ScalarKind = ScalarKind::F64; + + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { index.inner.search_f64(query, count) } - fn exact_search( + unsafe fn exact_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -856,15 +1122,23 @@ impl VectorType for f64 { index.inner.exact_search_f64(query, count) } - fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { + unsafe fn get_unchecked( + index: &Index, + key: Key, + vector: &mut [Self], + ) -> Result { index.inner.get_f64(key, vector) } - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + unsafe fn add_unchecked( + index: &Index, + key: Key, + vector: &[Self], + ) -> Result<(), cxx::Exception> { index.inner.add_f64(key, vector) } - fn filtered_search( + unsafe fn filtered_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -887,7 +1161,7 @@ impl VectorType for f64 { .inner .filtered_search_f64(query, count, trampoline_fn, closure_address) } - fn change_metric( + unsafe fn change_metric_unchecked( index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, ) -> Result<(), cxx::Exception> { @@ -918,11 +1192,17 @@ impl VectorType for f64 { } impl VectorType for f16 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { + const KIND: ScalarKind = ScalarKind::F16; + + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { index.inner.search_f16(f16::to_i16s(query), count) } - fn exact_search( + unsafe fn exact_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -930,15 +1210,23 @@ impl VectorType for f16 { index.inner.exact_search_f16(f16::to_i16s(query), count) } - fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { + unsafe fn get_unchecked( + index: &Index, + key: Key, + vector: &mut [Self], + ) -> Result { index.inner.get_f16(key, f16::to_mut_i16s(vector)) } - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + unsafe fn add_unchecked( + index: &Index, + key: Key, + vector: &[Self], + ) -> Result<(), cxx::Exception> { index.inner.add_f16(key, f16::to_i16s(vector)) } - fn filtered_search( + unsafe fn filtered_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -957,15 +1245,12 @@ impl VectorType for f16 { // Temporarily cast the closure to a raw pointer for passing. let trampoline_fn: usize = trampoline:: as *const () as usize; let closure_address: usize = &filter as *const F as usize; - index.inner.filtered_search_f16( - f16::to_i16s(query), - count, - trampoline_fn, - closure_address, - ) + index + .inner + .filtered_search_f16(f16::to_i16s(query), count, trampoline_fn, closure_address) } - fn change_metric( + unsafe fn change_metric_unchecked( index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, ) -> Result<(), cxx::Exception> { @@ -996,11 +1281,17 @@ impl VectorType for f16 { } impl VectorType for b1x8 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { + const KIND: ScalarKind = ScalarKind::B1; + + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { index.inner.search_b1x8(b1x8::to_u8s(query), count) } - fn exact_search( + unsafe fn exact_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -1008,15 +1299,23 @@ impl VectorType for b1x8 { index.inner.exact_search_b1x8(b1x8::to_u8s(query), count) } - fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { + unsafe fn get_unchecked( + index: &Index, + key: Key, + vector: &mut [Self], + ) -> Result { index.inner.get_b1x8(key, b1x8::to_mut_u8s(vector)) } - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + unsafe fn add_unchecked( + index: &Index, + key: Key, + vector: &[Self], + ) -> Result<(), cxx::Exception> { index.inner.add_b1x8(key, b1x8::to_u8s(vector)) } - fn filtered_search( + unsafe fn filtered_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -1035,15 +1334,12 @@ impl VectorType for b1x8 { // Temporarily cast the closure to a raw pointer for passing. let trampoline_fn: usize = trampoline:: as *const () as usize; let closure_address: usize = &filter as *const F as usize; - index.inner.filtered_search_b1x8( - b1x8::to_u8s(query), - count, - trampoline_fn, - closure_address, - ) + index + .inner + .filtered_search_b1x8(b1x8::to_u8s(query), count, trampoline_fn, closure_address) } - fn change_metric( + unsafe fn change_metric_unchecked( index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, ) -> Result<(), cxx::Exception> { @@ -1078,6 +1374,7 @@ impl Index { match ffi::new_native_index(options) { Ok(inner) => Result::Ok(Self { inner, + scalar_kind: options.quantization, metric_fn: None, }), Err(err) => Err(err), @@ -1113,8 +1410,8 @@ impl Index { pub fn change_metric( self: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, - ) { - T::change_metric(self, metric).unwrap(); + ) -> Result<(), IndexOperationError> { + T::change_metric(self, metric) } /// Retrieves the hardware acceleration information. @@ -1140,7 +1437,7 @@ impl Index { self: &Index, query: &[T], count: usize, - ) -> Result { + ) -> Result { T::search(self, query, count) } @@ -1160,7 +1457,7 @@ impl Index { self: &Index, query: &[T], count: usize, - ) -> Result { + ) -> Result { T::exact_search(self, query, count) } @@ -1181,7 +1478,7 @@ impl Index { query: &[T], count: usize, filter: F, - ) -> Result + ) -> Result where F: Fn(Key) -> bool, { @@ -1194,7 +1491,11 @@ impl Index { /// /// * `key` - The key associated with the vector. /// * `vector` - A slice containing the vector data. - pub fn add(self: &Index, key: Key, vector: &[T]) -> Result<(), cxx::Exception> { + pub fn add( + self: &Index, + key: Key, + vector: &[T], + ) -> Result<(), IndexOperationError> { T::add(self, key, vector) } @@ -1213,7 +1514,7 @@ impl Index { self: &Index, key: Key, vector: &mut [T], - ) -> Result { + ) -> Result { T::get(self, key, vector) } @@ -1228,7 +1529,7 @@ impl Index { self: &Index, key: Key, vector: &mut Vec, - ) -> Result { + ) -> Result { let dim = self.dimensions(); let max_matches = self.count(key); vector.resize(dim * max_matches, T::default()); @@ -1531,10 +1832,12 @@ mod tests { let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1]; let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1]; let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0]; + let wrong_type: [i8; 5] = [1, 2, 3, 4, 5]; assert!(index.add(1, &first).is_ok()); assert!(index.add(2, &second).is_ok()); assert!(index.add(3, &too_long).is_err()); assert!(index.add(4, &too_short).is_err()); + assert!(index.add(5, &wrong_type).is_err()); assert_eq!(index.size(), 2); // Test using Vec @@ -1567,6 +1870,7 @@ mod tests { let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1]; let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1]; let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0]; + let wrong_type: [i8; 5] = [1, 2, 3, 4, 5]; assert!(index.add(1, &first).is_ok()); assert!(index.add(2, &second).is_ok()); assert_eq!(index.size(), 2); @@ -1574,7 +1878,11 @@ mod tests { //assert!(index.add(4, &too_short).is_err()); assert!(index.search(&too_long, 1).is_err()); + assert!(index.exact_search(&too_long, 1).is_err()); assert!(index.search(&too_short, 1).is_err()); + assert!(index.exact_search(&too_short, 1).is_err()); + assert!(index.search(&wrong_type, 1).is_err()); + assert!(index.exact_search(&wrong_type, 1).is_err()); } #[test] @@ -1887,7 +2195,10 @@ mod tests { (a_slice[0] - b_slice[0]).abs() * first_factor + (a_slice[1] - b_slice[1]).abs() * second_factor }); - index.change_metric(stateful_distance); + assert!(index.change_metric(stateful_distance).is_ok()); + + let wrong_type = Box::new(move |_: *const b1x8, _: *const b1x8| 0.0); + assert!(index.change_metric(wrong_type).is_err()); let another_vector: [f32; 2] = [0.0, 1.0]; index.add(2, &another_vector).unwrap(); From 0b0f9b5e6c5d6122e003f02810106c2a3e96e993 Mon Sep 17 00:00:00 2001 From: Thomas Nibler Date: Tue, 28 Oct 2025 12:26:32 +0100 Subject: [PATCH 2/7] Add Sized bound to VectorType trait, mark trait unsafe Every method has a Sized bound anyway, so might as well put it a level higher. Also mark VectorType unsafe, since implementors must follow safety invariants, or else the entire program becomes unsound. --- rust/lib.rs | 53 ++++++++++++++--------------------------------------- 1 file changed, 14 insertions(+), 39 deletions(-) diff --git a/rust/lib.rs b/rust/lib.rs index e16f9e7f7..a49ca95ef 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -642,7 +642,7 @@ impl std::fmt::Debug for IndexOperationError { /// The `VectorType` trait defines operations for managing and querying vectors /// in an index. It supports generic operations on vectors of different types, /// allowing for the addition, retrieval, and search of vectors within an index. -pub trait VectorType { +pub unsafe trait VectorType: Sized { const KIND: ScalarKind; /// Adds a vector to the index under the specified key. @@ -655,10 +655,7 @@ pub trait VectorType { /// # Returns /// - `Ok(())` if the vector was successfully added to the index. /// - `Err(IndexOperationError)` if an error occurred during the operation. - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), IndexOperationError> - where - Self: Sized, - { + fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), IndexOperationError> { if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { return Err(IndexOperationError::TypeError( Self::KIND, @@ -695,10 +692,7 @@ pub trait VectorType { /// # Retuns /// - `Ok(usize)` indicating the number of elements actually written into the `buffer`. /// - `Err(IndexOperationError)` if an error occurred during the operation. - fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result - where - Self: Sized, - { + fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result { if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { return Err(IndexOperationError::TypeError( Self::KIND, @@ -724,9 +718,7 @@ pub trait VectorType { index: &Index, key: Key, buffer: &mut [Self], - ) -> Result - where - Self: Sized; + ) -> Result; /// Performs a search in the index using the given query vector, returning /// up to `count` closest matches. @@ -743,10 +735,7 @@ pub trait VectorType { index: &Index, query: &[Self], count: usize, - ) -> Result - where - Self: Sized, - { + ) -> Result { if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { return Err(IndexOperationError::TypeError( Self::KIND, @@ -772,9 +761,7 @@ pub trait VectorType { index: &Index, query: &[Self], count: usize, - ) -> Result - where - Self: Sized; + ) -> Result; /// Performs an exact (brute force) search in the index using the given query vector, returning /// up to `count` closest matches. This search checks all vectors in the index, guaranteeing to find @@ -792,10 +779,7 @@ pub trait VectorType { index: &Index, query: &[Self], count: usize, - ) -> Result - where - Self: Sized, - { + ) -> Result { if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { return Err(IndexOperationError::TypeError( Self::KIND, @@ -821,9 +805,7 @@ pub trait VectorType { index: &Index, query: &[Self], count: usize, - ) -> Result - where - Self: Sized; + ) -> Result; /// Performs a filtered search in the index using a query vector and a custom /// filter function, returning up to `count` matches that satisfy the filter. @@ -845,7 +827,6 @@ pub trait VectorType { filter: F, ) -> Result where - Self: Sized, F: Fn(Key) -> bool, { if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { @@ -877,7 +858,6 @@ pub trait VectorType { filter: F, ) -> Result where - Self: Sized, F: Fn(Key) -> bool; /// Changes the metric used for distance calculations within the index. @@ -893,10 +873,7 @@ pub trait VectorType { fn change_metric( index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), IndexOperationError> - where - Self: Sized, - { + ) -> Result<(), IndexOperationError> { // TODO: same question as higher up. what kind of casts are allowed and sensible here? if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { return Err(IndexOperationError::TypeError( @@ -922,11 +899,9 @@ pub trait VectorType { index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, ) -> Result<(), cxx::Exception> - where - Self: Sized; } -impl VectorType for f32 { +unsafe impl VectorType for f32 { const KIND: ScalarKind = ScalarKind::F32; unsafe fn search_unchecked( @@ -1015,7 +990,7 @@ impl VectorType for f32 { } } -impl VectorType for i8 { +unsafe impl VectorType for i8 { const KIND: ScalarKind = ScalarKind::I8; unsafe fn search_unchecked( @@ -1103,7 +1078,7 @@ impl VectorType for i8 { } } -impl VectorType for f64 { +unsafe impl VectorType for f64 { const KIND: ScalarKind = ScalarKind::F64; unsafe fn search_unchecked( @@ -1191,7 +1166,7 @@ impl VectorType for f64 { } } -impl VectorType for f16 { +unsafe impl VectorType for f16 { const KIND: ScalarKind = ScalarKind::F16; unsafe fn search_unchecked( @@ -1280,7 +1255,7 @@ impl VectorType for f16 { } } -impl VectorType for b1x8 { +unsafe impl VectorType for b1x8 { const KIND: ScalarKind = ScalarKind::B1; unsafe fn search_unchecked( From 7612efa2c89e7c2278d22bd2a0dfebe1c4dbfaaf Mon Sep 17 00:00:00 2001 From: Thomas Nibler Date: Tue, 28 Oct 2025 12:27:50 +0100 Subject: [PATCH 3/7] Use generic change_metric_unchecked method, remove copy-paste impls --- .envrc | 1 + rust/lib.rs | 172 +++++++++++++++++----------------------------------- 2 files changed, 58 insertions(+), 115 deletions(-) create mode 100644 .envrc diff --git a/.envrc b/.envrc new file mode 100644 index 000000000..3550a30f2 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/rust/lib.rs b/rust/lib.rs index a49ca95ef..904e8ad5f 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -644,6 +644,9 @@ impl std::fmt::Debug for IndexOperationError { /// allowing for the addition, retrieval, and search of vectors within an index. pub unsafe trait VectorType: Sized { const KIND: ScalarKind; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunction; /// Adds a vector to the index under the specified key. /// @@ -898,11 +901,52 @@ pub unsafe trait VectorType: Sized { unsafe fn change_metric_unchecked( index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), cxx::Exception> + ) -> Result<(), cxx::Exception> { + if let Some(metric) = index.metric_fn.take() { + // SAFETY: We have an exclusive &mut to Index, so no one can be using the + // pointed-to closure. + unsafe { + drop(metric.into_owned()); + } + } + + index.metric_fn = Some(Self::METRIC_FN(Box::into_raw(Box::new(metric)))); + + // Trampoline is the function that knows how to call the Rust closure. + // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, + // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function + // and the number of dimensions. + extern "C" fn trampoline( + first: usize, + second: usize, + closure_address: usize, + ) -> Distance { + let first_ptr = first as *const T; + let second_ptr = second as *const T; + let closure: *mut _ = + closure_address as *mut Box Distance>; + unsafe { (*closure)(first_ptr, second_ptr) } + } + + let trampoline_fn: usize = trampoline:: as *const () as usize; + let closure_address = match index.metric_fn.as_ref().expect("Was just set to Some") { + MetricFunction::F32Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunction::B1X8Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunction::I8Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunction::F16Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunction::F64Metric(metric) => (*metric as *mut _) as *mut () as usize, + }; + index.inner.change_metric(trampoline_fn, closure_address); + + Ok(()) + } } unsafe impl VectorType for f32 { const KIND: ScalarKind = ScalarKind::F32; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunction = MetricFunction::F32Metric; unsafe fn search_unchecked( index: &Index, @@ -959,39 +1003,13 @@ unsafe impl VectorType for f32 { .inner .filtered_search_f32(query, count, trampoline_fn, closure_address) } - - unsafe fn change_metric_unchecked( - index: &mut Index, - metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), cxx::Exception> { - // Store the metric function in the Index. - type MetricFn = Box Distance>; - index.metric_fn = Some(MetricFunction::F32Metric(Box::into_raw(Box::new(metric)))); - - // Trampoline is the function that knows how to call the Rust closure. - // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, - // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function - // and the number of dimensions. - extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { - let first_ptr = first as *const f32; - let second_ptr = second as *const f32; - let closure: *mut MetricFn = closure_address as *mut MetricFn; - unsafe { (*closure)(first_ptr, second_ptr) } - } - - let trampoline_fn: usize = trampoline as *const () as usize; - let closure_address = match index.metric_fn { - Some(MetricFunction::F32Metric(metric)) => metric as *mut () as usize, - _ => panic!("Expected F32Metric"), - }; - index.inner.change_metric(trampoline_fn, closure_address); - - Ok(()) - } } unsafe impl VectorType for i8 { const KIND: ScalarKind = ScalarKind::I8; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunction = MetricFunction::I8Metric; unsafe fn search_unchecked( index: &Index, @@ -1048,38 +1066,13 @@ unsafe impl VectorType for i8 { .inner .filtered_search_i8(query, count, trampoline_fn, closure_address) } - unsafe fn change_metric_unchecked( - index: &mut Index, - metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), cxx::Exception> { - // Store the metric function in the Index. - type MetricFn = Box Distance>; - index.metric_fn = Some(MetricFunction::I8Metric(Box::into_raw(Box::new(metric)))); - - // Trampoline is the function that knows how to call the Rust closure. - // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, - // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function - // and the number of dimensions. - extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { - let first_ptr = first as *const i8; - let second_ptr = second as *const i8; - let closure: *mut MetricFn = closure_address as *mut MetricFn; - unsafe { (*closure)(first_ptr, second_ptr) } - } - - let trampoline_fn: usize = trampoline as *const () as usize; - let closure_address = match index.metric_fn { - Some(MetricFunction::I8Metric(metric)) => metric as *mut () as usize, - _ => panic!("Expected I8Metric"), - }; - index.inner.change_metric(trampoline_fn, closure_address); - - Ok(()) - } } unsafe impl VectorType for f64 { const KIND: ScalarKind = ScalarKind::F64; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunction = MetricFunction::F64Metric; unsafe fn search_unchecked( index: &Index, @@ -1136,38 +1129,13 @@ unsafe impl VectorType for f64 { .inner .filtered_search_f64(query, count, trampoline_fn, closure_address) } - unsafe fn change_metric_unchecked( - index: &mut Index, - metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), cxx::Exception> { - // Store the metric function in the Index. - type MetricFn = Box Distance>; - index.metric_fn = Some(MetricFunction::F64Metric(Box::into_raw(Box::new(metric)))); - - // Trampoline is the function that knows how to call the Rust closure. - // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, - // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function - // and the number of dimensions. - extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { - let first_ptr = first as *const f64; - let second_ptr = second as *const f64; - let closure: *mut MetricFn = closure_address as *mut MetricFn; - unsafe { (*closure)(first_ptr, second_ptr) } - } - - let trampoline_fn: usize = trampoline as *const () as usize; - let closure_address = match index.metric_fn { - Some(MetricFunction::F64Metric(metric)) => metric as *mut () as usize, - _ => panic!("Expected F64Metric"), - }; - index.inner.change_metric(trampoline_fn, closure_address); - - Ok(()) - } } unsafe impl VectorType for f16 { const KIND: ScalarKind = ScalarKind::F16; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunction = MetricFunction::F16Metric; unsafe fn search_unchecked( index: &Index, @@ -1224,39 +1192,13 @@ unsafe impl VectorType for f16 { .inner .filtered_search_f16(f16::to_i16s(query), count, trampoline_fn, closure_address) } - - unsafe fn change_metric_unchecked( - index: &mut Index, - metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), cxx::Exception> { - // Store the metric function in the Index. - type MetricFn = Box Distance>; - index.metric_fn = Some(MetricFunction::F16Metric(Box::into_raw(Box::new(metric)))); - - // Trampoline is the function that knows how to call the Rust closure. - // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, - // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function - // and the number of dimensions. - extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { - let first_ptr = first as *const f16; - let second_ptr = second as *const f16; - let closure: *mut MetricFn = closure_address as *mut MetricFn; - unsafe { (*closure)(first_ptr, second_ptr) } - } - - let trampoline_fn: usize = trampoline as *const () as usize; - let closure_address = match index.metric_fn { - Some(MetricFunction::F16Metric(metric)) => metric as *mut () as usize, - _ => panic!("Expected F16Metric"), - }; - index.inner.change_metric(trampoline_fn, closure_address); - - Ok(()) - } } unsafe impl VectorType for b1x8 { const KIND: ScalarKind = ScalarKind::B1; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunction = MetricFunction::B1X8Metric; unsafe fn search_unchecked( index: &Index, From 3f585d03d6c5d6f67c0509474177c6cb70e36761 Mon Sep 17 00:00:00 2001 From: Thomas Nibler Date: Tue, 28 Oct 2025 12:28:12 +0100 Subject: [PATCH 4/7] Add (failing) test for memory leak in Index::change_metric --- rust/lib.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/rust/lib.rs b/rust/lib.rs index 904e8ad5f..f7aed7126 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -2121,6 +2121,34 @@ mod tests { index.add(2, &another_vector).unwrap(); } + #[test] + fn test_change_metric_leak() { + let options = IndexOptions { + dimensions: 2, + quantization: ScalarKind::F32, + ..Default::default() + }; + let mut index = Index::new(&options).unwrap(); + index.reserve(10).unwrap(); + + let vector: [f32; 2] = [1.0, 0.0]; + index.add(1, &vector).unwrap(); + + let counter = std::sync::Arc::new(std::sync::Mutex::new(0f32)); + + let n: i32 = 100; + for _ in 0..n { + let counter_copy = counter.clone(); + let metric = + Box::new(move |_: *const f32, _: *const f32| *counter_copy.lock().unwrap()); + index.change_metric(metric).unwrap(); + } + drop(index); + // Only one reference to counter (the one held in this scope, not the closure) + // should be left. + assert_eq!(std::sync::Arc::strong_count(&counter), 1); + } + #[test] fn test_binary_vectors_and_hamming_distance() { let index = Index::new(&IndexOptions { From e8eb7932715c587a57a7c0177c2f2f0559952c81 Mon Sep 17 00:00:00 2001 From: Thomas Nibler Date: Tue, 28 Oct 2025 12:46:52 +0100 Subject: [PATCH 5/7] Fix leaked metric closure in change_metric --- rust/lib.rs | 92 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 32 deletions(-) diff --git a/rust/lib.rs b/rust/lib.rs index f7aed7126..9e035c695 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -481,7 +481,7 @@ pub use ffi::{IndexOptions, MetricKind, ScalarKind}; /// ``` /// /// In this example, `dimensions` should be defined and valid for the vectors `a` and `b`. -pub enum MetricFunction { +pub enum MetricFunctionPtr { B1X8Metric(*mut std::boxed::Box Distance + Send + Sync>), I8Metric(*mut std::boxed::Box Distance + Send + Sync>), F16Metric(*mut std::boxed::Box Distance + Send + Sync>), @@ -489,6 +489,45 @@ pub enum MetricFunction { F64Metric(*mut std::boxed::Box Distance + Send + Sync>), } +impl MetricFunctionPtr { + /// Cast inner non-owning raw pointer to boxed closure into owned Box. + /// + /// # Safety + /// + /// Pointer must be valid, and the returned value must not be dropped + /// as long as the C++ side is still using the pointer to the closure. + unsafe fn into_owned(self) -> MetricFunctionOwned { + unsafe { + match self { + MetricFunctionPtr::B1X8Metric(pointer) => { + MetricFunctionOwned::B1X8Metric(Box::from_raw(pointer)) + } + MetricFunctionPtr::I8Metric(pointer) => { + MetricFunctionOwned::I8Metric(Box::from_raw(pointer)) + } + MetricFunctionPtr::F16Metric(pointer) => { + MetricFunctionOwned::F16Metric(Box::from_raw(pointer)) + } + MetricFunctionPtr::F32Metric(pointer) => { + MetricFunctionOwned::F32Metric(Box::from_raw(pointer)) + } + MetricFunctionPtr::F64Metric(pointer) => { + MetricFunctionOwned::F64Metric(Box::from_raw(pointer)) + } + } + } + } +} + +enum MetricFunctionOwned { + // Double boxed because Box is a wide pointer, and the C++ side needs a regular pointer + B1X8Metric(Box Distance + Send + Sync>>), + I8Metric(Box Distance + Send + Sync>>), + F16Metric(Box Distance + Send + Sync>>), + F32Metric(Box Distance + Send + Sync>>), + F64Metric(Box Distance + Send + Sync>>), +} + /// Approximate Nearest Neighbors search index for dense vectors. /// /// The `Index` struct provides an abstraction over a dense vector space, allowing @@ -528,7 +567,7 @@ pub enum MetricFunction { pub struct Index { inner: cxx::UniquePtr, scalar_kind: ScalarKind, - metric_fn: Option, + metric_fn: Option, } unsafe impl Send for Index {} @@ -536,23 +575,10 @@ unsafe impl Sync for Index {} impl Drop for Index { fn drop(&mut self) { - if let Some(metric) = &self.metric_fn { - match metric { - MetricFunction::B1X8Metric(pointer) => unsafe { - drop(Box::from_raw(*pointer)); - }, - MetricFunction::I8Metric(pointer) => unsafe { - drop(Box::from_raw(*pointer)); - }, - MetricFunction::F16Metric(pointer) => unsafe { - drop(Box::from_raw(*pointer)); - }, - MetricFunction::F32Metric(pointer) => unsafe { - drop(Box::from_raw(*pointer)); - }, - MetricFunction::F64Metric(pointer) => unsafe { - drop(Box::from_raw(*pointer)); - }, + if let Some(metric) = self.metric_fn.take() { + // SAFETY: the pointed-to closure is never used again after Index is dropped. + unsafe { + drop(metric.into_owned()); } } } @@ -646,7 +672,7 @@ pub unsafe trait VectorType: Sized { const KIND: ScalarKind; const METRIC_FN: fn( *mut std::boxed::Box Distance + Send + Sync>, - ) -> MetricFunction; + ) -> MetricFunctionPtr; /// Adds a vector to the index under the specified key. /// @@ -930,11 +956,11 @@ pub unsafe trait VectorType: Sized { let trampoline_fn: usize = trampoline:: as *const () as usize; let closure_address = match index.metric_fn.as_ref().expect("Was just set to Some") { - MetricFunction::F32Metric(metric) => (*metric as *mut _) as *mut () as usize, - MetricFunction::B1X8Metric(metric) => (*metric as *mut _) as *mut () as usize, - MetricFunction::I8Metric(metric) => (*metric as *mut _) as *mut () as usize, - MetricFunction::F16Metric(metric) => (*metric as *mut _) as *mut () as usize, - MetricFunction::F64Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunctionPtr::F32Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunctionPtr::B1X8Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunctionPtr::I8Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunctionPtr::F16Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunctionPtr::F64Metric(metric) => (*metric as *mut _) as *mut () as usize, }; index.inner.change_metric(trampoline_fn, closure_address); @@ -946,7 +972,7 @@ unsafe impl VectorType for f32 { const KIND: ScalarKind = ScalarKind::F32; const METRIC_FN: fn( *mut std::boxed::Box Distance + Send + Sync>, - ) -> MetricFunction = MetricFunction::F32Metric; + ) -> MetricFunctionPtr = MetricFunctionPtr::F32Metric; unsafe fn search_unchecked( index: &Index, @@ -1009,7 +1035,7 @@ unsafe impl VectorType for i8 { const KIND: ScalarKind = ScalarKind::I8; const METRIC_FN: fn( *mut std::boxed::Box Distance + Send + Sync>, - ) -> MetricFunction = MetricFunction::I8Metric; + ) -> MetricFunctionPtr = MetricFunctionPtr::I8Metric; unsafe fn search_unchecked( index: &Index, @@ -1072,7 +1098,7 @@ unsafe impl VectorType for f64 { const KIND: ScalarKind = ScalarKind::F64; const METRIC_FN: fn( *mut std::boxed::Box Distance + Send + Sync>, - ) -> MetricFunction = MetricFunction::F64Metric; + ) -> MetricFunctionPtr = MetricFunctionPtr::F64Metric; unsafe fn search_unchecked( index: &Index, @@ -1135,7 +1161,7 @@ unsafe impl VectorType for f16 { const KIND: ScalarKind = ScalarKind::F16; const METRIC_FN: fn( *mut std::boxed::Box Distance + Send + Sync>, - ) -> MetricFunction = MetricFunction::F16Metric; + ) -> MetricFunctionPtr = MetricFunctionPtr::F16Metric; unsafe fn search_unchecked( index: &Index, @@ -1198,7 +1224,7 @@ unsafe impl VectorType for b1x8 { const KIND: ScalarKind = ScalarKind::B1; const METRIC_FN: fn( *mut std::boxed::Box Distance + Send + Sync>, - ) -> MetricFunction = MetricFunction::B1X8Metric; + ) -> MetricFunctionPtr = MetricFunctionPtr::B1X8Metric; unsafe fn search_unchecked( index: &Index, @@ -1262,7 +1288,9 @@ unsafe impl VectorType for b1x8 { ) -> Result<(), cxx::Exception> { // Store the metric function in the Index. type MetricFn = Box Distance>; - index.metric_fn = Some(MetricFunction::B1X8Metric(Box::into_raw(Box::new(metric)))); + index.metric_fn = Some(MetricFunctionPtr::B1X8Metric(Box::into_raw(Box::new( + metric, + )))); // Trampoline is the function that knows how to call the Rust closure. // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, @@ -1277,7 +1305,7 @@ unsafe impl VectorType for b1x8 { let trampoline_fn: usize = trampoline as *const () as usize; let closure_address = match index.metric_fn { - Some(MetricFunction::B1X8Metric(metric)) => metric as *mut () as usize, + Some(MetricFunctionPtr::B1X8Metric(metric)) => metric as *mut () as usize, _ => panic!("Expected F1X8Metric"), }; index.inner.change_metric(trampoline_fn, closure_address); From 5532ef153bc5f3c7a0382011c124b370caf97069 Mon Sep 17 00:00:00 2001 From: Thomas Nibler Date: Tue, 28 Oct 2025 12:47:00 +0100 Subject: [PATCH 6/7] Mark MetricFunctionPtr as doc(hidden) It has to be pub since it's exposed through the public VectorType trait, but details of how FFI calls, trampolines etc are handled probably shouldn't be part of the public API. --- rust/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/lib.rs b/rust/lib.rs index 9e035c695..c5470d1e3 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -481,6 +481,7 @@ pub use ffi::{IndexOptions, MetricKind, ScalarKind}; /// ``` /// /// In this example, `dimensions` should be defined and valid for the vectors `a` and `b`. +#[doc(hidden)] pub enum MetricFunctionPtr { B1X8Metric(*mut std::boxed::Box Distance + Send + Sync>), I8Metric(*mut std::boxed::Box Distance + Send + Sync>), From 18c786b4441881fdc7d12e2e7b099398624102a7 Mon Sep 17 00:00:00 2001 From: Thomas Nibler Date: Tue, 28 Oct 2025 14:29:29 +0100 Subject: [PATCH 7/7] Move mutating/readonly Index methods into separate traits This allows us to remove the unsafe `view_from_buffer` function by introducing a new IndexView type that is associated with the buffer's lifetime. It also prevents invalid usage of immutable Index instances, which would throw C++ exceptions if e.g., add() was called on an immutable view. --- rust/lib.rs | 623 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 411 insertions(+), 212 deletions(-) diff --git a/rust/lib.rs b/rust/lib.rs index c5470d1e3..135eaa70b 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -431,6 +431,8 @@ pub mod ffi { } } +use std::marker::PhantomData; + // Re-export the FFI structs and enums at the crate root for easy access pub use ffi::{IndexOptions, MetricKind, ScalarKind}; @@ -520,6 +522,7 @@ impl MetricFunctionPtr { } } +#[allow(unused)] enum MetricFunctionOwned { // Double boxed because Box is a wide pointer, and the C++ side needs a regular pointer B1X8Metric(Box Distance + Send + Sync>>), @@ -540,7 +543,7 @@ enum MetricFunctionOwned { /// Basic usage: /// /// ```rust -/// use usearch::{Index, IndexOptions, MetricKind, ScalarKind}; +/// use usearch::{Index, IndexMethods, IndexViewMethods, IndexOptions, MetricKind, ScalarKind}; /// /// let mut options = IndexOptions::default(); /// options.dimensions = 4; // Set the number of dimensions for vectors @@ -574,6 +577,15 @@ pub struct Index { unsafe impl Send for Index {} unsafe impl Sync for Index {} +/// A read-only view into an index, read from a file or in-memory buffer. +pub struct IndexView<'buf> { + inner: Index, + _phantom_data: PhantomData<(&'buf [u8], *const ())>, +} + +unsafe impl Send for IndexView<'static> {} +unsafe impl Sync for IndexView<'static> {} + impl Drop for Index { fn drop(&mut self) { if let Some(metric) = self.metric_fn.take() { @@ -1315,59 +1327,15 @@ unsafe impl VectorType for b1x8 { } } -impl Index { - pub fn new(options: &ffi::IndexOptions) -> Result { - match ffi::new_native_index(options) { - Ok(inner) => Result::Ok(Self { - inner, - scalar_kind: options.quantization, - metric_fn: None, - }), - Err(err) => Err(err), - } - } - +pub trait IndexViewMethods { /// Retrieves the expansion value used during index creation. - pub fn expansion_add(self: &Index) -> usize { - self.inner.expansion_add() - } + fn expansion_add(&self) -> usize; /// Retrieves the expansion value used during search. - pub fn expansion_search(self: &Index) -> usize { - self.inner.expansion_search() - } - - /// Updates the expansion value used during index creation. Rarely used. - pub fn change_expansion_add(self: &Index, n: usize) { - self.inner.change_expansion_add(n) - } - - /// Updates the expansion value used during search operations. - pub fn change_expansion_search(self: &Index, n: usize) { - self.inner.change_expansion_search(n) - } - - /// Changes the metric kind used to calculate the distance between vectors. - pub fn change_metric_kind(self: &Index, metric: ffi::MetricKind) { - self.inner.change_metric_kind(metric) - } - - /// Overrides the metric function used to calculate the distance between vectors. - pub fn change_metric( - self: &mut Index, - metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), IndexOperationError> { - T::change_metric(self, metric) - } + fn expansion_search(&self) -> usize; /// Retrieves the hardware acceleration information. - pub fn hardware_acceleration(&self) -> String { - use core::ffi::CStr; - unsafe { - let c_str = CStr::from_ptr(self.inner.hardware_acceleration()); - c_str.to_string_lossy().into_owned() - } - } + fn hardware_acceleration(&self) -> String; /// Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to the provided query. /// @@ -1379,13 +1347,11 @@ impl Index { /// # Returns /// /// A `Result` containing the matches found. - pub fn search( - self: &Index, + fn search( + &self, query: &[T], count: usize, - ) -> Result { - T::search(self, query, count) - } + ) -> Result; /// Performs exact (brute force) Nearest Neighbors Search for closest vectors to the provided query. /// This search checks all vectors in the index, guaranteeing to find the true nearest neighbors, @@ -1399,13 +1365,11 @@ impl Index { /// # Returns /// /// A `Result` containing the matches found. - pub fn exact_search( - self: &Index, + fn exact_search( + &self, query: &[T], count: usize, - ) -> Result { - T::exact_search(self, query, count) - } + ) -> Result; /// Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to the provided query /// satisfying a custom filter function. @@ -1419,31 +1383,14 @@ impl Index { /// # Returns /// /// A `Result` containing the matches found. - pub fn filtered_search( - self: &Index, + fn filtered_search( + &self, query: &[T], count: usize, filter: F, ) -> Result where - F: Fn(Key) -> bool, - { - T::filtered_search(self, query, count, filter) - } - - /// Adds a vector with a specified key to the index. - /// - /// # Arguments - /// - /// * `key` - The key associated with the vector. - /// * `vector` - A slice containing the vector data. - pub fn add( - self: &Index, - key: Key, - vector: &[T], - ) -> Result<(), IndexOperationError> { - T::add(self, key, vector) - } + F: Fn(Key) -> bool; /// Extracts one or more vectors matching the specified key. /// The `vector` slice must be a multiple of the number of dimensions in the index. @@ -1456,13 +1403,7 @@ impl Index { /// /// * `key` - The key associated with the vector. /// * `vector` - A slice containing the vector data. - pub fn get( - self: &Index, - key: Key, - vector: &mut [T], - ) -> Result { - T::get(self, key, vector) - } + fn get(&self, key: Key, vector: &mut [T]) -> Result; /// Extracts one or more vectors matching specified key into supplied resizable vector. /// The `vector` is resized to a multiple of the number of dimensions in the index. @@ -1471,206 +1412,446 @@ impl Index { /// /// * `key` - The key associated with the vector. /// * `vector` - A mutable vector containing the vector data. - pub fn export( - self: &Index, + fn export( + &self, key: Key, vector: &mut Vec, - ) -> Result { - let dim = self.dimensions(); - let max_matches = self.count(key); - vector.resize(dim * max_matches, T::default()); - let matches = T::get(self, key, &mut vector[..])?; - vector.resize(dim * matches, T::default()); - Ok(matches) - } - - /// Reserves memory for a specified number of incoming vectors. - /// - /// # Arguments - /// - /// * `capacity` - The desired total capacity, including the current size. - pub fn reserve(self: &Index, capacity: usize) -> Result<(), cxx::Exception> { - self.inner.reserve(capacity) - } - - /// Reserves memory for a specified number of incoming vectors & active threads. - /// - /// # Arguments - /// - /// * `capacity` - The desired total capacity, including the current size. - /// * `threads` - The number of threads to use for the operation. - pub fn reserve_capacity_and_threads( - self: &Index, - capacity: usize, - threads: usize, - ) -> Result<(), cxx::Exception> { - self.inner.reserve_capacity_and_threads(capacity, threads) - } + ) -> Result; /// Retrieves the number of dimensions in the vectors indexed. - pub fn dimensions(self: &Index) -> usize { - self.inner.dimensions() - } + fn dimensions(&self) -> usize; /// Retrieves the connectivity parameter that limits connections-per-node in the graph. - pub fn connectivity(self: &Index) -> usize { - self.inner.connectivity() - } + fn connectivity(&self) -> usize; /// Retrieves the current number of vectors in the index. - pub fn size(self: &Index) -> usize { - self.inner.size() - } + fn size(&self) -> usize; /// Retrieves the total capacity of the index, including reserved space. - pub fn capacity(self: &Index) -> usize { - self.inner.capacity() - } + fn capacity(&self) -> usize; /// Reports expected file size after serialization. - pub fn serialized_length(self: &Index) -> usize { - self.inner.serialized_length() - } + fn serialized_length(&self) -> usize; - /// Removes the vector associated with the given key from the index. + /// Checks if the index contains a vector with a specified key. /// /// # Arguments /// - /// * `key` - The key of the vector to be removed. + /// * `key` - The key to be checked. /// /// # Returns /// - /// `true` if the vector is successfully removed, `false` otherwise. - pub fn remove(self: &Index, key: Key) -> Result { - self.inner.remove(key) - } + /// `true` if the index contains the vector with the given key, `false` otherwise. + fn contains(&self, key: Key) -> bool; - /// Renames the vector under a specific key. + /// Count the count of vectors with the same specified key. /// /// # Arguments /// - /// * `from` - The key of the vector to be renamed. - /// * `to` - The new name. + /// * `key` - The key to be checked. /// /// # Returns /// - /// `true` if the vector is renamed, `false` otherwise. - pub fn rename(self: &Index, from: Key, to: Key) -> Result { - self.inner.rename(from, to) - } + /// Number of vectors found. + fn count(&self, key: Key) -> usize; - /// Checks if the index contains a vector with a specified key. + /// Saves the index to a specified file. /// /// # Arguments /// - /// * `key` - The key to be checked. + /// * `path` - The file path where the index will be saved. + fn save(&self, path: &str) -> Result<(), cxx::Exception>; + + /// A relatively accurate lower bound on the amount of memory consumed by the system. + /// In practice, its error will be below 10%. + fn memory_usage(&self) -> usize; + + /// Saves the index to a specified file. /// - /// # Returns + /// # Arguments /// - /// `true` if the index contains the vector with the given key, `false` otherwise. - pub fn contains(self: &Index, key: Key) -> bool { - self.inner.contains(key) - } + /// * `buffer` - The buffer where the index will be saved. + fn save_to_buffer(&self, buffer: &mut [u8]) -> Result<(), cxx::Exception>; +} - /// Count the count of vectors with the same specified key. +pub trait IndexMethods: IndexViewMethods { + /// Updates the expansion value used during index creation. Rarely used. + fn change_expansion_add(&self, n: usize); + + /// Updates the expansion value used during search operations. + fn change_expansion_search(&self, n: usize); + + /// Changes the metric kind used to calculate the distance between vectors. + fn change_metric_kind(self: &Self, metric: ffi::MetricKind); + + /// Overrides the metric function used to calculate the distance between vectors. + fn change_metric( + &mut self, + metric: std::boxed::Box Distance + Send + Sync>, + ) -> Result<(), IndexOperationError>; + + /// Adds a vector with a specified key to the index. /// /// # Arguments /// - /// * `key` - The key to be checked. + /// * `key` - The key associated with the vector. + /// * `vector` - A slice containing the vector data. + fn add(&self, key: Key, vector: &[T]) -> Result<(), IndexOperationError>; + + /// Reserves memory for a specified number of incoming vectors. + /// + /// # Arguments + /// + /// * `capacity` - The desired total capacity, including the current size. + fn reserve(&self, capacity: usize) -> Result<(), cxx::Exception>; + + /// Reserves memory for a specified number of incoming vectors & active threads. + /// + /// # Arguments + /// + /// * `capacity` - The desired total capacity, including the current size. + /// * `threads` - The number of threads to use for the operation. + fn reserve_capacity_and_threads( + &self, + capacity: usize, + threads: usize, + ) -> Result<(), cxx::Exception>; + + /// Removes the vector associated with the given key from the index. + /// + /// # Arguments + /// + /// * `key` - The key of the vector to be removed. /// /// # Returns /// - /// Number of vectors found. - pub fn count(self: &Index, key: Key) -> usize { - self.inner.count(key) - } + /// `true` if the vector is successfully removed, `false` otherwise. + fn remove(&self, key: Key) -> Result; - /// Saves the index to a specified file. + /// Renames the vector under a specific key. /// /// # Arguments /// - /// * `path` - The file path where the index will be saved. - pub fn save(self: &Index, path: &str) -> Result<(), cxx::Exception> { - self.inner.save(path) - } + /// * `from` - The key of the vector to be renamed. + /// * `to` - The new name. + /// + /// # Returns + /// + /// `true` if the vector is renamed, `false` otherwise. + fn rename(&self, from: Key, to: Key) -> Result; /// Loads the index from a specified file. /// /// # Arguments /// /// * `path` - The file path from where the index will be loaded. - pub fn load(self: &Index, path: &str) -> Result<(), cxx::Exception> { - self.inner.load(path) + fn load(&self, path: &str) -> Result<(), cxx::Exception>; + + /// Erases all members from the index, closes files, and returns RAM to OS. + fn reset(&self) -> Result<(), cxx::Exception>; +} + +impl IndexViewMethods for Index { + fn expansion_add(&self) -> usize { + self.inner.expansion_add() } - /// Creates a view of the index from a file without loading it into memory. - /// - /// # Arguments - /// - /// * `path` - The file path from where the view will be created. - pub fn view(self: &Index, path: &str) -> Result<(), cxx::Exception> { - self.inner.view(path) + fn expansion_search(&self) -> usize { + self.inner.expansion_search() } - /// Erases all members from the index, closes files, and returns RAM to OS. - pub fn reset(self: &Index) -> Result<(), cxx::Exception> { - self.inner.reset() + fn hardware_acceleration(&self) -> String { + use core::ffi::CStr; + unsafe { + let c_str = CStr::from_ptr(self.inner.hardware_acceleration()); + c_str.to_string_lossy().into_owned() + } } - /// A relatively accurate lower bound on the amount of memory consumed by the system. - /// In practice, its error will be below 10%. - pub fn memory_usage(self: &Index) -> usize { + fn search( + &self, + query: &[T], + count: usize, + ) -> Result { + T::search(self, query, count) + } + + fn exact_search( + &self, + query: &[T], + count: usize, + ) -> Result { + T::exact_search(self, query, count) + } + + fn filtered_search( + &self, + query: &[T], + count: usize, + filter: F, + ) -> Result + where + F: Fn(Key) -> bool, + { + T::filtered_search(self, query, count, filter) + } + + fn get(&self, key: Key, vector: &mut [T]) -> Result { + T::get(self, key, vector) + } + + fn export( + &self, + key: Key, + vector: &mut Vec, + ) -> Result { + let dim = self.dimensions(); + let max_matches = self.count(key); + vector.resize(dim * max_matches, T::default()); + let matches = T::get(self, key, &mut vector[..])?; + vector.resize(dim * matches, T::default()); + Ok(matches) + } + + fn dimensions(&self) -> usize { + self.inner.dimensions() + } + + fn connectivity(&self) -> usize { + self.inner.connectivity() + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn capacity(&self) -> usize { + self.inner.capacity() + } + + fn serialized_length(&self) -> usize { + self.inner.serialized_length() + } + + fn contains(&self, key: Key) -> bool { + self.inner.contains(key) + } + + fn count(&self, key: Key) -> usize { + self.inner.count(key) + } + + fn save(&self, path: &str) -> Result<(), cxx::Exception> { + self.inner.save(path) + } + + fn memory_usage(&self) -> usize { self.inner.memory_usage() } - /// Saves the index to a specified file. - /// - /// # Arguments - /// - /// * `buffer` - The buffer where the index will be saved. - pub fn save_to_buffer(self: &Index, buffer: &mut [u8]) -> Result<(), cxx::Exception> { + fn save_to_buffer(&self, buffer: &mut [u8]) -> Result<(), cxx::Exception> { self.inner.save_to_buffer(buffer) } +} + +impl IndexMethods for Index { + fn change_expansion_add(&self, n: usize) { + self.inner.change_expansion_add(n) + } + + fn change_expansion_search(&self, n: usize) { + self.inner.change_expansion_search(n) + } + + fn change_metric_kind(&self, metric: ffi::MetricKind) { + self.inner.change_metric_kind(metric) + } + + fn change_metric( + &mut self, + metric: std::boxed::Box Distance + Send + Sync>, + ) -> Result<(), IndexOperationError> { + T::change_metric(self, metric) + } + + fn add(&self, key: Key, vector: &[T]) -> Result<(), IndexOperationError> { + T::add(self, key, vector) + } + + fn reserve(&self, capacity: usize) -> Result<(), cxx::Exception> { + self.inner.reserve(capacity) + } + + fn reserve_capacity_and_threads( + &self, + capacity: usize, + threads: usize, + ) -> Result<(), cxx::Exception> { + self.inner.reserve_capacity_and_threads(capacity, threads) + } + + fn remove(&self, key: Key) -> Result { + self.inner.remove(key) + } + + fn rename(&self, from: Key, to: Key) -> Result { + self.inner.rename(from, to) + } + + fn load(&self, path: &str) -> Result<(), cxx::Exception> { + self.inner.load(path) + } + + fn reset(&self) -> Result<(), cxx::Exception> { + self.inner.reset() + } +} + +impl Index { + pub fn new(options: &ffi::IndexOptions) -> Result { + match ffi::new_native_index(options) { + Ok(inner) => Result::Ok(Self { + inner, + scalar_kind: options.quantization, + metric_fn: None, + }), + Err(err) => Err(err), + } + } /// Loads the index from a specified file. /// /// # Arguments /// /// * `buffer` - The buffer from where the index will be loaded. - pub fn load_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> { + pub fn load_from_buffer(&self, buffer: &[u8]) -> Result<(), cxx::Exception> { self.inner.load_from_buffer(buffer) } +} + +impl<'buf> IndexViewMethods for IndexView<'buf> { + fn expansion_add(&self) -> usize { + self.inner.expansion_add() + } + + fn expansion_search(&self) -> usize { + self.inner.expansion_search() + } + + fn hardware_acceleration(&self) -> String { + self.inner.hardware_acceleration() + } + + fn search( + &self, + query: &[T], + count: usize, + ) -> Result { + self.inner.search(query, count) + } + + fn exact_search( + &self, + query: &[T], + count: usize, + ) -> Result { + self.inner.exact_search(query, count) + } + + fn filtered_search( + &self, + query: &[T], + count: usize, + filter: F, + ) -> Result + where + F: Fn(Key) -> bool, + { + self.inner.filtered_search(query, count, filter) + } + + fn get(&self, key: Key, vector: &mut [T]) -> Result { + self.inner.get(key, vector) + } + + fn export( + &self, + key: Key, + vector: &mut Vec, + ) -> Result { + self.inner.export(key, vector) + } + fn dimensions(&self) -> usize { + self.inner.dimensions() + } + + fn connectivity(&self) -> usize { + self.inner.connectivity() + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn capacity(&self) -> usize { + self.inner.capacity() + } + + fn serialized_length(&self) -> usize { + self.inner.serialized_length() + } + + fn contains(&self, key: Key) -> bool { + self.inner.contains(key) + } + + fn count(&self, key: Key) -> usize { + self.inner.count(key) + } + + fn save(&self, path: &str) -> Result<(), cxx::Exception> { + self.inner.save(path) + } + + fn memory_usage(&self) -> usize { + self.inner.memory_usage() + } + + fn save_to_buffer(&self, buffer: &mut [u8]) -> Result<(), cxx::Exception> { + self.inner.save_to_buffer(buffer) + } +} + +impl<'buf> IndexView<'buf> { /// Creates a view of the index from a file without loading it into memory. /// /// # Arguments /// /// * `buffer` - The buffer from where the view will be created. + pub fn new_from_buffer(buffer: &'buf [u8]) -> Result { + let inner = Index::new(&IndexOptions::default())?; + inner.inner.view_from_buffer(buffer)?; + Ok(IndexView { + inner, + _phantom_data: PhantomData, + }) + } +} + +impl IndexView<'static> { + /// Creates a view of the index from a file without loading it into memory. /// - /// # Safety - /// - /// This function is marked as `unsafe` because it stores a pointer to the input buffer. - /// The caller must ensure that the buffer outlives the index and is not dropped - /// or modified for the duration of the index's use. Dereferencing a pointer to a - /// temporary buffer after it has been dropped can lead to undefined behavior, - /// which violates Rust's memory safety guarantees. - /// - /// Example of misuse: - /// - /// ```rust,ignore - /// let index: usearch::Index = usearch::new_index(&usearch::IndexOptions::default()).unwrap(); - /// - /// let temporary = vec![0u8; 100]; - /// index.view_from_buffer(&temporary); - /// std::mem::drop(temporary); - /// - /// let query = vec![0.0; 256]; - /// let results = index.search(&query, 5).unwrap(); - /// ``` + /// # Arguments /// - /// The above example would result in use-after-free and undefined behavior. - pub unsafe fn view_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> { - self.inner.view_from_buffer(buffer) + /// * `path` - The file path from where the view will be created. + pub fn new_from_file(path: &str) -> Result { + let inner = Index::new(&IndexOptions::default())?; + inner.inner.view(path)?; + Ok(IndexView { + inner, + _phantom_data: PhantomData, + }) } } @@ -1686,8 +1867,9 @@ mod tests { use crate::b1x8; use crate::new_index; - use crate::Index; use crate::Key; + use crate::{Index, IndexView}; + use crate::{IndexMethods, IndexViewMethods}; use std::env; @@ -1939,7 +2121,13 @@ mod tests { // Validate serialization assert!(index.save("index.rust.usearch").is_ok()); assert!(index.load("index.rust.usearch").is_ok()); - assert!(index.view("index.rust.usearch").is_ok()); + + let index_view = IndexView::new_from_file("index.rust.usearch").unwrap(); + let results = index_view.search(&first, 10).unwrap(); + println!("{:?}", results); + assert_eq!(results.keys.len(), 2); + let mut out = [0f32; 5]; + assert!(index_view.get(43, &mut out).is_ok()); // Make sure every function is called at least once assert!(new_index(&options).is_ok()); @@ -2320,4 +2508,15 @@ mod tests { "All searches should find exact matches" ); } + + #[test] + fn test_index_file_view_is_sync() { + #[allow(unused)] + fn assert_sync() {} + #[allow(unused)] + fn assert_send() {} + + assert_sync::>(); + assert_send::>(); + } }