Skip to content

Commit ab4f0c6

Browse files
committed
allow dynamic WHT rounds
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent ec28916 commit ab4f0c6

11 files changed

Lines changed: 299 additions & 145 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ pub struct vortex_tensor::encodings::turboquant::TurboQuantConfig
7878

7979
pub vortex_tensor::encodings::turboquant::TurboQuantConfig::bit_width: u8
8080

81+
pub vortex_tensor::encodings::turboquant::TurboQuantConfig::num_rounds: u8
82+
8183
pub vortex_tensor::encodings::turboquant::TurboQuantConfig::seed: core::option::Option<u64>
8284

8385
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantConfig
@@ -100,11 +102,13 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantData::bit_width(&self) ->
100102

101103
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::dimension(&self) -> u32
102104

103-
pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dimension: u32, bit_width: u8) -> Self
105+
pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dimension: u32, bit_width: u8, num_rounds: u8) -> Self
106+
107+
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::num_rounds(&self) -> u8
104108

105109
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::padded_dim(&self) -> u32
106110

107-
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dimension: u32, bit_width: u8) -> vortex_error::VortexResult<Self>
111+
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> vortex_error::VortexResult<Self>
108112

109113
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::validate(dtype: &vortex_array::dtype::DType, codes: &vortex_array::array::erased::ArrayRef, norms: &vortex_array::array::erased::ArrayRef, centroids: &vortex_array::array::erased::ArrayRef, rotation_signs: &vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<()>
110114

@@ -156,34 +160,22 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::scheme_name(&self
156160

157161
pub trait vortex_tensor::encodings::turboquant::TurboQuantArrayExt: vortex_array::array::typed::TypedArrayRef<vortex_tensor::encodings::turboquant::TurboQuant>
158162

159-
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::bit_width(&self) -> u8
160-
161163
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::centroids(&self) -> &vortex_array::array::erased::ArrayRef
162164

163165
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::codes(&self) -> &vortex_array::array::erased::ArrayRef
164166

165-
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::dimension(&self) -> u32
166-
167167
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::norms(&self) -> &vortex_array::array::erased::ArrayRef
168168

169-
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::padded_dim(&self) -> u32
170-
171169
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef
172170

173171
impl<T: vortex_array::array::typed::TypedArrayRef<vortex_tensor::encodings::turboquant::TurboQuant>> vortex_tensor::encodings::turboquant::TurboQuantArrayExt for T
174172

175-
pub fn T::bit_width(&self) -> u8
176-
177173
pub fn T::centroids(&self) -> &vortex_array::array::erased::ArrayRef
178174

179175
pub fn T::codes(&self) -> &vortex_array::array::erased::ArrayRef
180176

181-
pub fn T::dimension(&self) -> u32
182-
183177
pub fn T::norms(&self) -> &vortex_array::array::erased::ArrayRef
184178

185-
pub fn T::padded_dim(&self) -> u32
186-
187179
pub fn T::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef
188180

189181
pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

vortex-tensor/src/encodings/turboquant/array/data.rs

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ pub struct TurboQuantData {
3636
///
3737
/// This is 0 for degenerate empty arrays.
3838
pub(crate) bit_width: u8,
39+
40+
/// The number of sign-diagonal + WHT rounds in the structured rotation.
41+
///
42+
/// This is 0 for degenerate empty arrays.
43+
pub(crate) num_rounds: u8,
3944
}
4045

4146
impl TurboQuantData {
@@ -46,7 +51,7 @@ impl TurboQuantData {
4651
/// Returns an error if:
4752
/// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
4853
/// - `bit_width` is greater than [`MAX_BIT_WIDTH`](TurboQuant::MAX_BIT_WIDTH).
49-
pub fn try_new(dimension: u32, bit_width: u8) -> VortexResult<Self> {
54+
pub fn try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> VortexResult<Self> {
5055
vortex_ensure!(
5156
dimension >= TurboQuant::MIN_DIMENSION,
5257
"TurboQuant requires dimension >= {}, got {dimension}",
@@ -61,6 +66,7 @@ impl TurboQuantData {
6166
Ok(Self {
6267
dimension,
6368
bit_width,
69+
num_rounds,
6470
})
6571
}
6672

@@ -72,12 +78,14 @@ impl TurboQuantData {
7278
///
7379
/// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
7480
/// - `bit_width` is in the range `[0, MAX_BIT_WIDTH]`.
81+
/// - `num_rounds` is >= 1 (or 0 for degenerate empty arrays).
7582
///
7683
/// Violating these invariants may produce incorrect results during decompression.
77-
pub unsafe fn new_unchecked(dimension: u32, bit_width: u8) -> Self {
84+
pub unsafe fn new_unchecked(dimension: u32, bit_width: u8, num_rounds: u8) -> Self {
7885
Self {
7986
dimension,
8087
bit_width,
88+
num_rounds,
8189
}
8290
}
8391

@@ -168,11 +176,21 @@ impl TurboQuantData {
168176
"centroids dtype must be non-nullable f32",
169177
);
170178

171-
// Rotation signs count must be 3 * padded_dim.
179+
// Rotation signs must be a FixedSizeList<u8> with list_size == padded_dim. The FSL length
180+
// is the number of rotation rounds.
181+
let expected_signs_dtype = DType::FixedSizeList(
182+
Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
183+
padded_dim,
184+
Nullability::NonNullable,
185+
);
172186
vortex_ensure_eq!(
173-
rotation_signs.len(),
174-
3 * padded_dim as usize,
175-
"rotation_signs length does not match expected 3 * {padded_dim}",
187+
*rotation_signs.dtype(),
188+
expected_signs_dtype,
189+
"rotation_signs dtype does not match expected {expected_signs_dtype}",
190+
);
191+
vortex_ensure!(
192+
!rotation_signs.is_empty(),
193+
"rotation_signs must have at least 1 round"
176194
);
177195

178196
Ok(())
@@ -203,6 +221,11 @@ impl TurboQuantData {
203221
self.bit_width
204222
}
205223

224+
/// The number of sign-diagonal + WHT rounds in the structured rotation.
225+
pub fn num_rounds(&self) -> u8 {
226+
self.num_rounds
227+
}
228+
206229
/// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)).
207230
///
208231
/// The current Walsh-Hadamard-based structured rotation requires power-of-2 input, so
@@ -213,18 +236,6 @@ impl TurboQuantData {
213236
}
214237

215238
pub trait TurboQuantArrayExt: TypedArrayRef<TurboQuant> {
216-
fn dimension(&self) -> u32 {
217-
std::ops::Deref::deref(self).dimension()
218-
}
219-
220-
fn bit_width(&self) -> u8 {
221-
std::ops::Deref::deref(self).bit_width()
222-
}
223-
224-
fn padded_dim(&self) -> u32 {
225-
std::ops::Deref::deref(self).padded_dim()
226-
}
227-
228239
fn codes(&self) -> &ArrayRef {
229240
self.as_ref().slots()[Slot::Codes as usize]
230241
.as_ref()

0 commit comments

Comments
 (0)