Skip to content

Commit 8f6c1ef

Browse files
committed
pull out norms from turboquant
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent d72bf9b commit 8f6c1ef

13 files changed

Lines changed: 296 additions & 303 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub const vortex_tensor::encodings::turboquant::TurboQuant::MAX_CENTROIDS: usize
1616

1717
pub const vortex_tensor::encodings::turboquant::TurboQuant::MIN_DIMENSION: u32
1818

19-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_tensor::encodings::turboquant::TurboQuantArray>
19+
pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_tensor::encodings::turboquant::TurboQuantArray>
2020

2121
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_tensor::vector::VectorMatcherMetadata>
2222

@@ -34,7 +34,7 @@ pub type vortex_tensor::encodings::turboquant::TurboQuant::ArrayData = vortex_te
3434

3535
pub type vortex_tensor::encodings::turboquant::TurboQuant::OperationsVTable = vortex_tensor::encodings::turboquant::TurboQuant
3636

37-
pub type vortex_tensor::encodings::turboquant::TurboQuant::ValidityVTable = vortex_array::array::vtable::validity::ValidityVTableFromChild
37+
pub type vortex_tensor::encodings::turboquant::TurboQuant::ValidityVTable = vortex_tensor::encodings::turboquant::TurboQuant
3838

3939
pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer(_array: vortex_array::array::view::ArrayView<'_, Self>, idx: usize) -> vortex_array::buffer::BufferHandle
4040

@@ -62,9 +62,9 @@ impl vortex_array::array::vtable::operations::OperationsVTable<vortex_tensor::en
6262

6363
pub fn vortex_tensor::encodings::turboquant::TurboQuant::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>, index: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
6464

65-
impl vortex_array::array::vtable::validity::ValidityChild<vortex_tensor::encodings::turboquant::TurboQuant> for vortex_tensor::encodings::turboquant::TurboQuant
65+
impl vortex_array::array::vtable::validity::ValidityVTable<vortex_tensor::encodings::turboquant::TurboQuant> for vortex_tensor::encodings::turboquant::TurboQuant
6666

67-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validity_child(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>) -> vortex_array::array::erased::ArrayRef
67+
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validity(_array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>) -> vortex_error::VortexResult<vortex_array::validity::Validity>
6868

6969
impl vortex_array::arrays::dict::take::TakeExecute for vortex_tensor::encodings::turboquant::TurboQuant
7070

@@ -110,7 +110,7 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantData::padded_dim(&self) -
110110

111111
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> vortex_error::VortexResult<Self>
112112

113-
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::validate(dtype: &vortex_array::dtype::DType, codes: &vortex_array::array::erased::ArrayRef, norms: &vortex_array::array::erased::ArrayRef, centroids: &vortex_array::array::erased::ArrayRef, rotation_signs: &vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<()>
113+
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::validate(dtype: &vortex_array::dtype::DType, codes: &vortex_array::array::erased::ArrayRef, centroids: &vortex_array::array::erased::ArrayRef, rotation_signs: &vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<()>
114114

115115
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantData
116116

@@ -164,8 +164,6 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::centroids(&self
164164

165165
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::codes(&self) -> &vortex_array::array::erased::ArrayRef
166166

167-
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::norms(&self) -> &vortex_array::array::erased::ArrayRef
168-
169167
pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef
170168

171169
impl<T: vortex_array::array::typed::TypedArrayRef<vortex_tensor::encodings::turboquant::TurboQuant>> vortex_tensor::encodings::turboquant::TurboQuantArrayExt for T
@@ -174,8 +172,6 @@ pub fn T::centroids(&self) -> &vortex_array::array::erased::ArrayRef
174172

175173
pub fn T::codes(&self) -> &vortex_array::array::erased::ArrayRef
176174

177-
pub fn T::norms(&self) -> &vortex_array::array::erased::ArrayRef
178-
179175
pub fn T::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef
180176

181177
pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

vortex-tensor/src/encodings/turboquant/array/data.rs

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@ use crate::encodings::turboquant::vtable::TurboQuant;
1919
/// TurboQuant array data.
2020
///
2121
/// TurboQuant is a lossy vector quantization encoding for [`Vector`](crate::vector::Vector)
22-
/// extension arrays. It stores quantized coordinate codes and per-vector norms, along with shared
22+
/// extension arrays. It stores quantized coordinate codes for unit-norm vectors, along with shared
2323
/// codebook centroids and the parameters of the current structured rotation.
2424
///
25+
/// Norms should be stored externally in the [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)
26+
/// `ScalarFnArray` wrapper.
27+
///
2528
/// See the [module docs](crate::encodings::turboquant) for algorithmic details.
2629
///
27-
/// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty.
30+
/// Note that degenerate TurboQuant arrays have zero rows and `bit_width == 0`, with all slots
31+
/// empty.
2832
#[derive(Clone, Debug)]
2933
pub struct TurboQuantData {
3034
/// The vector dimension `d`, cached from the `FixedSizeList` storage dtype's list size.
@@ -95,16 +99,21 @@ impl TurboQuantData {
9599
pub fn validate(
96100
dtype: &DType,
97101
codes: &ArrayRef,
98-
norms: &ArrayRef,
99102
centroids: &ArrayRef,
100103
rotation_signs: &ArrayRef,
101104
) -> VortexResult<()> {
102105
let vector_metadata = TurboQuant::validate_dtype(dtype)?;
103106
let dimension = vector_metadata.dimensions();
104107
let padded_dim = dimension.next_power_of_two();
105108

109+
// TurboQuant arrays are always non-nullable. Nullability should be handled by the external
110+
// L2Denorm ScalarFnArray wrapper.
111+
vortex_ensure!(
112+
!dtype.is_nullable(),
113+
"TurboQuant dtype must be non-nullable, got {dtype}",
114+
);
115+
106116
// Codes must be a non-nullable FixedSizeList<u8> with list_size == padded_dim.
107-
// Null vectors are represented by all-zero codes since validity lives in the norms array.
108117
let expected_codes_dtype = DType::FixedSizeList(
109118
Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
110119
padded_dim,
@@ -116,23 +125,6 @@ impl TurboQuantData {
116125
"codes dtype does not match expected {expected_codes_dtype}",
117126
);
118127

119-
let num_rows = codes.len();
120-
vortex_ensure_eq!(
121-
norms.len(),
122-
num_rows,
123-
"norms length must match codes length",
124-
);
125-
126-
// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
127-
// Norms carry the validity of the entire TurboQuant array.
128-
let element_ptype = vector_metadata.element_ptype();
129-
let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability());
130-
vortex_ensure_eq!(
131-
*norms.dtype(),
132-
expected_norms_dtype,
133-
"norms dtype does not match expected {expected_norms_dtype}",
134-
);
135-
136128
// Centroids are always f32 regardless of element type.
137129
let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
138130
vortex_ensure_eq!(
@@ -154,6 +146,7 @@ impl TurboQuantData {
154146
"rotation_signs dtype does not match expected {expected_signs_dtype}",
155147
);
156148
// Degenerate (empty) case: all children must be empty, and bit_width is 0.
149+
let num_rows = codes.len();
157150
if num_rows == 0 {
158151
vortex_ensure!(
159152
centroids.is_empty(),
@@ -198,13 +191,11 @@ impl TurboQuantData {
198191

199192
pub(crate) fn make_slots(
200193
codes: ArrayRef,
201-
norms: ArrayRef,
202194
centroids: ArrayRef,
203195
rotation_signs: ArrayRef,
204196
) -> Vec<Option<ArrayRef>> {
205197
let mut slots = vec![None; Slot::COUNT];
206198
slots[Slot::Codes as usize] = Some(codes);
207-
slots[Slot::Norms as usize] = Some(norms);
208199
slots[Slot::Centroids as usize] = Some(centroids);
209200
slots[Slot::RotationSigns as usize] = Some(rotation_signs);
210201
slots
@@ -242,12 +233,6 @@ pub trait TurboQuantArrayExt: TypedArrayRef<TurboQuant> {
242233
.vortex_expect("TurboQuantArray codes slot")
243234
}
244235

245-
fn norms(&self) -> &ArrayRef {
246-
self.as_ref().slots()[Slot::Norms as usize]
247-
.as_ref()
248-
.vortex_expect("TurboQuantArray norms slot")
249-
}
250-
251236
fn centroids(&self) -> &ArrayRef {
252237
self.as_ref().slots()[Slot::Centroids as usize]
253238
.as_ref()

vortex-tensor/src/encodings/turboquant/array/slots.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,26 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
/// Slot positions for TurboQuantArray children.
5+
///
6+
/// Norms are not stored in the TurboQuantArray. They live in the external [`L2Denorm`]
7+
/// ScalarFnArray wrapper returned by [`turboquant_encode`].
8+
///
9+
/// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm
10+
/// [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode
511
#[repr(usize)]
612
#[derive(Clone, Copy, Debug)]
713
pub(crate) enum Slot {
814
Codes = 0,
9-
Norms = 1,
10-
Centroids = 2,
11-
RotationSigns = 3,
15+
Centroids = 1,
16+
RotationSigns = 2,
1217
}
1318

1419
impl Slot {
15-
pub(crate) const COUNT: usize = 4;
20+
pub(crate) const COUNT: usize = 3;
1621

1722
pub(crate) fn name(self) -> &'static str {
1823
match self {
1924
Self::Codes => "codes",
20-
Self::Norms => "norms",
2125
Self::Centroids => "centroids",
2226
Self::RotationSigns => "rotation_signs",
2327
}
@@ -26,9 +30,8 @@ impl Slot {
2630
pub(crate) fn from_index(idx: usize) -> Self {
2731
match idx {
2832
0 => Self::Codes,
29-
1 => Self::Norms,
30-
2 => Self::Centroids,
31-
3 => Self::RotationSigns,
33+
1 => Self::Centroids,
34+
2 => Self::RotationSigns,
3235
_ => vortex_error::vortex_panic!("invalid slot index {idx}"),
3336
}
3437
}

vortex-tensor/src/encodings/turboquant/compute/slice.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@ impl SliceReduce for TurboQuant {
1717
array: ArrayView<'_, TurboQuant>,
1818
range: Range<usize>,
1919
) -> VortexResult<Option<ArrayRef>> {
20-
let sliced_codes = array.codes().slice(range.clone())?;
21-
let sliced_norms = array.norms().slice(range)?;
20+
let sliced_codes = array.codes().slice(range)?;
2221

2322
Ok(Some(
2423
TurboQuant::try_new_array(
2524
array.dtype().clone(),
2625
sliced_codes,
27-
sliced_norms,
2826
array.centroids().clone(),
2927
array.rotation_signs().clone(),
3028
)?

vortex-tensor/src/encodings/turboquant/compute/take.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@ impl TakeExecute for TurboQuant {
1919
) -> VortexResult<Option<ArrayRef>> {
2020
// FSL children handle per-row take natively.
2121
let taken_codes = array.codes().take(indices.clone())?;
22-
let taken_norms = array.norms().take(indices.clone())?;
2322

2423
Ok(Some(
2524
TurboQuant::try_new_array(
2625
array.dtype().clone(),
2726
taken_codes,
28-
taken_norms,
2927
array.centroids().clone(),
3028
array.rotation_signs().clone(),
3129
)?

vortex-tensor/src/encodings/turboquant/mod.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@
1717
//! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate)
1818
//! using MSE-optimal scalar quantization on coordinates of a rotated unit vector.
1919
//!
20+
//! The `TurboQuantArray` stores only the quantized unit-norm vector data (codes, centroids,
21+
//! rotation signs). Per-vector L2 norms are stored separately in an [`L2Denorm`] ScalarFnArray
22+
//! wrapper. The [`turboquant_encode`] function returns this wrapper:
23+
//!
24+
//! ```text
25+
//! ScalarFnArray(L2Denorm, [TurboQuantArray, norms])
26+
//! ```
27+
//!
28+
//! When executed, the TQ array decompresses to unit-norm vectors, and the [`L2Denorm`] function
29+
//! lazily re-applies the stored norms to reconstruct the original magnitudes.
30+
//!
31+
//! [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm
32+
//! [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode
33+
//!
2034
//! The TurboQuant paper analyzes a full random orthogonal rotation. The current Vortex
2135
//! implementation instead uses a fixed 3-round Walsh-Hadamard-based structured transform with
2236
//! random sign diagonals. This is a practical approximation chosen for encode/decode efficiency,
@@ -48,9 +62,10 @@
4862
//! # Compression ratios
4963
//!
5064
//! Each vector is stored as `padded_dim * bit_width / 8` bytes of quantized codes plus one stored
51-
//! norm. In the current implementation, that norm uses the vector's element float type, not a
52-
//! separate fixed storage precision. Non-power-of-2 dimensions are padded to the next power of 2
53-
//! for the structured rotation, which reduces the effective ratio for those dimensions.
65+
//! norm (in the [`L2Denorm`] wrapper). In the current implementation, that norm uses the vector's
66+
//! element float type, not a separate fixed storage precision. Non-power-of-2 dimensions are
67+
//! padded to the next power of 2 for the structured rotation, which reduces the effective ratio
68+
//! for those dimensions.
5469
//!
5570
//! The table below assumes f32 input, so the stored norm is 4 bytes.
5671
//!

0 commit comments

Comments
 (0)