@@ -36,6 +36,11 @@ pub struct TurboQuantData {
3636 ///
3737 /// This is 0 for degenerate empty arrays.
3838 pub ( crate ) bit_width : u8 ,
39+
40+ /// The number of sign-diagonal + WHT rounds in the structured rotation.
41+ ///
42+ /// This is 0 for degenerate empty arrays.
43+ pub ( crate ) num_rounds : u8 ,
3944}
4045
4146impl TurboQuantData {
@@ -46,7 +51,7 @@ impl TurboQuantData {
4651 /// Returns an error if:
4752 /// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
4853 /// - `bit_width` is greater than [`MAX_BIT_WIDTH`](TurboQuant::MAX_BIT_WIDTH).
49- pub fn try_new ( dimension : u32 , bit_width : u8 ) -> VortexResult < Self > {
54+ pub fn try_new ( dimension : u32 , bit_width : u8 , num_rounds : u8 ) -> VortexResult < Self > {
5055 vortex_ensure ! (
5156 dimension >= TurboQuant :: MIN_DIMENSION ,
5257 "TurboQuant requires dimension >= {}, got {dimension}" ,
@@ -61,6 +66,7 @@ impl TurboQuantData {
6166 Ok ( Self {
6267 dimension,
6368 bit_width,
69+ num_rounds,
6470 } )
6571 }
6672
@@ -72,12 +78,14 @@ impl TurboQuantData {
7278 ///
7379 /// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
7480 /// - `bit_width` is in the range `[0, MAX_BIT_WIDTH]`.
81+ /// - `num_rounds` is >= 1 (or 0 for degenerate empty arrays).
7582 ///
7683 /// Violating these invariants may produce incorrect results during decompression.
77- pub unsafe fn new_unchecked ( dimension : u32 , bit_width : u8 ) -> Self {
84+ pub unsafe fn new_unchecked ( dimension : u32 , bit_width : u8 , num_rounds : u8 ) -> Self {
7885 Self {
7986 dimension,
8087 bit_width,
88+ num_rounds,
8189 }
8290 }
8391
@@ -115,6 +123,36 @@ impl TurboQuantData {
115123 "norms length must match codes length" ,
116124 ) ;
117125
126+ // Norms dtype must match the element ptype of the Vector, with the parent's nullability.
127+ // Norms carry the validity of the entire TurboQuant array.
128+ let element_ptype = vector_metadata. element_ptype ( ) ;
129+ let expected_norms_dtype = DType :: Primitive ( element_ptype, dtype. nullability ( ) ) ;
130+ vortex_ensure_eq ! (
131+ * norms. dtype( ) ,
132+ expected_norms_dtype,
133+ "norms dtype does not match expected {expected_norms_dtype}" ,
134+ ) ;
135+
136+ // Centroids are always f32 regardless of element type.
137+ let centroids_dtype = DType :: Primitive ( PType :: F32 , Nullability :: NonNullable ) ;
138+ vortex_ensure_eq ! (
139+ * centroids. dtype( ) ,
140+ centroids_dtype,
141+ "centroids dtype must be non-nullable f32" ,
142+ ) ;
143+
144+ // Rotation signs must be a FixedSizeList<u8> with list_size == padded_dim. The FSL length
145+ // is the number of rotation rounds.
146+ let expected_signs_dtype = DType :: FixedSizeList (
147+ Arc :: new ( DType :: Primitive ( PType :: U8 , Nullability :: NonNullable ) ) ,
148+ padded_dim,
149+ Nullability :: NonNullable ,
150+ ) ;
151+ vortex_ensure_eq ! (
152+ * rotation_signs. dtype( ) ,
153+ expected_signs_dtype,
154+ "rotation_signs dtype does not match expected {expected_signs_dtype}" ,
155+ ) ;
118156 // Degenerate (empty) case: all children must be empty, and bit_width is 0.
119157 if num_rows == 0 {
120158 vortex_ensure ! (
@@ -130,6 +168,11 @@ impl TurboQuantData {
130168 return Ok ( ( ) ) ;
131169 }
132170
171+ vortex_ensure ! (
172+ !rotation_signs. is_empty( ) ,
173+ "rotation_signs must have at least 1 round"
174+ ) ;
175+
133176 // Non-degenerate: derive and validate bit_width from centroids.
134177 let num_centroids = centroids. len ( ) ;
135178 vortex_ensure ! (
@@ -150,31 +193,6 @@ impl TurboQuantData {
150193 TurboQuant :: MAX_BIT_WIDTH
151194 ) ;
152195
153- // Norms dtype must match the element ptype of the Vector, with the parent's nullability.
154- // Norms carry the validity of the entire TurboQuant array.
155- let element_ptype = vector_metadata. element_ptype ( ) ;
156- let expected_norms_dtype = DType :: Primitive ( element_ptype, dtype. nullability ( ) ) ;
157- vortex_ensure_eq ! (
158- * norms. dtype( ) ,
159- expected_norms_dtype,
160- "norms dtype does not match expected {expected_norms_dtype}" ,
161- ) ;
162-
163- // Centroids are always f32 regardless of element type.
164- let centroids_dtype = DType :: Primitive ( PType :: F32 , Nullability :: NonNullable ) ;
165- vortex_ensure_eq ! (
166- * centroids. dtype( ) ,
167- centroids_dtype,
168- "centroids dtype must be non-nullable f32" ,
169- ) ;
170-
171- // Rotation signs count must be 3 * padded_dim.
172- vortex_ensure_eq ! (
173- rotation_signs. len( ) ,
174- 3 * padded_dim as usize ,
175- "rotation_signs length does not match expected 3 * {padded_dim}" ,
176- ) ;
177-
178196 Ok ( ( ) )
179197 }
180198
@@ -203,6 +221,11 @@ impl TurboQuantData {
203221 self . bit_width
204222 }
205223
224+ /// The number of sign-diagonal + WHT rounds in the structured rotation.
225+ pub fn num_rounds ( & self ) -> u8 {
226+ self . num_rounds
227+ }
228+
206229 /// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)).
207230 ///
208231 /// The current Walsh-Hadamard-based structured rotation requires power-of-2 input, so
@@ -213,18 +236,6 @@ impl TurboQuantData {
213236}
214237
215238pub trait TurboQuantArrayExt : TypedArrayRef < TurboQuant > {
216- fn dimension ( & self ) -> u32 {
217- std:: ops:: Deref :: deref ( self ) . dimension ( )
218- }
219-
220- fn bit_width ( & self ) -> u8 {
221- std:: ops:: Deref :: deref ( self ) . bit_width ( )
222- }
223-
224- fn padded_dim ( & self ) -> u32 {
225- std:: ops:: Deref :: deref ( self ) . padded_dim ( )
226- }
227-
228239 fn codes ( & self ) -> & ArrayRef {
229240 self . as_ref ( ) . slots ( ) [ Slot :: Codes as usize ]
230241 . as_ref ( )
0 commit comments