Skip to content

Commit 2e00050

Browse files
authored
Dynamic WHT rounds in TurboQuant (#7330)
## Summary Tracking issue: #7297 Adds the ability to have a dynamic number of rounds of FWHT rounds for the SORF algorithm. Previously, this was just hardcoded to 3. Also fixes a small validation bug. ## Testing Existing tests suffice, and then added some regression tests for the validation bug. --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent f87eefa commit 2e00050

File tree

11 files changed

+427
-190
lines changed

11 files changed

+427
-190
lines changed

vortex-tensor/public-api.lock

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ pub struct vortex_tensor::encodings::turboquant::TurboQuantConfig
7878

7979
pub vortex_tensor::encodings::turboquant::TurboQuantConfig::bit_width: u8
8080

81+
pub vortex_tensor::encodings::turboquant::TurboQuantConfig::num_rounds: u8
82+
8183
pub vortex_tensor::encodings::turboquant::TurboQuantConfig::seed: core::option::Option<u64>
8284

8385
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantConfig
@@ -100,11 +102,13 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantData::bit_width(&self) ->
100102

101103
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::dimension(&self) -> u32
102104

103-
pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dimension: u32, bit_width: u8) -> Self
105+
pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dimension: u32, bit_width: u8, num_rounds: u8) -> Self
106+
107+
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::num_rounds(&self) -> u8
104108

105109
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::padded_dim(&self) -> u32
106110

107-
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dimension: u32, bit_width: u8) -> vortex_error::VortexResult<Self>
111+
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> vortex_error::VortexResult<Self>
108112

