2727//! distortion: at 4 bits the error is typically < 0.1, at 8 bits < 0.001.
2828//!
2929//! For approximate nearest neighbor (ANN) search, biased-but-accurate ranking is
30- //! usually sufficient — the relative ordering of cosine similarities is preserved
30+ //! usually sufficient -- the relative ordering of cosine similarities is preserved
3131//! even if the absolute values have bounded error.
3232
33+ use num_traits:: FromPrimitive ;
34+ use num_traits:: Zero ;
3335use vortex_array:: ArrayRef ;
3436use vortex_array:: ArrayView ;
3537use vortex_array:: ExecutionCtx ;
3638use vortex_array:: IntoArray ;
3739use vortex_array:: arrays:: FixedSizeListArray ;
3840use vortex_array:: arrays:: PrimitiveArray ;
41+ use vortex_array:: match_each_float_ptype;
3942use vortex_array:: validity:: Validity ;
4043use vortex_buffer:: BufferMut ;
4144use vortex_error:: VortexResult ;
42- use vortex_error:: vortex_ensure ;
45+ use vortex_error:: vortex_ensure_eq ;
4346
4447use crate :: encodings:: turboquant:: TurboQuant ;
48+ use crate :: utils:: extension_element_ptype;
4549
46- /// Shared helper: read codes, norms, and centroids from two TurboQuant arrays,
47- /// then compute per-row quantized unit-norm dot products.
50+ /// Convert an f32 value to `T`, returning `T::zero()` if the conversion fails.
4851///
49- /// Both arrays must have the same dimension (vector length) and row count.
50- /// They may have different codebooks (e.g., different bit widths), in which
51- /// case each array's own centroids are used for its code lookups.
52+ /// This helper exists because `half::f16` has an inherent `from_f32` method that shadows
53+ /// the [`FromPrimitive`] trait method, causing compilation errors when used inside
54+ /// [`match_each_float_ptype!`].
55+ #[ inline]
56+ fn f32_to_t < T : FromPrimitive + Zero > ( v : f32 ) -> T {
57+ FromPrimitive :: from_f32 ( v) . unwrap_or_else ( T :: zero)
58+ }
59+
60+ /// Compute the per-row unit-norm dot products in f32 (centroids are always f32).
5261///
53- /// Returns `(norms_a, norms_b, unit_dots)` where `unit_dots[i]` is the dot product
54- /// of the unit-norm quantized vectors for row i.
55- fn quantized_unit_dots (
56- lhs : ArrayView < TurboQuant > ,
57- rhs : ArrayView < TurboQuant > ,
62+ /// Returns a `Vec<f32>` of length `num_rows`.
63+ fn compute_unit_dots (
64+ lhs : & ArrayView < TurboQuant > ,
65+ rhs : & ArrayView < TurboQuant > ,
5866 ctx : & mut ExecutionCtx ,
59- ) -> VortexResult < ( Vec < f32 > , Vec < f32 > , Vec < f32 > ) > {
60- vortex_ensure ! (
61- lhs. dimension( ) == rhs. dimension( ) ,
62- "TurboQuant quantized dot product requires matching dimensions, got {} and {}" ,
63- lhs. dimension( ) ,
64- rhs. dimension( )
65- ) ;
66-
67+ ) -> VortexResult < Vec < f32 > > {
6768 let pd = lhs. padded_dim ( ) as usize ;
6869 let num_rows = lhs. norms ( ) . len ( ) ;
6970
70- let lhs_norms: PrimitiveArray = lhs. norms ( ) . clone ( ) . execute ( ctx) ?;
71- let rhs_norms: PrimitiveArray = rhs. norms ( ) . clone ( ) . execute ( ctx) ?;
72- let na = lhs_norms. as_slice :: < f32 > ( ) ;
73- let nb = rhs_norms. as_slice :: < f32 > ( ) ;
74-
7571 let lhs_codes_fsl: FixedSizeListArray = lhs. codes ( ) . clone ( ) . execute ( ctx) ?;
7672 let rhs_codes_fsl: FixedSizeListArray = rhs. codes ( ) . clone ( ) . execute ( ctx) ?;
7773 let lhs_codes = lhs_codes_fsl. elements ( ) . to_canonical ( ) ?. into_primitive ( ) ;
7874 let rhs_codes = rhs_codes_fsl. elements ( ) . to_canonical ( ) ?. into_primitive ( ) ;
7975 let ca = lhs_codes. as_slice :: < u8 > ( ) ;
8076 let cb = rhs_codes. as_slice :: < u8 > ( ) ;
8177
82- // Read centroids from both arrays — they may have different codebooks
83- // (e.g., different bit widths).
78+ // Read centroids from both arrays. They may have different codebooks (e.g., different bit
79+ // widths).
8480 let lhs_centroids: PrimitiveArray = lhs. centroids ( ) . clone ( ) . execute ( ctx) ?;
8581 let rhs_centroids: PrimitiveArray = rhs. centroids ( ) . clone ( ) . execute ( ctx) ?;
8682 let cl = lhs_centroids. as_slice :: < f32 > ( ) ;
@@ -98,49 +94,75 @@ fn quantized_unit_dots(
9894 dots. push ( dot) ;
9995 }
10096
101- Ok ( ( na . to_vec ( ) , nb . to_vec ( ) , dots) )
97+ Ok ( dots)
10298}
10399
104100/// Compute approximate cosine similarity for all rows between two TurboQuant
105101/// arrays (same rotation matrix and codebook) without full decompression.
102+ ///
103+ /// Since TurboQuant stores unit-normalized rotated vectors, the dot product of the quantized
104+ /// codes directly approximates cosine similarity without needing the stored norms.
105+ ///
106+ /// The output dtype matches the Vector's element type (f16, f32, or f64).
106107pub fn cosine_similarity_quantized_column (
107108 lhs : ArrayView < TurboQuant > ,
108109 rhs : ArrayView < TurboQuant > ,
109110 ctx : & mut ExecutionCtx ,
110111) -> VortexResult < ArrayRef > {
111- let num_rows = lhs. norms ( ) . len ( ) ;
112- let ( na, nb, dots) = quantized_unit_dots ( lhs, rhs, ctx) ?;
112+ vortex_ensure_eq ! (
113+ lhs. dimension( ) ,
114+ rhs. dimension( ) ,
115+ "TurboQuant quantized dot product requires matching dimensions" ,
116+ ) ;
113117
114- let mut result = BufferMut :: < f32 > :: with_capacity ( num_rows) ;
115- for row in 0 ..num_rows {
116- if na[ row] == 0.0 || nb[ row] == 0.0 {
117- result. push ( 0.0 ) ;
118- } else {
119- // Unit-norm dot product IS the cosine similarity.
120- result. push ( dots[ row] ) ;
121- }
122- }
118+ let element_ptype = extension_element_ptype ( lhs. dtype ( ) . as_extension ( ) ) ?;
119+ let dots = compute_unit_dots ( & lhs, & rhs, ctx) ?;
123120
124- Ok ( PrimitiveArray :: new :: < f32 > ( result. freeze ( ) , Validity :: NonNullable ) . into_array ( ) )
121+ // The unit-norm dot product IS the cosine similarity. Cast from f32 to the native type.
122+ match_each_float_ptype ! ( element_ptype, |T | {
123+ let mut result = BufferMut :: <T >:: with_capacity( dots. len( ) ) ;
124+ for & dot in & dots {
125+ result. push( f32_to_t( dot) ) ;
126+ }
127+ Ok ( PrimitiveArray :: new:: <T >( result. freeze( ) , Validity :: NonNullable ) . into_array( ) )
128+ } )
125129}
126130
127131/// Compute approximate dot product for all rows between two TurboQuant
128132/// arrays (same rotation matrix and codebook) without full decompression.
129133///
130- /// `dot_product(a, b) ≈ ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])`
134+ /// `dot_product(a, b) = ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])`
135+ ///
136+ /// The output dtype matches the Vector's element type (f16, f32, or f64).
131137pub fn dot_product_quantized_column (
132138 lhs : ArrayView < TurboQuant > ,
133139 rhs : ArrayView < TurboQuant > ,
134140 ctx : & mut ExecutionCtx ,
135141) -> VortexResult < ArrayRef > {
142+ vortex_ensure_eq ! (
143+ lhs. dimension( ) ,
144+ rhs. dimension( ) ,
145+ "TurboQuant quantized dot product requires matching dimensions" ,
146+ ) ;
147+
148+ let element_ptype = extension_element_ptype ( lhs. dtype ( ) . as_extension ( ) ) ?;
149+ let dots = compute_unit_dots ( & lhs, & rhs, ctx) ?;
136150 let num_rows = lhs. norms ( ) . len ( ) ;
137- let ( na, nb, dots) = quantized_unit_dots ( lhs, rhs, ctx) ?;
138151
139- let mut result = BufferMut :: < f32 > :: with_capacity ( num_rows) ;
140- for row in 0 ..num_rows {
141- // Scale the unit-norm dot product by both norms to get the actual dot product.
142- result. push ( na[ row] * nb[ row] * dots[ row] ) ;
143- }
152+ let lhs_norms: PrimitiveArray = lhs. norms ( ) . clone ( ) . execute ( ctx) ?;
153+ let rhs_norms: PrimitiveArray = rhs. norms ( ) . clone ( ) . execute ( ctx) ?;
154+
155+ // Scale the f32 unit-norm dot product by native-precision norms.
156+ match_each_float_ptype ! ( element_ptype, |T | {
157+ let na = lhs_norms. as_slice:: <T >( ) ;
158+ let nb = rhs_norms. as_slice:: <T >( ) ;
159+
160+ let mut result = BufferMut :: <T >:: with_capacity( num_rows) ;
161+ for row in 0 ..num_rows {
162+ let dot_t: T = f32_to_t( dots[ row] ) ;
163+ result. push( na[ row] * nb[ row] * dot_t) ;
164+ }
144165
145- Ok ( PrimitiveArray :: new :: < f32 > ( result. freeze ( ) , Validity :: NonNullable ) . into_array ( ) )
166+ Ok ( PrimitiveArray :: new:: <T >( result. freeze( ) , Validity :: NonNullable ) . into_array( ) )
167+ } )
146168}
0 commit comments