Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ pub struct vortex_tensor::encodings::turboquant::TurboQuantConfig

pub vortex_tensor::encodings::turboquant::TurboQuantConfig::bit_width: u8

pub vortex_tensor::encodings::turboquant::TurboQuantConfig::num_rounds: u8

pub vortex_tensor::encodings::turboquant::TurboQuantConfig::seed: core::option::Option<u64>

impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantConfig
Expand All @@ -100,11 +102,13 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantData::bit_width(&self) ->

pub fn vortex_tensor::encodings::turboquant::TurboQuantData::dimension(&self) -> u32

pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dimension: u32, bit_width: u8) -> Self
pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dimension: u32, bit_width: u8, num_rounds: u8) -> Self

pub fn vortex_tensor::encodings::turboquant::TurboQuantData::num_rounds(&self) -> u8

pub fn vortex_tensor::encodings::turboquant::TurboQuantData::padded_dim(&self) -> u32

pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dimension: u32, bit_width: u8) -> vortex_error::VortexResult<Self>
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> vortex_error::VortexResult<Self>

pub fn vortex_tensor::encodings::turboquant::TurboQuantData::validate(dtype: &vortex_array::dtype::DType, codes: &vortex_array::array::erased::ArrayRef, norms: &vortex_array::array::erased::ArrayRef, centroids: &vortex_array::array::erased::ArrayRef, rotation_signs: &vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<()>

Expand Down Expand Up @@ -156,34 +160,22 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::scheme_name(&self

pub trait vortex_tensor::encodings::turboquant::TurboQuantArrayExt: vortex_array::array::typed::TypedArrayRef<vortex_tensor::encodings::turboquant::TurboQuant>

pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::bit_width(&self) -> u8

pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::centroids(&self) -> &vortex_array::array::erased::ArrayRef

pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::codes(&self) -> &vortex_array::array::erased::ArrayRef

pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::dimension(&self) -> u32

pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::norms(&self) -> &vortex_array::array::erased::ArrayRef

pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::padded_dim(&self) -> u32

pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef

impl<T: vortex_array::array::typed::TypedArrayRef<vortex_tensor::encodings::turboquant::TurboQuant>> vortex_tensor::encodings::turboquant::TurboQuantArrayExt for T

pub fn T::bit_width(&self) -> u8

pub fn T::centroids(&self) -> &vortex_array::array::erased::ArrayRef

pub fn T::codes(&self) -> &vortex_array::array::erased::ArrayRef

pub fn T::dimension(&self) -> u32

pub fn T::norms(&self) -> &vortex_array::array::erased::ArrayRef

pub fn T::padded_dim(&self) -> u32

pub fn T::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef

pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
Expand Down
89 changes: 50 additions & 39 deletions vortex-tensor/src/encodings/turboquant/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ pub struct TurboQuantData {
///
/// This is 0 for degenerate empty arrays.
pub(crate) bit_width: u8,

/// The number of sign-diagonal + WHT rounds in the structured rotation.
///
/// This is 0 for degenerate empty arrays.
pub(crate) num_rounds: u8,
}

impl TurboQuantData {
Expand All @@ -46,7 +51,7 @@ impl TurboQuantData {
/// Returns an error if:
/// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
/// - `bit_width` is greater than [`MAX_BIT_WIDTH`](TurboQuant::MAX_BIT_WIDTH).
pub fn try_new(dimension: u32, bit_width: u8) -> VortexResult<Self> {
pub fn try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> VortexResult<Self> {
vortex_ensure!(
dimension >= TurboQuant::MIN_DIMENSION,
"TurboQuant requires dimension >= {}, got {dimension}",
Expand All @@ -61,6 +66,7 @@ impl TurboQuantData {
Ok(Self {
dimension,
bit_width,
num_rounds,
})
}

Expand All @@ -72,12 +78,14 @@ impl TurboQuantData {
///
/// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
/// - `bit_width` is in the range `[0, MAX_BIT_WIDTH]`.
/// - `num_rounds` is >= 1 (or 0 for degenerate empty arrays).
///
/// Violating these invariants may produce incorrect results during decompression.
pub unsafe fn new_unchecked(dimension: u32, bit_width: u8) -> Self {
pub unsafe fn new_unchecked(dimension: u32, bit_width: u8, num_rounds: u8) -> Self {
Self {
dimension,
bit_width,
num_rounds,
}
}

Expand Down Expand Up @@ -115,6 +123,36 @@ impl TurboQuantData {
"norms length must match codes length",
);

// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
// Norms carry the validity of the entire TurboQuant array.
let element_ptype = vector_metadata.element_ptype();
let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noooo, we want these to be same or wider. E.g., I think for f16, norms should be f32.

Copy link
Copy Markdown
Contributor Author

@connortsui20 connortsui20 Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we want that as that is an implicit cast. If you wanted to store f16 with f32 precision you should upcast it first (not inside this quantization encoding)

vortex_ensure_eq!(
*norms.dtype(),
expected_norms_dtype,
"norms dtype does not match expected {expected_norms_dtype}",
);

// Centroids are always f32 regardless of element type.
let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
vortex_ensure_eq!(
*centroids.dtype(),
centroids_dtype,
"centroids dtype must be non-nullable f32",
);

// Rotation signs must be a FixedSizeList<u8> with list_size == padded_dim. The FSL length
// is the number of rotation rounds.
let expected_signs_dtype = DType::FixedSizeList(
Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
padded_dim,
Nullability::NonNullable,
);
vortex_ensure_eq!(
*rotation_signs.dtype(),
expected_signs_dtype,
"rotation_signs dtype does not match expected {expected_signs_dtype}",
);
// Degenerate (empty) case: all children must be empty, and bit_width is 0.
if num_rows == 0 {
vortex_ensure!(
Expand All @@ -130,6 +168,11 @@ impl TurboQuantData {
return Ok(());
}

vortex_ensure!(
!rotation_signs.is_empty(),
"rotation_signs must have at least 1 round"
);

// Non-degenerate: derive and validate bit_width from centroids.
let num_centroids = centroids.len();
vortex_ensure!(
Expand All @@ -150,31 +193,6 @@ impl TurboQuantData {
TurboQuant::MAX_BIT_WIDTH
);

// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
// Norms carry the validity of the entire TurboQuant array.
let element_ptype = vector_metadata.element_ptype();
let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability());
vortex_ensure_eq!(
*norms.dtype(),
expected_norms_dtype,
"norms dtype does not match expected {expected_norms_dtype}",
);

// Centroids are always f32 regardless of element type.
let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
vortex_ensure_eq!(
*centroids.dtype(),
centroids_dtype,
"centroids dtype must be non-nullable f32",
);

// Rotation signs count must be 3 * padded_dim.
vortex_ensure_eq!(
rotation_signs.len(),
3 * padded_dim as usize,
"rotation_signs length does not match expected 3 * {padded_dim}",
);

Ok(())
}

Expand Down Expand Up @@ -203,6 +221,11 @@ impl TurboQuantData {
self.bit_width
}

/// The number of sign-diagonal + WHT rounds in the structured rotation.
pub fn num_rounds(&self) -> u8 {
self.num_rounds
}

/// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)).
///
/// The current Walsh-Hadamard-based structured rotation requires power-of-2 input, so
Expand All @@ -213,18 +236,6 @@ impl TurboQuantData {
}

pub trait TurboQuantArrayExt: TypedArrayRef<TurboQuant> {
fn dimension(&self) -> u32 {
std::ops::Deref::deref(self).dimension()
}

fn bit_width(&self) -> u8 {
std::ops::Deref::deref(self).bit_width()
}

fn padded_dim(&self) -> u32 {
std::ops::Deref::deref(self).padded_dim()
}

fn codes(&self) -> &ArrayRef {
self.as_ref().slots()[Slot::Codes as usize]
.as_ref()
Expand Down
Loading