109113
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::validate(dtype: &vortex_array::dtype::DType, codes: &vortex_array::array::erased::ArrayRef, norms: &vortex_array::array::erased::ArrayRef, centroids: &vortex_array::array::erased::ArrayRef, rotation_signs: &vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<()>
110114

@@ -156,34 +160,22 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::scheme_name(&self
156160

157161
pub trait vortex_tensor::encodings::turboquant::TurboQuantArrayExt: vortex_array::array::typed::TypedArrayRef<vortex_tensor::encodings::turboquant::TurboQuant>
158162

159-
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::bit_width(&self) -> u8
160-
161163
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::centroids(&self) -> &vortex_array::array::erased::ArrayRef
162164

163165
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::codes(&self) -> &vortex_array::array::erased::ArrayRef
164166

165-
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::dimension(&self) -> u32
166-
167167
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::norms(&self) -> &vortex_array::array::erased::ArrayRef
168168

169-
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::padded_dim(&self) -> u32
170-
171169
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef
172170

173171
impl<T: vortex_array::array::typed::TypedArrayRef<vortex_tensor::encodings::turboquant::TurboQuant>> vortex_tensor::encodings::turboquant::TurboQuantArrayExt for T
174172

175-
pub fn T::bit_width(&self) -> u8
176-
177173
pub fn T::centroids(&self) -> &vortex_array::array::erased::ArrayRef
178174

179175
pub fn T::codes(&self) -> &vortex_array::array::erased::ArrayRef
180176

181-
pub fn T::dimension(&self) -> u32
182-
183177
pub fn T::norms(&self) -> &vortex_array::array::erased::ArrayRef
184178

185-
pub fn T::padded_dim(&self) -> u32
186-
187179
pub fn T::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef
188180

189181
pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

vortex-tensor/src/encodings/turboquant/array/data.rs

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ pub struct TurboQuantData {
3636
///
3737
/// This is 0 for degenerate empty arrays.
3838
pub(crate) bit_width: u8,
39+
40+
/// The number of sign-diagonal + WHT rounds in the structured rotation.
41+
///
42+
/// This is 0 for degenerate empty arrays.
43+
pub(crate) num_rounds: u8,
3944
}
4045

4146
impl TurboQuantData {
@@ -46,7 +51,7 @@ impl TurboQuantData {
4651
/// Returns an error if:
4752
/// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
4853
/// - `bit_width` is greater than [`MAX_BIT_WIDTH`](TurboQuant::MAX_BIT_WIDTH).
49-
pub fn try_new(dimension: u32, bit_width: u8) -> VortexResult<Self> {
54+
pub fn try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> VortexResult<Self> {
5055
vortex_ensure!(
5156
dimension >= TurboQuant::MIN_DIMENSION,
5257
"TurboQuant requires dimension >= {}, got {dimension}",
@@ -61,6 +66,7 @@ impl TurboQuantData {
6166
Ok(Self {
6267
dimension,
6368
bit_width,
69+
num_rounds,
6470
})
6571
}
6672

@@ -72,12 +78,14 @@ impl TurboQuantData {
7278
///
7379
/// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
7480
/// - `bit_width` is in the range `[0, MAX_BIT_WIDTH]`.
81+
/// - `num_rounds` is >= 1 (or 0 for degenerate empty arrays).
7582
///
7683
/// Violating these invariants may produce incorrect results during decompression.
77-
pub unsafe fn new_unchecked(dimension: u32, bit_width: u8) -> Self {
84+
pub unsafe fn new_unchecked(dimension: u32, bit_width: u8, num_rounds: u8) -> Self {
7885
Self {
7986
dimension,
8087
bit_width,
88+
num_rounds,
8189
}
8290
}
8391

@@ -115,6 +123,36 @@ impl TurboQuantData {
115123
"norms length must match codes length",
116124
);
117125

126+
// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
127+
// Norms carry the validity of the entire TurboQuant array.
128+
let element_ptype = vector_metadata.element_ptype();
129+
let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability());
130+
vortex_ensure_eq!(
131+
*norms.dtype(),
132+
expected_norms_dtype,
133+
"norms dtype does not match expected {expected_norms_dtype}",
134+
);
135+
136+
// Centroids are always f32 regardless of element type.
137+
let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
138+
vortex_ensure_eq!(
139+
*centroids.dtype(),
140+
centroids_dtype,
141+
"centroids dtype must be non-nullable f32",
142+
);
143+
144+
// Rotation signs must be a FixedSizeList<u8> with list_size == padded_dim. The FSL length
145+
// is the number of rotation rounds.
146+
let expected_signs_dtype = DType::FixedSizeList(
147+
Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
148+
padded_dim,
149+
Nullability::NonNullable,
150+
);
151+
vortex_ensure_eq!(
152+
*rotation_signs.dtype(),
153+
expected_signs_dtype,
154+
"rotation_signs dtype does not match expected {expected_signs_dtype}",
155+
);
118156
// Degenerate (empty) case: all children must be empty, and bit_width is 0.
119157
if num_rows == 0 {
120158
vortex_ensure!(
@@ -130,6 +168,11 @@ impl TurboQuantData {
130168
return Ok(());
131169
}
132170

171+
vortex_ensure!(
172+
!rotation_signs.is_empty(),
173+
"rotation_signs must have at least 1 round"
174+
);
175+
133176
// Non-degenerate: derive and validate bit_width from centroids.
134177
let num_centroids = centroids.len();
135178
vortex_ensure!(
@@ -150,31 +193,6 @@ impl TurboQuantData {
150193
TurboQuant::MAX_BIT_WIDTH
151194
);
152195

153-
// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
154-
// Norms carry the validity of the entire TurboQuant array.
155-
let element_ptype = vector_metadata.element_ptype();
156-
let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability());
157-
vortex_ensure_eq!(
158-
*norms.dtype(),
159-
expected_norms_dtype,
160-
"norms dtype does not match expected {expected_norms_dtype}",
161-
);
162-
163-
// Centroids are always f32 regardless of element type.
164-
let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
165-
vortex_ensure_eq!(
166-
*centroids.dtype(),
167-
centroids_dtype,
168-
"centroids dtype must be non-nullable f32",
169-
);
170-
171-
// Rotation signs count must be 3 * padded_dim.
172-
vortex_ensure_eq!(
173-
rotation_signs.len(),
174-
3 * padded_dim as usize,
175-
"rotation_signs length does not match expected 3 * {padded_dim}",
176-
);
177-
178196
Ok(())
179197
}
180198

@@ -203,6 +221,11 @@ impl TurboQuantData {
203221
self.bit_width
204222
}
205223

224+
/// The number of sign-diagonal + WHT rounds in the structured rotation.
225+
pub fn num_rounds(&self) -> u8 {
226+
self.num_rounds
227+
}
228+
206229
/// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)).
207230
///
208231
/// The current Walsh-Hadamard-based structured rotation requires power-of-2 input, so
@@ -213,18 +236,6 @@ impl TurboQuantData {
213236
}
214237

215238
pub trait TurboQuantArrayExt: TypedArrayRef<TurboQuant> {
216-
fn dimension(&self) -> u32 {
217-
std::ops::Deref::deref(self).dimension()
218-
}
219-
220-
fn bit_width(&self) -> u8 {
221-
std::ops::Deref::deref(self).bit_width()
222-
}
223-
224-
fn padded_dim(&self) -> u32 {
225-
std::ops::Deref::deref(self).padded_dim()
226-
}
227-
228239
fn codes(&self) -> &ArrayRef {
229240
self.as_ref().slots()[Slot::Codes as usize]
230241
.as_ref()

0 commit comments

Comments
 (0)