@@ -36,6 +36,11 @@ pub struct TurboQuantData {
3636 ///
3737 /// This is 0 for degenerate empty arrays.
3838 pub ( crate ) bit_width : u8 ,
39+
40+ /// The number of sign-diagonal + WHT rounds in the structured rotation.
41+ ///
42+ /// This is 0 for degenerate empty arrays.
43+ pub ( crate ) num_rounds : u8 ,
3944}
4045
4146impl TurboQuantData {
@@ -46,7 +51,7 @@ impl TurboQuantData {
4651 /// Returns an error if:
4752 /// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
4853 /// - `bit_width` is greater than [`MAX_BIT_WIDTH`](TurboQuant::MAX_BIT_WIDTH).
49- pub fn try_new ( dimension : u32 , bit_width : u8 ) -> VortexResult < Self > {
54+ pub fn try_new ( dimension : u32 , bit_width : u8 , num_rounds : u8 ) -> VortexResult < Self > {
5055 vortex_ensure ! (
5156 dimension >= TurboQuant :: MIN_DIMENSION ,
5257 "TurboQuant requires dimension >= {}, got {dimension}" ,
@@ -61,6 +66,7 @@ impl TurboQuantData {
6166 Ok ( Self {
6267 dimension,
6368 bit_width,
69+ num_rounds,
6470 } )
6571 }
6672
@@ -72,12 +78,14 @@ impl TurboQuantData {
7278 ///
7379 /// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
7480 /// - `bit_width` is in the range `[0, MAX_BIT_WIDTH]`.
81+ /// - `num_rounds` is >= 1 (or 0 for degenerate empty arrays).
7582 ///
7683 /// Violating these invariants may produce incorrect results during decompression.
77- pub unsafe fn new_unchecked ( dimension : u32 , bit_width : u8 ) -> Self {
84+ pub unsafe fn new_unchecked ( dimension : u32 , bit_width : u8 , num_rounds : u8 ) -> Self {
7885 Self {
7986 dimension,
8087 bit_width,
88+ num_rounds,
8189 }
8290 }
8391
@@ -168,11 +176,21 @@ impl TurboQuantData {
168176 "centroids dtype must be non-nullable f32" ,
169177 ) ;
170178
171- // Rotation signs count must be 3 * padded_dim.
179+ // Rotation signs must be a FixedSizeList<u8> with list_size == padded_dim. The FSL length
180+ // is the number of rotation rounds.
181+ let expected_signs_dtype = DType :: FixedSizeList (
182+ Arc :: new ( DType :: Primitive ( PType :: U8 , Nullability :: NonNullable ) ) ,
183+ padded_dim,
184+ Nullability :: NonNullable ,
185+ ) ;
172186 vortex_ensure_eq ! (
173- rotation_signs. len( ) ,
174- 3 * padded_dim as usize ,
175- "rotation_signs length does not match expected 3 * {padded_dim}" ,
187+ * rotation_signs. dtype( ) ,
188+ expected_signs_dtype,
189+ "rotation_signs dtype does not match expected {expected_signs_dtype}" ,
190+ ) ;
191+ vortex_ensure ! (
192+ !rotation_signs. is_empty( ) ,
193+ "rotation_signs must have at least 1 round"
176194 ) ;
177195
178196 Ok ( ( ) )
@@ -203,6 +221,11 @@ impl TurboQuantData {
203221 self . bit_width
204222 }
205223
224+ /// The number of sign-diagonal + WHT rounds in the structured rotation.
225+ pub fn num_rounds ( & self ) -> u8 {
226+ self . num_rounds
227+ }
228+
206229 /// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)).
207230 ///
208231 /// The current Walsh-Hadamard-based structured rotation requires power-of-2 input, so
@@ -213,18 +236,6 @@ impl TurboQuantData {
213236}
214237
215238pub trait TurboQuantArrayExt : TypedArrayRef < TurboQuant > {
216- fn dimension ( & self ) -> u32 {
217- std:: ops:: Deref :: deref ( self ) . dimension ( )
218- }
219-
220- fn bit_width ( & self ) -> u8 {
221- std:: ops:: Deref :: deref ( self ) . bit_width ( )
222- }
223-
224- fn padded_dim ( & self ) -> u32 {
225- std:: ops:: Deref :: deref ( self ) . padded_dim ( )
226- }
227-
228239 fn codes ( & self ) -> & ArrayRef {
229240 self . as_ref ( ) . slots ( ) [ Slot :: Codes as usize ]
230241 . as_ref ( )
0 commit comments