@@ -16,82 +16,24 @@ use vortex_error::vortex_bail;
1616use vortex_error:: vortex_ensure;
1717use vortex_fastlanes:: bitpack_compress:: bitpack_encode;
1818
19- use crate :: array:: TurboQuantArray ;
20- use crate :: array:: TurboQuantVariant ;
2119use crate :: centroids:: find_nearest_centroid;
2220use crate :: centroids:: get_centroids;
23- use crate :: mse_array :: TurboQuantMSEArray ;
24- use crate :: qjl_array :: TurboQuantQJLArray ;
21+ use crate :: mse :: array :: TurboQuantMSEArray ;
22+ use crate :: qjl :: array :: TurboQuantQJLArray ;
2523use crate :: rotation:: RotationMatrix ;
2624
2725/// Configuration for TurboQuant encoding.
2826#[ derive( Clone , Debug ) ]
2927pub struct TurboQuantConfig {
30- /// Bits per coordinate (1-4).
28+ /// Bits per coordinate.
29+ ///
30+ /// For MSE encoding: 1-8.
31+ /// For QJL encoding: 2-9 (the MSE inner uses `bit_width - 1`).
3132 pub bit_width : u8 ,
32- /// Which variant to use.
33- pub variant : TurboQuantVariant ,
3433 /// Optional seed for the rotation matrix. If None, a random seed is generated.
3534 pub seed : Option < u64 > ,
3635}
3736
38- /// Encode a FixedSizeListArray of floats into a TurboQuantArray.
39- ///
40- /// The input should be the storage array of a Vector or FixedShapeTensor extension type.
41- /// Each row (fixed-size-list element) is treated as a d-dimensional vector to quantize.
42- pub fn turboquant_encode (
43- fsl : & FixedSizeListArray ,
44- config : & TurboQuantConfig ,
45- ) -> VortexResult < TurboQuantArray > {
46- match config. variant {
47- TurboQuantVariant :: Mse => vortex_ensure ! (
48- config. bit_width >= 1 && config. bit_width <= 8 ,
49- "MSE variant bit_width must be 1-8, got {}" ,
50- config. bit_width
51- ) ,
52- TurboQuantVariant :: Prod => vortex_ensure ! (
53- config. bit_width >= 2 && config. bit_width <= 9 ,
54- "Prod variant bit_width must be 2-9, got {}" ,
55- config. bit_width
56- ) ,
57- }
58-
59- let dimension = fsl. list_size ( ) ;
60- vortex_ensure ! (
61- dimension >= 2 ,
62- "TurboQuant requires dimension >= 2, got {dimension}"
63- ) ;
64- let num_rows = fsl. len ( ) ;
65-
66- if num_rows == 0 {
67- return encode_empty ( fsl, config, dimension) ;
68- }
69-
70- let seed = config. seed . unwrap_or_else ( rand:: random) ;
71-
72- // Extract flat f32 elements from the FixedSizeListArray.
73- let f32_elements = extract_f32_elements ( fsl) ?;
74-
75- match config. variant {
76- TurboQuantVariant :: Mse => encode_mse (
77- & f32_elements,
78- num_rows,
79- dimension,
80- config. bit_width ,
81- seed,
82- fsl,
83- ) ,
84- TurboQuantVariant :: Prod => encode_prod (
85- & f32_elements,
86- num_rows,
87- dimension,
88- config. bit_width ,
89- seed,
90- fsl,
91- ) ,
92- }
93- }
94-
9537/// Extract elements from a FixedSizeListArray as a flat f32 vec.
9638#[ allow( clippy:: cast_possible_truncation) ]
9739fn extract_f32_elements ( fsl : & FixedSizeListArray ) -> VortexResult < Vec < f32 > > {
@@ -110,231 +52,12 @@ fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult<Vec<f32>> {
11052 }
11153}
11254
113- fn encode_empty (
114- fsl : & FixedSizeListArray ,
115- config : & TurboQuantConfig ,
116- dimension : u32 ,
117- ) -> VortexResult < TurboQuantArray > {
118- let seed = config. seed . unwrap_or ( 0 ) ;
119- let codes = PrimitiveArray :: empty :: < u8 > ( fsl. dtype ( ) . nullability ( ) ) ;
120- let norms = PrimitiveArray :: empty :: < f32 > ( fsl. dtype ( ) . nullability ( ) ) ;
121-
122- match config. variant {
123- TurboQuantVariant :: Mse => TurboQuantArray :: try_new_mse (
124- fsl. dtype ( ) . clone ( ) ,
125- codes. into_array ( ) ,
126- norms. into_array ( ) ,
127- dimension,
128- config. bit_width ,
129- seed,
130- ) ,
131- TurboQuantVariant :: Prod => {
132- let qjl_signs = PrimitiveArray :: empty :: < u8 > ( fsl. dtype ( ) . nullability ( ) ) ;
133- let residual_norms = PrimitiveArray :: empty :: < f32 > ( fsl. dtype ( ) . nullability ( ) ) ;
134- TurboQuantArray :: try_new_prod (
135- fsl. dtype ( ) . clone ( ) ,
136- codes. into_array ( ) ,
137- norms. into_array ( ) ,
138- qjl_signs. into_array ( ) ,
139- residual_norms. into_array ( ) ,
140- dimension,
141- config. bit_width ,
142- seed,
143- )
144- }
145- }
146- }
147-
148- fn encode_mse (
149- elements : & [ f32 ] ,
150- num_rows : usize ,
151- dimension : u32 ,
152- bit_width : u8 ,
153- seed : u64 ,
154- fsl : & FixedSizeListArray ,
155- ) -> VortexResult < TurboQuantArray > {
156- let dim = dimension as usize ;
157- let rotation = RotationMatrix :: try_new ( seed, dim) ?;
158- let padded_dim = rotation. padded_dim ( ) ;
159- #[ allow( clippy:: cast_possible_truncation) ]
160- let centroids = get_centroids ( padded_dim as u32 , bit_width) ?;
161-
162- let mut all_indices = BufferMut :: < u8 > :: with_capacity ( num_rows * padded_dim) ;
163- let mut norms_buf = BufferMut :: < f32 > :: with_capacity ( num_rows) ;
164-
165- let mut padded = vec ! [ 0.0f32 ; padded_dim] ;
166- let mut rotated = vec ! [ 0.0f32 ; padded_dim] ;
167-
168- for row in 0 ..num_rows {
169- let x = & elements[ row * dim..( row + 1 ) * dim] ;
170-
171- let norm = l2_norm ( x) ;
172- norms_buf. push ( norm) ;
173-
174- // Normalize, zero-pad to padded_dim, and rotate.
175- padded. fill ( 0.0 ) ;
176- if norm > 0.0 {
177- let inv_norm = 1.0 / norm;
178- for ( dst, & src) in padded[ ..dim] . iter_mut ( ) . zip ( x. iter ( ) ) {
179- * dst = src * inv_norm;
180- }
181- }
182- rotation. rotate ( & padded, & mut rotated) ;
183-
184- // Quantize all padded_dim coordinates.
185- for j in 0 ..padded_dim {
186- all_indices. push ( find_nearest_centroid ( rotated[ j] , & centroids) ) ;
187- }
188- }
189-
190- // Pack indices: bitpack for 1-7 bits, store raw u8 for 8 bits.
191- let indices_array = PrimitiveArray :: new :: < u8 > ( all_indices. freeze ( ) , Validity :: NonNullable ) ;
192- let codes = if bit_width < 8 {
193- bitpack_encode ( & indices_array, bit_width, None ) ?. into_array ( )
194- } else {
195- indices_array. into_array ( )
196- } ;
197-
198- let norms_array = PrimitiveArray :: new :: < f32 > ( norms_buf. freeze ( ) , Validity :: NonNullable ) ;
199-
200- TurboQuantArray :: try_new_mse (
201- fsl. dtype ( ) . clone ( ) ,
202- codes,
203- norms_array. into_array ( ) ,
204- dimension,
205- bit_width,
206- seed,
207- )
208- }
209-
210- fn encode_prod (
211- elements : & [ f32 ] ,
212- num_rows : usize ,
213- dimension : u32 ,
214- bit_width : u8 ,
215- seed : u64 ,
216- fsl : & FixedSizeListArray ,
217- ) -> VortexResult < TurboQuantArray > {
218- let dim = dimension as usize ;
219- let mse_bit_width = bit_width - 1 ;
220-
221- let rotation = RotationMatrix :: try_new ( seed, dim) ?;
222- let padded_dim = rotation. padded_dim ( ) ;
223- #[ allow( clippy:: cast_possible_truncation) ]
224- let centroids = get_centroids ( padded_dim as u32 , mse_bit_width) ?;
225-
226- let mut all_indices = BufferMut :: < u8 > :: with_capacity ( num_rows * padded_dim) ;
227- let mut norms_buf = BufferMut :: < f32 > :: with_capacity ( num_rows) ;
228- let mut residual_norms_buf = BufferMut :: < f32 > :: with_capacity ( num_rows) ;
229-
230- // QJL sign bits: num_rows * padded_dim bits, packed into bytes.
231- let total_sign_bits = num_rows * padded_dim;
232- let sign_byte_count = total_sign_bits. div_ceil ( 8 ) ;
233- let mut sign_buf = BufferMut :: < u8 > :: with_capacity ( sign_byte_count) ;
234- sign_buf. extend ( std:: iter:: repeat_n ( 0u8 , sign_byte_count) ) ;
235- let sign_slice = sign_buf. as_mut_slice ( ) ;
236-
237- let mut padded = vec ! [ 0.0f32 ; padded_dim] ;
238- let mut rotated = vec ! [ 0.0f32 ; padded_dim] ;
239- let mut dequantized_rotated = vec ! [ 0.0f32 ; padded_dim] ;
240- let mut dequantized = vec ! [ 0.0f32 ; padded_dim] ;
241- let mut residual = vec ! [ 0.0f32 ; padded_dim] ;
242- let mut projected = vec ! [ 0.0f32 ; padded_dim] ;
243-
244- // QJL random sign matrix generator (using seed + 1).
245- let qjl_rotation = RotationMatrix :: try_new ( seed. wrapping_add ( 1 ) , dim) ?;
246-
247- for row in 0 ..num_rows {
248- let x = & elements[ row * dim..( row + 1 ) * dim] ;
249-
250- let norm = l2_norm ( x) ;
251- norms_buf. push ( norm) ;
252-
253- // Normalize, zero-pad, and rotate.
254- padded. fill ( 0.0 ) ;
255- if norm > 0.0 {
256- let inv_norm = 1.0 / norm;
257- for ( dst, & src) in padded[ ..dim] . iter_mut ( ) . zip ( x. iter ( ) ) {
258- * dst = src * inv_norm;
259- }
260- }
261- rotation. rotate ( & padded, & mut rotated) ;
262-
263- // MSE quantize at (bit_width - 1) bits over padded_dim coordinates.
264- for j in 0 ..padded_dim {
265- let idx = find_nearest_centroid ( rotated[ j] , & centroids) ;
266- all_indices. push ( idx) ;
267- dequantized_rotated[ j] = centroids[ idx as usize ] ;
268- }
269-
270- // Dequantize MSE result (inverse rotate to full padded space, take first dim).
271- rotation. inverse_rotate ( & dequantized_rotated, & mut dequantized) ;
272- if norm > 0.0 {
273- for val in & mut dequantized {
274- * val *= norm;
275- }
276- }
277-
278- // Compute residual r = x - x_hat_mse (only first dim elements matter).
279- residual. fill ( 0.0 ) ;
280- for j in 0 ..dim {
281- residual[ j] = x[ j] - dequantized[ j] ;
282- }
283- let residual_norm = l2_norm ( & residual[ ..dim] ) ;
284- residual_norms_buf. push ( residual_norm) ;
285-
286- // QJL: sign(S * r).
287- projected. fill ( 0.0 ) ;
288- if residual_norm > 0.0 {
289- qjl_rotation. rotate ( & residual, & mut projected) ;
290- }
291-
292- // Store sign bits for padded_dim positions.
293- let bit_offset = row * padded_dim;
294- for j in 0 ..padded_dim {
295- if projected[ j] >= 0.0 {
296- let bit_idx = bit_offset + j;
297- sign_slice[ bit_idx / 8 ] |= 1 << ( bit_idx % 8 ) ;
298- }
299- }
300- }
301-
302- // Pack MSE indices: bitpack for 1-7 bits, store raw u8 for 8 bits.
303- let indices_array = PrimitiveArray :: new :: < u8 > ( all_indices. freeze ( ) , Validity :: NonNullable ) ;
304- let codes = if mse_bit_width < 8 {
305- bitpack_encode ( & indices_array, mse_bit_width, None ) ?. into_array ( )
306- } else {
307- indices_array. into_array ( )
308- } ;
309-
310- let norms_array = PrimitiveArray :: new :: < f32 > ( norms_buf. freeze ( ) , Validity :: NonNullable ) ;
311- let residual_norms_array =
312- PrimitiveArray :: new :: < f32 > ( residual_norms_buf. freeze ( ) , Validity :: NonNullable ) ;
313-
314- let qjl_signs = PrimitiveArray :: new :: < u8 > ( sign_buf. freeze ( ) , Validity :: NonNullable ) ;
315-
316- TurboQuantArray :: try_new_prod (
317- fsl. dtype ( ) . clone ( ) ,
318- codes,
319- norms_array. into_array ( ) ,
320- qjl_signs. into_array ( ) ,
321- residual_norms_array. into_array ( ) ,
322- dimension,
323- bit_width,
324- seed,
325- )
326- }
327-
32855/// Compute the L2 norm of a vector.
32956#[ inline]
33057fn l2_norm ( x : & [ f32 ] ) -> f32 {
33158 x. iter ( ) . map ( |& v| v * v) . sum :: < f32 > ( ) . sqrt ( )
33259}
33360
334- // ---------------------------------------------------------------------------
335- // New encoding producing cascaded MSE/QJL arrays
336- // ---------------------------------------------------------------------------
337-
33861/// Encode a FixedSizeListArray into a `TurboQuantMSEArray`.
33962pub fn turboquant_encode_mse (
34063 fsl : & FixedSizeListArray ,
@@ -390,7 +113,7 @@ pub fn turboquant_encode_mse(
390113 }
391114 }
392115
393- // Pack indices.
116+ // Pack indices: bitpack for 1-7 bits, store raw u8 for 8 bits .
394117 let indices_array = PrimitiveArray :: new :: < u8 > ( all_indices. freeze ( ) , Validity :: NonNullable ) ;
395118 let codes = if config. bit_width < 8 {
396119 bitpack_encode ( & indices_array, config. bit_width , None ) ?. into_array ( )
@@ -448,7 +171,6 @@ pub fn turboquant_encode_qjl(
448171 // First, encode the MSE inner at (bit_width - 1).
449172 let mse_config = TurboQuantConfig {
450173 bit_width : mse_bit_width,
451- variant : TurboQuantVariant :: Mse , // legacy field, not used in new path
452174 seed : Some ( seed) ,
453175 } ;
454176 let mse_inner = turboquant_encode_mse ( fsl, & mse_config) ?;
@@ -581,7 +303,6 @@ fn build_empty_qjl_array(
581303) -> VortexResult < TurboQuantQJLArray > {
582304 let mse_config = TurboQuantConfig {
583305 bit_width : bit_width - 1 ,
584- variant : TurboQuantVariant :: Mse ,
585306 seed : Some ( seed) ,
586307 } ;
587308 let mse_inner = turboquant_encode_mse ( fsl, & mse_config) ?;
0 commit comments