@@ -9,6 +9,7 @@ use vortex_array::stats::ArrayStats;
99use vortex_error:: VortexExpect ;
1010use vortex_error:: VortexResult ;
1111use vortex_error:: vortex_ensure;
12+ use vortex_error:: vortex_ensure_eq;
1213
1314use crate :: encodings:: turboquant:: array:: slots:: Slot ;
1415use crate :: encodings:: turboquant:: vtable:: TurboQuant ;
@@ -17,22 +18,22 @@ use crate::utils::extension_list_size;
1718
1819/// TurboQuant array data.
1920///
20- /// TurboQuant is a lossy vector quantization encoding for [`Vector`] extension arrays.
21- /// It stores quantized coordinate codes and per-vector norms, along with shared codebook
22- /// centroids and SRHT rotation signs. See the [module docs](super) for algorithmic details .
21+ /// TurboQuant is a lossy vector quantization encoding for [`Vector`](crate::vector::Vector)
22+ /// extension arrays. It stores quantized coordinate codes and per-vector norms, along with shared
23+ /// codebook centroids and SRHT rotation signs.
2324///
24- /// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty .
25+ /// See the [module docs](super) for algorithmic details .
2526///
26- /// [`Vector`]: crate::vector::Vector
27+ /// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty.
2728#[ derive( Clone , Debug ) ]
2829pub struct TurboQuantData {
29- /// The [`Vector`] extension dtype that this array encodes. The storage dtype within the
30- /// extension determines the element type (f16, f32, or f64) and the list size (dimension).
30+ /// The [`Vector`](crate::vector::Vector) extension dtype that this array encodes.
3131 ///
32- /// [`Vector`]: crate::vector::Vector
32+ /// The storage dtype within the extension determines the element type (f16, f32, or f64) and
33+ /// the list size (dimension).
3334 pub ( crate ) dtype : DType ,
3435
35- /// Child arrays stored as optional slots. See [`Slot`] for positions:
36+ /// Child arrays stored as slots. See [`Slot`] for positions:
3637 ///
3738 /// - [`Codes`](Slot::Codes): `FixedSizeListArray<u8>` with `list_size == padded_dim`. Each row
3839 /// holds one u8 centroid index per padded coordinate. The cascade compressor handles packing
@@ -53,13 +54,13 @@ pub struct TurboQuantData {
5354 pub ( crate ) slots : Vec < Option < ArrayRef > > ,
5455
5556 /// The vector dimension `d`, cached from the `FixedSizeList` storage dtype's list size.
57+ ///
5658 /// Stored as a convenience field to avoid repeatedly extracting it from `dtype`.
57- /// Non-power-of-2 dimensions are zero-padded to [`padded_dim`](Self::padded_dim) for the
58- /// Walsh-Hadamard transform.
5959 pub ( crate ) dimension : u32 ,
6060
6161 /// The number of bits per coordinate (1-8), derived from `log2(centroids.len())`.
62- /// Zero for degenerate empty arrays.
62+ ///
63+ /// This is 0 for degenerate empty arrays.
6364 pub ( crate ) bit_width : u8 ,
6465
6566 /// The stats for this array.
@@ -100,8 +101,8 @@ impl TurboQuantData {
100101 /// is >= 3.
101102 /// - `codes` is a `FixedSizeListArray<u8>` with `list_size == padded_dim` and
102103 /// `codes.len() == norms.len()`.
103- /// - `norms` is a non-nullable primitive array whose ptype matches the element type of the
104- /// Vector's storage dtype .
104+ /// - `norms` is a primitive array whose ptype matches the element type of the Vector's storage
105+ /// dtype. This must match the validity of the `codes` array .
105106 /// - `centroids` is a non-nullable `PrimitiveArray<f32>` whose length is a power of 2 in
106107 /// `[2, 256]` (i.e., `2^bit_width` for bit_width 1-8), or empty for degenerate arrays.
107108 /// - `rotation_signs` has `3 * padded_dim` elements, or is empty for degenerate arrays.
@@ -124,13 +125,22 @@ impl TurboQuantData {
124125 . and_then ( |ext| extension_list_size ( ext) . ok ( ) )
125126 . vortex_expect ( "dtype must be a Vector extension type with FixedSizeList storage" ) ;
126127
127- let bit_width = derive_bit_width ( & centroids) ;
128+ let bit_width = if centroids. is_empty ( ) {
129+ 0
130+ } else {
131+ // Guaranteed to be 0-8 by validate().
132+ #[ expect( clippy:: cast_possible_truncation) ]
133+ {
134+ centroids. len ( ) . trailing_zeros ( ) as u8
135+ }
136+ } ;
128137
129138 let mut slots = vec ! [ None ; Slot :: COUNT ] ;
130139 slots[ Slot :: Codes as usize ] = Some ( codes) ;
131140 slots[ Slot :: Norms as usize ] = Some ( norms) ;
132141 slots[ Slot :: Centroids as usize ] = Some ( centroids) ;
133142 slots[ Slot :: RotationSigns as usize ] = Some ( rotation_signs) ;
143+
134144 Self {
135145 dtype,
136146 slots,
@@ -153,15 +163,19 @@ impl TurboQuantData {
153163 let ext = TurboQuant :: validate_dtype ( dtype) ?;
154164 let dimension = extension_list_size ( ext) ?;
155165
156- let num_rows = norms. len ( ) ;
166+ let num_rows = codes. len ( ) ;
167+ vortex_ensure_eq ! (
168+ norms. len( ) ,
169+ num_rows,
170+ "norms length must match codes length" ,
171+ ) ;
172+
173+ // TODO(connor): Should we check that the codes and norms have the same validity? We could
174+ // also make it so that norms holds the validity and any null vectors encoded as codes is
175+ // just 0...
157176
158- // Degenerate (empty) case: all children must be empty, bit_width is 0.
177+ // Degenerate (empty) case: all children must be empty, and bit_width is 0.
159178 if num_rows == 0 {
160- vortex_ensure ! (
161- codes. is_empty( ) ,
162- "degenerate TurboQuant must have empty codes, got length {}" ,
163- codes. len( )
164- ) ;
165179 vortex_ensure ! (
166180 centroids. is_empty( ) ,
167181 "degenerate TurboQuant must have empty centroids, got length {}" ,
@@ -183,7 +197,7 @@ impl TurboQuantData {
183197 ) ;
184198
185199 // Guaranteed to be 1-8 by the preceding power-of-2 and range checks.
186- #[ allow ( clippy:: cast_possible_truncation) ]
200+ #[ expect ( clippy:: cast_possible_truncation) ]
187201 let bit_width = num_centroids. trailing_zeros ( ) as u8 ;
188202 vortex_ensure ! (
189203 ( 1 ..=8 ) . contains( & bit_width) ,
@@ -193,44 +207,34 @@ impl TurboQuantData {
193207 // Norms dtype must match the element ptype of the Vector.
194208 let element_ptype = extension_element_ptype ( ext) ?;
195209 let expected_norms_dtype = DType :: Primitive ( element_ptype, Nullability :: NonNullable ) ;
196- vortex_ensure ! (
197- * norms. dtype( ) == expected_norms_dtype,
198- "norms dtype {} does not match expected {expected_norms_dtype} \
210+ vortex_ensure_eq ! (
211+ * norms. dtype( ) ,
212+ expected_norms_dtype,
213+ "norms dtype does not match expected {expected_norms_dtype} \
199214 (must match Vector element type)",
200- norms. dtype( )
201215 ) ;
202216
203217 // Centroids are always f32 regardless of element type.
204- let f32_nn = DType :: Primitive ( PType :: F32 , Nullability :: NonNullable ) ;
205- vortex_ensure ! (
206- * centroids. dtype( ) == f32_nn,
207- "centroids dtype {} must be non-nullable f32" ,
208- centroids. dtype( )
209- ) ;
210-
211- // Row count consistency.
212- vortex_ensure ! (
213- codes. len( ) == num_rows,
214- "codes length {} does not match norms length {num_rows}" ,
215- codes. len( )
218+ let centroids_dtype = DType :: Primitive ( PType :: F32 , Nullability :: NonNullable ) ;
219+ vortex_ensure_eq ! (
220+ * centroids. dtype( ) ,
221+ centroids_dtype,
222+ "centroids dtype must be non-nullable f32" ,
216223 ) ;
217224
218225 // Rotation signs count must be 3 * padded_dim.
219226 let padded_dim = dimension. next_power_of_two ( ) as usize ;
220- vortex_ensure ! (
221- rotation_signs. len( ) == 3 * padded_dim,
222- "rotation_signs length {} does not match expected 3 * {padded_dim} = {}" ,
227+ vortex_ensure_eq ! (
223228 rotation_signs. len( ) ,
224- 3 * padded_dim
229+ 3 * padded_dim,
230+ "rotation_signs length does not match expected 3 * {padded_dim}" ,
225231 ) ;
226232
227233 Ok ( ( ) )
228234 }
229235
230- /// The vector dimension `d`, as stored in the [`Vector`] extension dtype's
231- /// `FixedSizeList` storage.
232- ///
233- /// [`Vector`]: crate::vector::Vector
236+ /// The vector dimension `d`, as stored in the [`Vector`](crate::vector::Vector) extension
237+ /// dtype's `FixedSizeList` storage.
234238 pub fn dimension ( & self ) -> u32 {
235239 self . dimension
236240 }
@@ -248,12 +252,6 @@ impl TurboQuantData {
248252 self . dimension . next_power_of_two ( )
249253 }
250254
251- fn slot ( & self , idx : usize ) -> & ArrayRef {
252- self . slots [ idx]
253- . as_ref ( )
254- . vortex_expect ( "required slot is None" )
255- }
256-
257255 /// The quantized codes child (`FixedSizeListArray<u8>`, one row per vector).
258256 pub fn codes ( & self ) -> & ArrayRef {
259257 self . slot ( Slot :: Codes as usize )
@@ -278,19 +276,10 @@ impl TurboQuantData {
278276 pub fn rotation_signs ( & self ) -> & ArrayRef {
279277 self . slot ( Slot :: RotationSigns as usize )
280278 }
281- }
282279
283- /// Derive `bit_width` from the centroids array length.
284- ///
285- /// Returns 0 for empty centroids (degenerate array), otherwise `log2(centroids.len())`.
286- fn derive_bit_width ( centroids : & ArrayRef ) -> u8 {
287- if centroids. is_empty ( ) {
288- 0
289- } else {
290- // Guaranteed to be 0-8 by validate().
291- #[ allow( clippy:: cast_possible_truncation) ]
292- {
293- centroids. len ( ) . trailing_zeros ( ) as u8
294- }
280+ fn slot ( & self , idx : usize ) -> & ArrayRef {
281+ self . slots [ idx]
282+ . as_ref ( )
283+ . vortex_expect ( "required slot is None" )
295284 }
296285}
0 commit comments