@@ -30,36 +30,30 @@ use vortex_array::ArrayRef;
3030use vortex_array:: ExecutionCtx ;
3131use vortex_array:: IntoArray ;
3232use vortex_array:: VortexSessionExecute ;
33- use vortex_array:: arrays:: ConstantArray ;
3433use vortex_array:: arrays:: Extension ;
3534use vortex_array:: arrays:: ExtensionArray ;
3635use vortex_array:: arrays:: FixedSizeListArray ;
3736use vortex_array:: arrays:: PrimitiveArray ;
3837use vortex_array:: arrays:: extension:: ExtensionArrayExt ;
3938use vortex_array:: arrays:: fixed_size_list:: FixedSizeListArrayExt ;
4039use vortex_array:: arrays:: scalar_fn:: ScalarFnArrayExt ;
41- use vortex_array:: builtins:: ArrayBuiltins ;
42- use vortex_array:: dtype:: DType ;
43- use vortex_array:: dtype:: Nullability ;
44- use vortex_array:: dtype:: PType ;
4540use vortex_array:: dtype:: extension:: ExtDType ;
4641use vortex_array:: extension:: EmptyMetadata ;
47- use vortex_array:: scalar:: Scalar ;
48- use vortex_array:: scalar_fn:: fns:: operators:: Operator ;
4942use vortex_array:: session:: ArraySession ;
5043use vortex_array:: validity:: Validity ;
5144use vortex_btrblocks:: BtrBlocksCompressor ;
5245use vortex_buffer:: BufferMut ;
5346use vortex_error:: VortexExpect ;
5447use vortex_error:: VortexResult ;
48+ use vortex_error:: vortex_bail;
5549use vortex_error:: vortex_panic;
5650use vortex_session:: VortexSession ;
5751use vortex_tensor:: encodings:: turboquant:: TurboQuantConfig ;
5852use vortex_tensor:: encodings:: turboquant:: turboquant_encode_unchecked;
59- use vortex_tensor:: scalar_fns:: cosine_similarity:: CosineSimilarity ;
6053use vortex_tensor:: scalar_fns:: l2_denorm:: L2Denorm ;
6154use vortex_tensor:: scalar_fns:: l2_denorm:: normalize_as_l2_denorm;
6255use vortex_tensor:: vector:: Vector ;
56+ pub use vortex_tensor:: vector_search:: build_similarity_search_tree;
6357
6458/// A shared [`VortexSession`] pre-loaded with the builtin [`ArraySession`] so both bench and
6559/// example can create execution contexts cheaply.
@@ -146,65 +140,16 @@ pub fn extract_row_as_query(vectors: &ArrayRef, row: usize, dim: u32) -> Vec<f32
146140 slice[ start..start + dim_usize] . to_vec ( )
147141}
148142
149- /// Build a `Vector<dim, f32>` extension array whose storage is a [`ConstantArray`] broadcasting a
150- /// single query vector across `num_rows` rows. This is how we hand a single query vector to
151- /// `CosineSimilarity` on the `rhs` side -- `ScalarFnArray` requires both children to have the
152- /// same length, so we broadcast the query instead of hand-rolling a 1-row input.
153- fn build_constant_query_vector ( query : & [ f32 ] , num_rows : usize ) -> VortexResult < ArrayRef > {
154- let element_dtype = DType :: Primitive ( PType :: F32 , Nullability :: NonNullable ) ;
155-
156- let children: Vec < Scalar > = query
157- . iter ( )
158- . map ( |& v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
159- . collect ( ) ;
160- let storage_scalar = Scalar :: fixed_size_list ( element_dtype, children, Nullability :: NonNullable ) ;
161-
162- let storage = ConstantArray :: new ( storage_scalar, num_rows) . into_array ( ) ;
163-
164- let ext_dtype = ExtDType :: < Vector > :: try_new ( EmptyMetadata , storage. dtype ( ) . clone ( ) ) ?. erased ( ) ;
165- Ok ( ExtensionArray :: new ( ext_dtype, storage) . into_array ( ) )
166- }
167-
168- /// Compresses a raw `Vector<dim, f32>` array with the default BtrBlocks pipeline.
169- ///
170- /// [`BtrBlocksCompressor`] walks into the extension array and recursively compresses the
171- /// underlying FSL storage child. TurboQuant is *not* exercised by this path -- it is not
172- /// registered in the default scheme set -- so this measures "generic" lossless compression
173- /// applied to float vectors.
174- pub fn compress_default ( data : ArrayRef ) -> VortexResult < ArrayRef > {
175- BtrBlocksCompressor :: default ( ) . compress ( & data)
176- }
177-
178- /// Compresses a raw `Vector<dim, f32>` array with the TurboQuant pipeline by hand, producing the
179- /// same tree shape that
180- /// [`vortex_tensor::encodings::turboquant::TurboQuantScheme`] would:
181- ///
182- /// ```text
183- /// L2Denorm(SorfTransform(FSL(Dict(codes, centroids))), norms)
184- /// ```
185- ///
186- /// Calling the encode helpers directly (instead of going through
187- /// `BtrBlocksCompressorBuilder::with_turboquant()`) lets this example avoid depending on the
188- /// `unstable_encodings` feature flag.
189- ///
190- /// See `vortex-tensor/src/encodings/turboquant/tests/mod.rs::normalize_and_encode` for the same
191- /// canonical recipe.
192- pub fn compress_turboquant ( data : ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < ArrayRef > {
143+ fn normalize_vectors (
144+ data : ArrayRef ,
145+ ctx : & mut ExecutionCtx ,
146+ ) -> VortexResult < ( ArrayRef , ArrayRef , usize ) > {
193147 let l2_denorm = normalize_as_l2_denorm ( data, ctx) ?;
194148 let normalized = l2_denorm. child_at ( 0 ) . clone ( ) ;
195149 let norms = l2_denorm. child_at ( 1 ) . clone ( ) ;
196150 let num_rows = l2_denorm. len ( ) ;
197151
198- let normalized_ext = normalized
199- . as_opt :: < Extension > ( )
200- . vortex_expect ( "normalized child should be an Extension array" ) ;
201-
202- let config = TurboQuantConfig :: default ( ) ;
203- // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero), which is the
204- // invariant `turboquant_encode_unchecked` expects.
205- let tq = unsafe { turboquant_encode_unchecked ( normalized_ext, & config, ctx) } ?;
206-
207- Ok ( unsafe { L2Denorm :: new_array_unchecked ( tq, norms, num_rows) } ?. into_array ( ) )
152+ Ok ( ( normalized, norms, num_rows) )
208153}
209154
210155/// Dispatch helper that builds the data array for the requested [`Variant`], starting from a
@@ -220,37 +165,24 @@ pub fn build_variant(
220165 let raw = generate_random_vectors ( num_rows, dim, seed) ;
221166 match variant {
222167 Variant :: Uncompressed => Ok ( raw) ,
223- Variant :: DefaultCompression => compress_default ( raw) ,
224- Variant :: TurboQuant => compress_turboquant ( raw, ctx) ,
225- }
226- }
227-
228- /// Build the lazy similarity-search array tree for a prepared data array and a single query
229- /// vector. The returned tree is a boolean array of length `data.len()` where position `i` is
230- /// `true` iff `cosine_similarity(data[i], query) > threshold`.
231- ///
232- /// The tree shape is:
233- ///
234- /// ```text
235- /// Binary(Gt, [
236- /// CosineSimilarity([data, ConstantArray(query_vec, n)]),
237- /// ConstantArray(threshold, n),
238- /// ])
239- /// ```
240- ///
241- /// This function does no execution; it is safe to call inside a benchmark setup closure.
242- pub fn build_similarity_search_tree (
243- data : ArrayRef ,
244- query : & [ f32 ] ,
245- threshold : f32 ,
246- ) -> VortexResult < ArrayRef > {
247- let num_rows = data. len ( ) ;
248- let query_vec = build_constant_query_vector ( query, num_rows) ?;
249-
250- let cosine = CosineSimilarity :: try_new_array ( data, query_vec, num_rows) ?. into_array ( ) ;
251-
252- let threshold_scalar = Scalar :: primitive ( threshold, Nullability :: NonNullable ) ;
253- let threshold_array = ConstantArray :: new ( threshold_scalar, num_rows) . into_array ( ) ;
168+ Variant :: DefaultCompression => {
169+ let ( normalized, norms, num_rows) = normalize_vectors ( raw, ctx) ?;
170+ let compressed = BtrBlocksCompressor :: default ( ) . compress ( & normalized) ?;
254171
255- cosine. binary ( threshold_array, Operator :: Gt )
172+ Ok ( unsafe { L2Denorm :: new_array_unchecked ( compressed, norms, num_rows) } ?. into_array ( ) )
173+ }
174+ Variant :: TurboQuant => {
175+ let ( normalized, norms, num_rows) = normalize_vectors ( raw, ctx) ?;
176+ let Some ( normalized_ext) = normalized. as_opt :: < Extension > ( ) else {
177+ vortex_bail ! ( "normalize_as_l2_denorm must produce an Extension array child" ) ;
178+ } ;
179+
180+ let config = TurboQuantConfig :: default ( ) ;
181+ // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero),
182+ // which is the invariant `turboquant_encode_unchecked` expects.
183+ let tq = unsafe { turboquant_encode_unchecked ( normalized_ext, & config, ctx) } ?;
184+
185+ Ok ( unsafe { L2Denorm :: new_array_unchecked ( tq, norms, num_rows) } ?. into_array ( ) )
186+ }
187+ }
256188}
0 commit comments