Skip to content

Commit 8d55015

Browse files
authored
TurboQuant cleanup part 2 (#7326)
## Summary Tracking issue: #7297 Adds named limit constants, removes the old double-quantized similarity path, and updates the docs to better match the current Stage 1 implementation. It also replaces all remaining `to_canonical` usage in `vortex-tensor` with explicit `execute(...)` calls and tightens `cast_possible_truncation` handling by using targeted `expect` annotations or checked conversions instead of broad `allow`s. ## API Changes Added 2 constants. ## Testing NA --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent be65147 commit 8d55015

File tree

16 files changed

+204
-320
lines changed

16 files changed

+204
-320
lines changed

vortex-tensor/public-api.lock

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ impl vortex_tensor::encodings::turboquant::TurboQuant
1010

1111
pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId
1212

13+
pub const vortex_tensor::encodings::turboquant::TurboQuant::MAX_BIT_WIDTH: u8
14+
15+
pub const vortex_tensor::encodings::turboquant::TurboQuant::MAX_CENTROIDS: usize
16+
1317
pub const vortex_tensor::encodings::turboquant::TurboQuant::MIN_DIMENSION: u32
1418

1519
pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_tensor::encodings::turboquant::TurboQuantArray>
@@ -412,7 +416,7 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _opt
412416

413417
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
414418

415-
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
419+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
416420

417421
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
418422

vortex-tensor/src/encodings/turboquant/array/centroids.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Vec<f32>>> = LazyLock::new(Da
3737
/// `dimension`-dimensional space.
3838
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> {
3939
vortex_ensure!(
40-
(1..=8).contains(&bit_width),
41-
"TurboQuant bit_width must be 1-8, got {bit_width}"
40+
(1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width),
41+
"TurboQuant bit_width must be 1-{}, got {bit_width}",
42+
TurboQuant::MAX_BIT_WIDTH
4243
);
4344
vortex_ensure!(
4445
dimension >= TurboQuant::MIN_DIMENSION,
@@ -91,7 +92,7 @@ impl HalfIntExponent {
9192
/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]`
9293
/// where `C_d` is the normalizing constant.
9394
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
94-
debug_assert!((1..=8).contains(&bit_width));
95+
debug_assert!((1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width));
9596
let num_centroids = 1usize << bit_width;
9697

9798
// For the marginal distribution on [-1, 1], we use the exponent (d-3)/2.
@@ -220,7 +221,6 @@ pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 {
220221
}
221222

222223
#[cfg(test)]
223-
#[allow(clippy::cast_possible_truncation)]
224224
mod tests {
225225
use rstest::rstest;
226226
use vortex_error::VortexResult;
@@ -311,9 +311,11 @@ mod tests {
311311
let boundaries = compute_centroid_boundaries(&centroids);
312312
assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0);
313313

314+
#[expect(clippy::cast_possible_truncation)]
314315
let last_idx = (centroids.len() - 1) as u8;
315316
assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx);
316317
for (idx, &cv) in centroids.iter().enumerate() {
318+
#[expect(clippy::cast_possible_truncation)]
317319
let expected = idx as u8;
318320
assert_eq!(find_nearest_centroid(cv, &boundaries), expected);
319321
}

vortex-tensor/src/encodings/turboquant/array/data.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::encodings::turboquant::vtable::TurboQuant;
2020
///
2121
/// TurboQuant is a lossy vector quantization encoding for [`Vector`](crate::vector::Vector)
2222
/// extension arrays. It stores quantized coordinate codes and per-vector norms, along with shared
23-
/// codebook centroids and SRHT rotation signs.
23+
/// codebook centroids and the parameters of the current structured rotation.
2424
///
2525
/// See the [module docs](crate::encodings::turboquant) for algorithmic details.
2626
///
@@ -45,16 +45,17 @@ impl TurboQuantData {
4545
///
4646
/// Returns an error if:
4747
/// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
48-
/// - `bit_width` is greater than 8.
48+
/// - `bit_width` is greater than [`MAX_BIT_WIDTH`](TurboQuant::MAX_BIT_WIDTH).
4949
pub fn try_new(dimension: u32, bit_width: u8) -> VortexResult<Self> {
5050
vortex_ensure!(
5151
dimension >= TurboQuant::MIN_DIMENSION,
5252
"TurboQuant requires dimension >= {}, got {dimension}",
5353
TurboQuant::MIN_DIMENSION
5454
);
5555
vortex_ensure!(
56-
bit_width <= 8,
57-
"bit_width is expected to be between 0 and 8, got {bit_width}"
56+
bit_width <= TurboQuant::MAX_BIT_WIDTH,
57+
"bit_width is expected to be between 0 and {}, got {bit_width}",
58+
TurboQuant::MAX_BIT_WIDTH
5859
);
5960

6061
Ok(Self {
@@ -70,7 +71,7 @@ impl TurboQuantData {
7071
/// The caller must ensure:
7172
///
7273
/// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
73-
/// - `bit_width` is in the range `[0, 8]`.
74+
/// - `bit_width` is in the range `[0, MAX_BIT_WIDTH]`.
7475
///
7576
/// Violating these invariants may produce incorrect results during decompression.
7677
pub unsafe fn new_unchecked(dimension: u32, bit_width: u8) -> Self {
@@ -132,16 +133,21 @@ impl TurboQuantData {
132133
// Non-degenerate: derive and validate bit_width from centroids.
133134
let num_centroids = centroids.len();
134135
vortex_ensure!(
135-
num_centroids.is_power_of_two() && (2..=256).contains(&num_centroids),
136-
"centroids length must be a power of 2 in [2, 256], got {num_centroids}"
136+
num_centroids.is_power_of_two()
137+
&& (2..=TurboQuant::MAX_CENTROIDS).contains(&num_centroids),
138+
"centroids length must be a power of 2 in [2, {}], got {num_centroids}",
139+
TurboQuant::MAX_CENTROIDS
137140
);
138141

139-
// Guaranteed to be 1-8 by the preceding power-of-2 and range checks.
140-
#[expect(clippy::cast_possible_truncation)]
142+
#[expect(
143+
clippy::cast_possible_truncation,
144+
reason = "Guaranteed to be [1,8] by the preceding power-of-2 and range checks."
145+
)]
141146
let bit_width = num_centroids.trailing_zeros() as u8;
142147
vortex_ensure!(
143-
(1..=8).contains(&bit_width),
144-
"derived bit_width must be 1-8, got {bit_width}"
148+
(1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width),
149+
"derived bit_width must be 1-{}, got {bit_width}",
150+
TurboQuant::MAX_BIT_WIDTH
145151
);
146152

147153
// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
@@ -192,15 +198,15 @@ impl TurboQuantData {
192198
self.dimension
193199
}
194200

195-
/// MSE bits per coordinate (1-8 for non-empty arrays, 0 for degenerate empty arrays).
201+
/// MSE bits per coordinate (1-MAX_BIT_WIDTH for non-empty arrays, 0 for degenerate empty arrays).
196202
pub fn bit_width(&self) -> u8 {
197203
self.bit_width
198204
}
199205

200206
/// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)).
201207
///
202-
/// The SRHT rotation requires power-of-2 input, so non-power-of-2 dimensions are
203-
/// zero-padded to this value.
208+
/// The current Walsh-Hadamard-based structured rotation requires power-of-2 input, so
209+
/// non-power-of-2 dimensions are zero-padded to this value.
204210
pub fn padded_dim(&self) -> u32 {
205211
self.dimension.next_power_of_two()
206212
}

vortex-tensor/src/encodings/turboquant/array/rotation.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
//! Deterministic random rotation for TurboQuant.
55
//!
6-
//! Uses a Structured Random Hadamard Transform (SRHT) for O(d log d) rotation
7-
//! instead of a full d×d matrix multiply. The SRHT applies the sequence
8-
//! D₃ · H · D₂ · H · D₁ where H is the Walsh-Hadamard Transform (WHT) and Dₖ are
9-
//! random diagonal ±1 sign matrices. Three rounds of HD provide sufficient
10-
//! randomness for near-uniform distribution on the sphere.
6+
//! The TurboQuant paper analyzes a full random orthogonal rotation. The current implementation
7+
//! uses a cheaper structured Walsh-Hadamard-based surrogate instead of a dense d x d matrix.
8+
//!
9+
//! Concretely, this applies three rounds of random sign diagonals interleaved with the
10+
//! Walsh-Hadamard Transform: D3 * H * D2 * H * D1 * H, followed by normalization. This is a
11+
//! SORF-style structured approximation to a random orthogonal matrix, chosen for O(d log d)
12+
//! encode/decode cost and compact serialized parameters.
1113
//!
1214
//! For dimensions that are not powers of 2, the input is zero-padded to the
1315
//! next power of 2 before the transform and truncated afterward.
@@ -28,7 +30,7 @@ use vortex_error::vortex_ensure;
2830
/// IEEE 754 sign bit mask for f32.
2931
const F32_SIGN_BIT: u32 = 0x8000_0000;
3032

31-
/// A structured random Hadamard transform for O(d log d) pseudo-random rotation.
33+
/// A Walsh-Hadamard-based structured surrogate for a random orthogonal rotation.
3234
pub struct RotationMatrix {
3335
/// XOR masks for each of the 3 diagonal matrices, each of length `padded_dim`.
3436
/// `0x00000000` = multiply by +1 (no-op), `0x80000000` = multiply by -1 (flip sign bit).
@@ -40,7 +42,7 @@ pub struct RotationMatrix {
4042
}
4143

4244
impl RotationMatrix {
43-
/// Create a new SRHT rotation from a deterministic seed.
45+
/// Create a new structured Walsh-Hadamard-based rotation from a deterministic seed.
4446
pub fn try_new(seed: u64, dimension: usize) -> VortexResult<Self> {
4547
let padded_dim = dimension.next_power_of_two();
4648
let mut rng = StdRng::seed_from_u64(seed);
@@ -55,7 +57,7 @@ impl RotationMatrix {
5557
})
5658
}
5759

58-
/// Apply forward rotation: `output = SRHT(input)`.
60+
/// Apply forward rotation: `output = R(input)`.
5961
///
6062
/// Both `input` and `output` must have length `padded_dim()`. The caller
6163
/// is responsible for zero-padding input beyond `dim` positions.
@@ -67,7 +69,7 @@ impl RotationMatrix {
6769
self.apply_srht(output);
6870
}
6971

70-
/// Apply inverse rotation: `output = SRHT⁻¹(input)`.
72+
/// Apply inverse rotation: `output = R⁻¹(input)`.
7173
///
7274
/// Both `input` and `output` must have length `padded_dim()`.
7375
pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) {
@@ -85,7 +87,7 @@ impl RotationMatrix {
8587
self.padded_dim
8688
}
8789

88-
/// Apply the SRHT: D₃ · H · D₂ · H · D₁ · x, with normalization.
90+
/// Apply the structured rotation: `D₃ · H · D₂ · H · D₁ · H · x`, with normalization.
8991
fn apply_srht(&self, buf: &mut [f32]) {
9092
apply_signs_xor(buf, &self.sign_masks[0]);
9193
walsh_hadamard_transform(buf);
@@ -100,7 +102,7 @@ impl RotationMatrix {
100102
buf.iter_mut().for_each(|val| *val *= norm);
101103
}
102104

103-
/// Apply the inverse SRHT.
105+
/// Apply the inverse structured rotation.
104106
///
105107
/// Forward is: norm · H · D₃ · H · D₂ · H · D₁
106108
/// Inverse is: norm · D₁ · H · D₂ · H · D₃ · H

vortex-tensor/src/encodings/turboquant/array/scheme.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ fn estimate_compression_ratio(bits_per_element: u8, dimensions: u32, num_vectors
9898
// Shared overhead: codebook centroids (2^bit_width f32 values) and
9999
// rotation signs (3 * padded_dim bits).
100100
let num_centroids = 1usize << config.bit_width;
101+
debug_assert!(num_centroids <= TurboQuant::MAX_CENTROIDS);
101102
let overhead_bits = num_centroids * 32 // centroids are always f32
102103
+ 3 * padded_dim; // rotation signs, 1 bit each
103104

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ pub struct TurboQuantConfig {
4646
impl Default for TurboQuantConfig {
4747
fn default() -> Self {
4848
Self {
49-
bit_width: 8,
49+
bit_width: TurboQuant::MAX_BIT_WIDTH,
5050
seed: Some(42),
5151
}
5252
}
@@ -55,10 +55,12 @@ impl Default for TurboQuantConfig {
5555
/// Extract elements from a FixedSizeListArray as a flat f32 PrimitiveArray for quantization.
5656
///
5757
/// All quantization (rotation, centroid lookup) happens in f32. f16 is upcast; f64 is truncated.
58-
#[allow(clippy::cast_possible_truncation)]
59-
fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult<PrimitiveArray> {
58+
fn extract_f32_elements(
59+
fsl: &FixedSizeListArray,
60+
ctx: &mut ExecutionCtx,
61+
) -> VortexResult<PrimitiveArray> {
6062
let elements = fsl.elements();
61-
let primitive = elements.to_canonical()?.into_primitive();
63+
let primitive = elements.clone().execute::<PrimitiveArray>(ctx)?;
6264
let ptype = primitive.ptype();
6365

6466
match ptype {
@@ -71,7 +73,14 @@ fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult<PrimitiveArray
7173
PType::F64 => Ok(primitive
7274
.as_slice::<f64>()
7375
.iter()
74-
.map(|&v| v as f32)
76+
.map(|&v| {
77+
#[expect(
78+
clippy::cast_possible_truncation,
79+
reason = "TurboQuant quantization operates in f32, so f64 inputs are intentionally downcast"
80+
)]
81+
let v = v as f32;
82+
v
83+
})
7584
.collect()),
7685
_ => vortex_bail!("TurboQuant requires float elements, got {ptype:?}"),
7786
}
@@ -94,7 +103,6 @@ struct QuantizationResult {
94103
/// Norms are computed in the native element precision via the [`L2Norm`] scalar function.
95104
/// The rotation and centroid lookup happen in f32. Null rows (per the input validity) produce
96105
/// all-zero codes.
97-
#[allow(clippy::cast_possible_truncation)]
98106
fn turboquant_quantize_core(
99107
ext: ArrayView<Extension>,
100108
fsl: &FixedSizeListArray,
@@ -103,7 +111,8 @@ fn turboquant_quantize_core(
103111
validity: &Validity,
104112
ctx: &mut ExecutionCtx,
105113
) -> VortexResult<QuantizationResult> {
106-
let dimension = fsl.list_size() as usize;
114+
let dimension =
115+
usize::try_from(fsl.list_size()).vortex_expect("u32 FixedSizeList dimension fits in usize");
107116
let num_rows = fsl.len();
108117

109118
// Compute native-precision norms via the L2Norm scalar fn. L2Norm propagates validity from
@@ -127,10 +136,12 @@ fn turboquant_quantize_core(
127136

128137
let rotation = RotationMatrix::try_new(seed, dimension)?;
129138
let padded_dim = rotation.padded_dim();
139+
let padded_dim_u32 =
140+
u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
130141

131-
let f32_elements = extract_f32_elements(fsl)?;
142+
let f32_elements = extract_f32_elements(fsl, ctx)?;
132143

133-
let centroids = get_centroids(padded_dim as u32, bit_width)?;
144+
let centroids = get_centroids(padded_dim_u32, bit_width)?;
134145
let boundaries = compute_centroid_boundaries(&centroids);
135146

136147
let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
@@ -173,19 +184,20 @@ fn turboquant_quantize_core(
173184
}
174185

175186
/// Build a `TurboQuantArray` from quantization results.
176-
#[allow(clippy::cast_possible_truncation)]
177187
fn build_turboquant(
178188
fsl: &FixedSizeListArray,
179189
core: QuantizationResult,
180190
ext_dtype: DType,
181191
) -> VortexResult<TurboQuantArray> {
182192
let num_rows = fsl.len();
183193
let padded_dim = core.padded_dim;
194+
let padded_dim_u32 =
195+
u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
184196
let codes_elements =
185197
PrimitiveArray::new::<u8>(core.all_indices.freeze(), Validity::NonNullable);
186198
let codes = FixedSizeListArray::try_new(
187199
codes_elements.into_array(),
188-
padded_dim as u32,
200+
padded_dim_u32,
189201
Validity::NonNullable,
190202
num_rows,
191203
)?
@@ -220,11 +232,12 @@ pub fn turboquant_encode(
220232
) -> VortexResult<ArrayRef> {
221233
let ext_dtype = ext.dtype().clone();
222234
let storage = ext.storage_array();
223-
let fsl = storage.to_canonical()?.into_fixed_size_list();
235+
let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;
224236

225237
vortex_ensure!(
226-
config.bit_width >= 1 && config.bit_width <= 8,
227-
"bit_width must be 1-8, got {}",
238+
config.bit_width >= 1 && config.bit_width <= TurboQuant::MAX_BIT_WIDTH,
239+
"bit_width must be 1-{}, got {}",
240+
TurboQuant::MAX_BIT_WIDTH,
228241
config.bit_width
229242
);
230243
let dimension = fsl.list_size();

0 commit comments

Comments
 (0)