Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ impl vortex_tensor::encodings::turboquant::TurboQuant

pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId

pub const vortex_tensor::encodings::turboquant::TurboQuant::MAX_BIT_WIDTH: u8

pub const vortex_tensor::encodings::turboquant::TurboQuant::MAX_CENTROIDS: usize

pub const vortex_tensor::encodings::turboquant::TurboQuant::MIN_DIMENSION: u32

pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_tensor::encodings::turboquant::TurboQuantArray>
Expand Down Expand Up @@ -412,7 +416,7 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _opt

pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName

pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

Expand Down
10 changes: 6 additions & 4 deletions vortex-tensor/src/encodings/turboquant/array/centroids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Vec<f32>>> = LazyLock::new(Da
/// `dimension`-dimensional space.
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> {
vortex_ensure!(
(1..=8).contains(&bit_width),
"TurboQuant bit_width must be 1-8, got {bit_width}"
(1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width),
"TurboQuant bit_width must be 1-{}, got {bit_width}",
TurboQuant::MAX_BIT_WIDTH
);
vortex_ensure!(
dimension >= TurboQuant::MIN_DIMENSION,
Expand Down Expand Up @@ -91,7 +92,7 @@ impl HalfIntExponent {
/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]`
/// where `C_d` is the normalizing constant.
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
debug_assert!((1..=8).contains(&bit_width));
debug_assert!((1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width));
let num_centroids = 1usize << bit_width;

// For the marginal distribution on [-1, 1], we use the exponent (d-3)/2.
Expand Down Expand Up @@ -220,7 +221,6 @@ pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 {
}

#[cfg(test)]
#[allow(clippy::cast_possible_truncation)]
mod tests {
use rstest::rstest;
use vortex_error::VortexResult;
Expand Down Expand Up @@ -311,9 +311,11 @@ mod tests {
let boundaries = compute_centroid_boundaries(&centroids);
assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0);

#[expect(clippy::cast_possible_truncation)]
let last_idx = (centroids.len() - 1) as u8;
assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx);
for (idx, &cv) in centroids.iter().enumerate() {
#[expect(clippy::cast_possible_truncation)]
let expected = idx as u8;
assert_eq!(find_nearest_centroid(cv, &boundaries), expected);
}
Expand Down
34 changes: 20 additions & 14 deletions vortex-tensor/src/encodings/turboquant/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::encodings::turboquant::vtable::TurboQuant;
///
/// TurboQuant is a lossy vector quantization encoding for [`Vector`](crate::vector::Vector)
/// extension arrays. It stores quantized coordinate codes and per-vector norms, along with shared
/// codebook centroids and SRHT rotation signs.
/// codebook centroids and the parameters of the current structured rotation.
///
/// See the [module docs](crate::encodings::turboquant) for algorithmic details.
///
Expand All @@ -45,16 +45,17 @@ impl TurboQuantData {
///
/// Returns an error if:
/// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
/// - `bit_width` is greater than 8.
/// - `bit_width` is greater than [`MAX_BIT_WIDTH`](TurboQuant::MAX_BIT_WIDTH).
pub fn try_new(dimension: u32, bit_width: u8) -> VortexResult<Self> {
vortex_ensure!(
dimension >= TurboQuant::MIN_DIMENSION,
"TurboQuant requires dimension >= {}, got {dimension}",
TurboQuant::MIN_DIMENSION
);
vortex_ensure!(
bit_width <= 8,
"bit_width is expected to be between 0 and 8, got {bit_width}"
bit_width <= TurboQuant::MAX_BIT_WIDTH,
"bit_width is expected to be between 0 and {}, got {bit_width}",
TurboQuant::MAX_BIT_WIDTH
);

Ok(Self {
Expand All @@ -70,7 +71,7 @@ impl TurboQuantData {
/// The caller must ensure:
///
/// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
/// - `bit_width` is in the range `[0, 8]`.
/// - `bit_width` is in the range `[0, MAX_BIT_WIDTH]`.
///
/// Violating these invariants may produce incorrect results during decompression.
pub unsafe fn new_unchecked(dimension: u32, bit_width: u8) -> Self {
Expand Down Expand Up @@ -132,16 +133,21 @@ impl TurboQuantData {
// Non-degenerate: derive and validate bit_width from centroids.
let num_centroids = centroids.len();
vortex_ensure!(
num_centroids.is_power_of_two() && (2..=256).contains(&num_centroids),
"centroids length must be a power of 2 in [2, 256], got {num_centroids}"
num_centroids.is_power_of_two()
&& (2..=TurboQuant::MAX_CENTROIDS).contains(&num_centroids),
"centroids length must be a power of 2 in [2, {}], got {num_centroids}",
TurboQuant::MAX_CENTROIDS
);

// Guaranteed to be 1-8 by the preceding power-of-2 and range checks.
#[expect(clippy::cast_possible_truncation)]
#[expect(
clippy::cast_possible_truncation,
reason = "Guaranteed to be [1,8] by the preceding power-of-2 and range checks."
)]
let bit_width = num_centroids.trailing_zeros() as u8;
vortex_ensure!(
(1..=8).contains(&bit_width),
"derived bit_width must be 1-8, got {bit_width}"
(1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width),
"derived bit_width must be 1-{}, got {bit_width}",
TurboQuant::MAX_BIT_WIDTH
);

// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
Expand Down Expand Up @@ -192,15 +198,15 @@ impl TurboQuantData {
self.dimension
}

/// MSE bits per coordinate (1-8 for non-empty arrays, 0 for degenerate empty arrays).
/// MSE bits per coordinate (1-MAX_BIT_WIDTH for non-empty arrays, 0 for degenerate empty arrays).
pub fn bit_width(&self) -> u8 {
self.bit_width
}

/// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)).
///
/// The SRHT rotation requires power-of-2 input, so non-power-of-2 dimensions are
/// zero-padded to this value.
/// The current Walsh-Hadamard-based structured rotation requires power-of-2 input, so
/// non-power-of-2 dimensions are zero-padded to this value.
pub fn padded_dim(&self) -> u32 {
self.dimension.next_power_of_two()
}
Expand Down
24 changes: 13 additions & 11 deletions vortex-tensor/src/encodings/turboquant/array/rotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

//! Deterministic random rotation for TurboQuant.
//!
//! Uses a Structured Random Hadamard Transform (SRHT) for O(d log d) rotation
//! instead of a full d×d matrix multiply. The SRHT applies the sequence
//! D₃ · H · D₂ · H · D₁ where H is the Walsh-Hadamard Transform (WHT) and Dₖ are
//! random diagonal ±1 sign matrices. Three rounds of HD provide sufficient
//! randomness for near-uniform distribution on the sphere.
//! The TurboQuant paper analyzes a full random orthogonal rotation. The current implementation
//! uses a cheaper structured Walsh-Hadamard-based surrogate instead of a dense d x d matrix.
//!
//! Concretely, this applies three rounds of random sign diagonals interleaved with the
//! Walsh-Hadamard Transform: D3 * H * D2 * H * D1 * H, followed by normalization. This is a
//! SORF-style structured approximation to a random orthogonal matrix, chosen for O(d log d)
//! encode/decode cost and compact serialized parameters.
//!
//! For dimensions that are not powers of 2, the input is zero-padded to the
//! next power of 2 before the transform and truncated afterward.
Expand All @@ -28,7 +30,7 @@ use vortex_error::vortex_ensure;
/// IEEE 754 sign bit mask for f32.
const F32_SIGN_BIT: u32 = 0x8000_0000;

/// A structured random Hadamard transform for O(d log d) pseudo-random rotation.
/// A Walsh-Hadamard-based structured surrogate for a random orthogonal rotation.
pub struct RotationMatrix {
/// XOR masks for each of the 3 diagonal matrices, each of length `padded_dim`.
/// `0x00000000` = multiply by +1 (no-op), `0x80000000` = multiply by -1 (flip sign bit).
Expand All @@ -40,7 +42,7 @@ pub struct RotationMatrix {
}

impl RotationMatrix {
/// Create a new SRHT rotation from a deterministic seed.
/// Create a new structured Walsh-Hadamard-based rotation from a deterministic seed.
pub fn try_new(seed: u64, dimension: usize) -> VortexResult<Self> {
let padded_dim = dimension.next_power_of_two();
let mut rng = StdRng::seed_from_u64(seed);
Expand All @@ -55,7 +57,7 @@ impl RotationMatrix {
})
}

/// Apply forward rotation: `output = SRHT(input)`.
/// Apply forward rotation: `output = R(input)`.
///
/// Both `input` and `output` must have length `padded_dim()`. The caller
/// is responsible for zero-padding input beyond `dim` positions.
Expand All @@ -67,7 +69,7 @@ impl RotationMatrix {
self.apply_srht(output);
}

/// Apply inverse rotation: `output = SRHT⁻¹(input)`.
/// Apply inverse rotation: `output = R⁻¹(input)`.
///
/// Both `input` and `output` must have length `padded_dim()`.
pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) {
Expand All @@ -85,7 +87,7 @@ impl RotationMatrix {
self.padded_dim
}

/// Apply the SRHT: D₃ · H · D₂ · H · D₁ · x, with normalization.
/// Apply the structured rotation: `D₃ · H · D₂ · H · D₁ · H · x`, with normalization.
fn apply_srht(&self, buf: &mut [f32]) {
apply_signs_xor(buf, &self.sign_masks[0]);
walsh_hadamard_transform(buf);
Expand All @@ -100,7 +102,7 @@ impl RotationMatrix {
buf.iter_mut().for_each(|val| *val *= norm);
}

/// Apply the inverse SRHT.
/// Apply the inverse structured rotation.
///
/// Forward is: norm · H · D₃ · H · D₂ · H · D₁
/// Inverse is: norm · D₁ · H · D₂ · H · D₃ · H
Expand Down
1 change: 1 addition & 0 deletions vortex-tensor/src/encodings/turboquant/array/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ fn estimate_compression_ratio(bits_per_element: u8, dimensions: u32, num_vectors
// Shared overhead: codebook centroids (2^bit_width f32 values) and
// rotation signs (3 * padded_dim bits).
let num_centroids = 1usize << config.bit_width;
debug_assert!(num_centroids <= TurboQuant::MAX_CENTROIDS);
let overhead_bits = num_centroids * 32 // centroids are always f32
+ 3 * padded_dim; // rotation signs, 1 bit each

Expand Down
41 changes: 27 additions & 14 deletions vortex-tensor/src/encodings/turboquant/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub struct TurboQuantConfig {
impl Default for TurboQuantConfig {
fn default() -> Self {
Self {
bit_width: 8,
bit_width: TurboQuant::MAX_BIT_WIDTH,
seed: Some(42),
}
}
Expand All @@ -55,10 +55,12 @@ impl Default for TurboQuantConfig {
/// Extract elements from a FixedSizeListArray as a flat f32 PrimitiveArray for quantization.
///
/// All quantization (rotation, centroid lookup) happens in f32. f16 is upcast; f64 is truncated.
#[allow(clippy::cast_possible_truncation)]
fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult<PrimitiveArray> {
fn extract_f32_elements(
fsl: &FixedSizeListArray,
ctx: &mut ExecutionCtx,
) -> VortexResult<PrimitiveArray> {
let elements = fsl.elements();
let primitive = elements.to_canonical()?.into_primitive();
let primitive = elements.clone().execute::<PrimitiveArray>(ctx)?;
let ptype = primitive.ptype();

match ptype {
Expand All @@ -71,7 +73,14 @@ fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult<PrimitiveArray
PType::F64 => Ok(primitive
.as_slice::<f64>()
.iter()
.map(|&v| v as f32)
.map(|&v| {
#[expect(
clippy::cast_possible_truncation,
reason = "TurboQuant quantization operates in f32, so f64 inputs are intentionally downcast"
)]
let v = v as f32;
v
})
.collect()),
_ => vortex_bail!("TurboQuant requires float elements, got {ptype:?}"),
}
Expand All @@ -94,7 +103,6 @@ struct QuantizationResult {
/// Norms are computed in the native element precision via the [`L2Norm`] scalar function.
/// The rotation and centroid lookup happen in f32. Null rows (per the input validity) produce
/// all-zero codes.
#[allow(clippy::cast_possible_truncation)]
fn turboquant_quantize_core(
ext: ArrayView<Extension>,
fsl: &FixedSizeListArray,
Expand All @@ -103,7 +111,8 @@ fn turboquant_quantize_core(
validity: &Validity,
ctx: &mut ExecutionCtx,
) -> VortexResult<QuantizationResult> {
let dimension = fsl.list_size() as usize;
let dimension =
usize::try_from(fsl.list_size()).vortex_expect("u32 FixedSizeList dimension fits in usize");
let num_rows = fsl.len();

// Compute native-precision norms via the L2Norm scalar fn. L2Norm propagates validity from
Expand All @@ -127,10 +136,12 @@ fn turboquant_quantize_core(

let rotation = RotationMatrix::try_new(seed, dimension)?;
let padded_dim = rotation.padded_dim();
let padded_dim_u32 =
u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");

let f32_elements = extract_f32_elements(fsl)?;
let f32_elements = extract_f32_elements(fsl, ctx)?;

let centroids = get_centroids(padded_dim as u32, bit_width)?;
let centroids = get_centroids(padded_dim_u32, bit_width)?;
let boundaries = compute_centroid_boundaries(&centroids);

let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
Expand Down Expand Up @@ -173,19 +184,20 @@ fn turboquant_quantize_core(
}

/// Build a `TurboQuantArray` from quantization results.
#[allow(clippy::cast_possible_truncation)]
fn build_turboquant(
fsl: &FixedSizeListArray,
core: QuantizationResult,
ext_dtype: DType,
) -> VortexResult<TurboQuantArray> {
let num_rows = fsl.len();
let padded_dim = core.padded_dim;
let padded_dim_u32 =
u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
let codes_elements =
PrimitiveArray::new::<u8>(core.all_indices.freeze(), Validity::NonNullable);
let codes = FixedSizeListArray::try_new(
codes_elements.into_array(),
padded_dim as u32,
padded_dim_u32,
Validity::NonNullable,
num_rows,
)?
Expand Down Expand Up @@ -220,11 +232,12 @@ pub fn turboquant_encode(
) -> VortexResult<ArrayRef> {
let ext_dtype = ext.dtype().clone();
let storage = ext.storage_array();
let fsl = storage.to_canonical()?.into_fixed_size_list();
let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;

vortex_ensure!(
config.bit_width >= 1 && config.bit_width <= 8,
"bit_width must be 1-8, got {}",
config.bit_width >= 1 && config.bit_width <= TurboQuant::MAX_BIT_WIDTH,
"bit_width must be 1-{}, got {}",
TurboQuant::MAX_BIT_WIDTH,
config.bit_width
);
let dimension = fsl.list_size();
Expand Down
Loading
Loading