Skip to content

Commit eb802bf

Browse files
committed
fix tq norm validation and other logic
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 8f6c1ef commit eb802bf

7 files changed

Lines changed: 197 additions & 43 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ pub const vortex_tensor::encodings::turboquant::TurboQuant::MAX_CENTROIDS: usize
1616

1717
pub const vortex_tensor::encodings::turboquant::TurboQuant::MIN_DIMENSION: u32
1818

19+
pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuant::new_array_unchecked(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_tensor::encodings::turboquant::TurboQuantArray
20+
1921
pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_tensor::encodings::turboquant::TurboQuantArray>
2022

2123
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_tensor::vector::VectorMatcherMetadata>
@@ -176,6 +178,8 @@ pub fn T::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef
176178

177179
pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
178180

181+
pub unsafe fn vortex_tensor::encodings::turboquant::turboquant_encode_unchecked(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
182+
179183
pub type vortex_tensor::encodings::turboquant::TurboQuantArray = vortex_array::array::typed::Array<vortex_tensor::encodings::turboquant::TurboQuant>
180184

181185
pub mod vortex_tensor::fixed_shape
@@ -454,6 +458,8 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::return_dtype(&self, _opti
454458

455459
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>
456460

461+
pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(options: &vortex_tensor::scalar_fns::ApproxOptions, input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
462+
457463
pub mod vortex_tensor::scalar_fns::l2_norm
458464

459465
pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm

vortex-tensor/src/encodings/turboquant/mod.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,17 @@
8686
//! use vortex_array::arrays::ExtensionArray;
8787
//! use vortex_array::arrays::FixedSizeListArray;
8888
//! use vortex_array::arrays::PrimitiveArray;
89+
//! use vortex_array::arrays::Extension;
90+
//! use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
8991
//! use vortex_array::dtype::extension::ExtDType;
9092
//! use vortex_array::extension::EmptyMetadata;
9193
//! use vortex_array::validity::Validity;
9294
//! use vortex_buffer::BufferMut;
9395
//! use vortex_array::session::ArraySession;
9496
//! use vortex_session::VortexSession;
95-
//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode};
97+
//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode_unchecked};
98+
//! use vortex_tensor::scalar_fns::ApproxOptions;
99+
//! use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm;
96100
//! use vortex_tensor::vector::Vector;
97101
//!
98102
//! // Create a Vector extension array of 100 random 128-d vectors.
@@ -110,14 +114,23 @@
110114
//! .unwrap().erased();
111115
//! let ext = ExtensionArray::new(ext_dtype, fsl.into_array());
112116
//!
113-
//! // Quantize at 2 bits per coordinate.
114-
//! let config = TurboQuantConfig { bit_width: 2, seed: Some(42), num_rounds: 3 };
117+
//! // Normalize, then quantize the normalized child at 2 bits per coordinate.
115118
//! let session = VortexSession::empty().with::<ArraySession>();
116119
//! let mut ctx = session.create_execution_ctx();
117-
//! let encoded = turboquant_encode(ext.as_view(), &config, &mut ctx).unwrap();
120+
//! let l2_denorm = normalize_as_l2_denorm(
121+
//! &ApproxOptions::Exact, ext.into_array(), &mut ctx,
122+
//! ).unwrap();
123+
//! let normalized = l2_denorm.child_at(0).clone();
124+
//!
125+
//! let normalized_ext = normalized.as_opt::<Extension>().unwrap();
126+
//! let config = TurboQuantConfig { bit_width: 2, seed: Some(42), num_rounds: 3 };
127+
//! // SAFETY: We just normalized the input.
128+
//! let tq = unsafe {
129+
//! turboquant_encode_unchecked(normalized_ext, &config, &mut ctx).unwrap()
130+
//! };
118131
//!
119132
//! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input.
120-
//! assert!(encoded.nbytes() < 51200);
133+
//! assert!(tq.nbytes() < 51200);
121134
//! ```
122135
123136
mod array;
@@ -137,6 +150,7 @@ mod scheme;
137150
pub use scheme::TurboQuantScheme;
138151
pub use scheme::compress::TurboQuantConfig;
139152
pub use scheme::compress::turboquant_encode;
153+
pub use scheme::compress::turboquant_encode_unchecked;
140154

141155
#[cfg(test)]
142156
mod tests;

vortex-tensor/src/encodings/turboquant/scheme/compress.rs

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//! externally by [`normalize_as_l2_denorm`](crate::scalar_fns::l2_denorm::normalize_as_l2_denorm),
99
//! which the [`TurboQuantScheme`](super::TurboQuantScheme) calls before invoking this function.
1010
11+
use num_traits::ToPrimitive;
1112
use vortex_array::ArrayRef;
1213
use vortex_array::ArrayView;
1314
use vortex_array::ExecutionCtx;
@@ -19,6 +20,7 @@ use vortex_array::arrays::extension::ExtensionArrayExt;
1920
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
2021
use vortex_array::dtype::Nullability;
2122
use vortex_array::dtype::PType;
23+
use vortex_array::match_each_float_ptype;
2224
use vortex_array::validity::Validity;
2325
use vortex_buffer::BufferMut;
2426
use vortex_error::VortexExpect;
@@ -33,6 +35,13 @@ use crate::encodings::turboquant::array::centroids::find_nearest_centroid;
3335
use crate::encodings::turboquant::array::centroids::get_centroids;
3436
use crate::encodings::turboquant::array::rotation::RotationMatrix;
3537
use crate::encodings::turboquant::vtable::TurboQuantArray;
38+
use crate::scalar_fns::ApproxOptions;
39+
use crate::scalar_fns::l2_norm::L2Norm;
40+
use crate::vector::AnyVector;
41+
42+
/// Tolerance for the unit-norm check in [`turboquant_encode`]. Each row's L2 norm must be within
43+
/// this distance of 1.0 (or be exactly 0.0 for zero vectors).
44+
const UNIT_NORM_TOLERANCE: f64 = 1e-10;
3645

3746
/// Configuration for TurboQuant encoding.
3847
#[derive(Clone, Debug)]
@@ -99,8 +108,9 @@ struct QuantizationResult {
99108

100109
/// Core quantization: rotate and quantize already-normalized rows.
101110
///
102-
/// The input `fsl` must contain unit-norm vectors (already L2-normalized). The rotation and
103-
/// centroid lookup happen in f32.
111+
/// The input `fsl` must contain non-nullable, unit-norm vectors (already L2-normalized). Null
112+
/// vectors are not supported and must be zeroed out before reaching this function. The rotation
113+
/// and centroid lookup happen in f32.
104114
fn turboquant_quantize_core(
105115
fsl: &FixedSizeListArray,
106116
seed: u64,
@@ -186,7 +196,12 @@ fn build_turboquant(
186196
/// [`TurboQuantArray`].
187197
///
188198
/// The input must be a non-nullable Vector extension array whose rows are already unit-norm.
189-
/// Normalization is handled externally (e.g. by [`normalize_as_l2_denorm`]).
199+
/// **Null vectors are not supported.** The caller must normalize and strip nullability before
200+
/// calling this function, for example via [`normalize_as_l2_denorm`].
201+
///
202+
/// This function validates that every row has L2 norm within `UNIT_NORM_TOLERANCE` of 1.0 (or is
203+
/// exactly 0.0). Use [`turboquant_encode_unchecked`] to skip this check when the caller has just
204+
/// performed normalization.
190205
///
191206
/// The returned array is a plain [`TurboQuantArray`] that decompresses to unit-norm vectors.
192207
/// The caller is responsible for wrapping it in an [`L2Denorm`] ScalarFnArray if the original
@@ -200,13 +215,61 @@ pub fn turboquant_encode(
200215
ctx: &mut ExecutionCtx,
201216
) -> VortexResult<ArrayRef> {
202217
let ext_dtype = ext.dtype().clone();
203-
let storage = ext.storage_array();
204-
let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;
205218

206219
vortex_ensure!(
207220
!ext_dtype.is_nullable(),
208221
"TurboQuant input must be non-nullable (normalize first via L2Denorm), got {ext_dtype}",
209222
);
223+
224+
// Validate that all rows are unit-norm (or zero).
225+
let num_rows = ext.as_ref().len();
226+
if num_rows > 0 {
227+
let norms_sfn =
228+
L2Norm::try_new_array(&ApproxOptions::Exact, ext.as_ref().clone(), num_rows)?;
229+
let norms: PrimitiveArray = norms_sfn.into_array().execute(ctx)?;
230+
231+
let element_ptype = ext_dtype
232+
.as_extension()
233+
.metadata::<AnyVector>()
234+
.element_ptype();
235+
236+
match_each_float_ptype!(element_ptype, |T| {
237+
for (i, &norm) in norms.as_slice::<T>().iter().enumerate() {
238+
let norm_f64: f64 = ToPrimitive::to_f64(&norm).unwrap_or(f64::NAN);
239+
vortex_ensure!(
240+
norm_f64 == 0.0 || (norm_f64 - 1.0).abs() < UNIT_NORM_TOLERANCE,
241+
"TurboQuant requires unit-norm input, but row {i} has L2 norm {norm_f64:.6} \
242+
(expected 1.0 or 0.0)",
243+
);
244+
}
245+
});
246+
}
247+
248+
// SAFETY: We just validated that the input is non-nullable and all rows are unit-norm.
249+
unsafe { turboquant_encode_unchecked(ext, config, ctx) }
250+
}
251+
252+
/// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a
253+
/// [`TurboQuantArray`], without validating the unit-norm precondition.
254+
///
255+
/// # Safety
256+
///
257+
/// The caller must ensure:
258+
///
259+
/// - The input dtype is non-nullable.
260+
/// - Every row is L2-normalized (unit norm) or is a zero vector.
261+
///
262+
/// Passing non-unit-norm vectors will not cause memory unsafety, but will produce silently
263+
/// incorrect quantization results.
264+
pub unsafe fn turboquant_encode_unchecked(
265+
ext: ArrayView<Extension>,
266+
config: &TurboQuantConfig,
267+
ctx: &mut ExecutionCtx,
268+
) -> VortexResult<ArrayRef> {
269+
let ext_dtype = ext.dtype().clone();
270+
let storage = ext.storage_array();
271+
let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;
272+
210273
vortex_ensure!(
211274
config.bit_width >= 1 && config.bit_width <= TurboQuant::MAX_BIT_WIDTH,
212275
"bit_width must be 1-{}, got {}",

vortex-tensor/src/encodings/turboquant/scheme/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use vortex_error::VortexResult;
2727

2828
use crate::encodings::turboquant::TurboQuant;
2929
use crate::encodings::turboquant::TurboQuantConfig;
30-
use crate::encodings::turboquant::turboquant_encode;
30+
use crate::encodings::turboquant::turboquant_encode_unchecked;
3131
use crate::scalar_fns::ApproxOptions;
3232
use crate::scalar_fns::l2_denorm::L2Denorm;
3333
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
@@ -112,7 +112,9 @@ impl Scheme for TurboQuantScheme {
112112
.as_opt::<Extension>()
113113
.vortex_expect("normalized child should be an Extension array");
114114
let config = TurboQuantConfig::default();
115-
let tq = turboquant_encode(normalized_ext, &config, &mut ctx)?;
115+
// SAFETY: We just normalized the input via `normalize_as_l2_denorm`, so all rows are
116+
// guaranteed to be unit-norm (or zero for originally-null rows).
117+
let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx)? };
116118

117119
// Reassemble L2Denorm(TurboQuant, norms).
118120
Ok(L2Denorm::try_new_array(&ApproxOptions::Exact, tq, norms, num_rows)?.into_array())

vortex-tensor/src/encodings/turboquant/tests.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use rstest::rstest;
1111
use vortex_array::ArrayRef;
1212
use vortex_array::IntoArray;
1313
use vortex_array::VortexSessionExecute;
14+
use vortex_array::arrays::Extension;
1415
use vortex_array::arrays::ExtensionArray;
1516
use vortex_array::arrays::FixedSizeListArray;
1617
use vortex_array::arrays::PrimitiveArray;
@@ -24,6 +25,7 @@ use vortex_array::extension::EmptyMetadata;
2425
use vortex_array::session::ArraySession;
2526
use vortex_array::validity::Validity;
2627
use vortex_buffer::BufferMut;
28+
use vortex_error::VortexExpect;
2729
use vortex_error::VortexResult;
2830
use vortex_session::VortexSession;
2931

@@ -32,6 +34,7 @@ use crate::encodings::turboquant::TurboQuantArrayExt;
3234
use crate::encodings::turboquant::TurboQuantConfig;
3335
use crate::encodings::turboquant::array::rotation::RotationMatrix;
3436
use crate::encodings::turboquant::turboquant_encode;
37+
use crate::encodings::turboquant::turboquant_encode_unchecked;
3538
use crate::scalar_fns::ApproxOptions;
3639
use crate::scalar_fns::l2_denorm::L2Denorm;
3740
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
@@ -100,7 +103,7 @@ fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray {
100103
/// Full encode pipeline: normalize, then TQ-encode, then wrap in L2Denorm.
101104
///
102105
/// This mirrors what `TurboQuantScheme::compress()` does: normalize via `normalize_as_l2_denorm`,
103-
/// then quantize the normalized child via `turboquant_encode`, then reassemble.
106+
/// then quantize the normalized child via `turboquant_encode_unchecked`, then reassemble.
104107
fn normalize_and_encode(
105108
ext: &ExtensionArray,
106109
config: &TurboQuantConfig,
@@ -112,9 +115,10 @@ fn normalize_and_encode(
112115
let num_rows = l2_denorm.len();
113116

114117
let normalized_ext = normalized
115-
.as_opt::<vortex_array::arrays::Extension>()
116-
.expect("normalized child should be an Extension array");
117-
let tq = turboquant_encode(normalized_ext, config, ctx)?;
118+
.as_opt::<Extension>()
119+
.vortex_expect("normalized child should be an Extension array");
120+
// SAFETY: We just normalized the input via `normalize_as_l2_denorm`.
121+
let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx)? };
118122

119123
Ok(L2Denorm::try_new_array(&ApproxOptions::Exact, tq, norms, num_rows)?.into_array())
120124
}

vortex-tensor/src/encodings/turboquant/vtable.rs

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use vortex_array::serde::ArrayChildren;
2626
use vortex_array::validity::Validity;
2727
use vortex_array::vtable::VTable;
2828
use vortex_array::vtable::ValidityVTable;
29+
use vortex_error::VortexExpect;
2930
use vortex_error::VortexResult;
3031
use vortex_error::vortex_ensure;
3132
use vortex_error::vortex_ensure_eq;
@@ -89,7 +90,8 @@ impl TurboQuant {
8990
/// Nullability is handled externally by the [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)
9091
/// ScalarFnArray wrapper.
9192
///
92-
/// Internally calls [`TurboQuantData::validate`] and [`TurboQuantData::try_new`].
93+
/// Internally calls [`TurboQuantData::validate`] and [`TurboQuantData::try_new`], then
94+
/// delegates to [`new_array_unchecked`](Self::new_array_unchecked).
9395
pub fn try_new_array(
9496
dtype: DType,
9597
codes: ArrayRef,
@@ -98,25 +100,66 @@ impl TurboQuant {
98100
) -> VortexResult<TurboQuantArray> {
99101
TurboQuantData::validate(&dtype, &codes, &centroids, &rotation_signs)?;
100102

103+
Ok(unsafe { Self::new_array_unchecked(dtype, codes, centroids, rotation_signs) })
104+
}
105+
106+
/// Creates a new [`TurboQuantArray`] without validation.
107+
///
108+
/// # Safety
109+
///
110+
/// The caller must ensure all invariants required by [`TurboQuantData::validate`] hold:
111+
///
112+
/// - `dtype` is a non-nullable [`Vector`](crate::vector::Vector) extension type with
113+
/// dimension >= [`MIN_DIMENSION`](Self::MIN_DIMENSION).
114+
/// - `codes` is a non-nullable `FixedSizeList<u8>` with `list_size == padded_dim`.
115+
/// - `centroids` is a non-nullable `Primitive<f32>` with a power-of-2 length in
116+
/// `[2, MAX_CENTROIDS]` (or empty for degenerate arrays).
117+
/// - `rotation_signs` is a non-nullable `FixedSizeList<u8>` with `list_size == padded_dim`.
118+
///
119+
/// Violating these invariants may produce incorrect results during decompression or panics
120+
/// during array access.
121+
pub unsafe fn new_array_unchecked(
122+
dtype: DType,
123+
codes: ArrayRef,
124+
centroids: ArrayRef,
125+
rotation_signs: ArrayRef,
126+
) -> TurboQuantArray {
127+
#[cfg(debug_assertions)]
128+
TurboQuantData::validate(&dtype, &codes, &centroids, &rotation_signs)
129+
.vortex_expect("[DEBUG ASSERTION]: TurboQuantData arrays are invalid");
130+
101131
let len = codes.len();
102-
let vector_metadata = TurboQuant::validate_dtype(&dtype)?;
132+
133+
let dimension = dtype
134+
.as_extension_opt()
135+
.and_then(|ext| ext.metadata_opt::<AnyVector>())
136+
.map(|m| m.dimensions())
137+
.unwrap_or(0);
103138

104139
let bit_width = if centroids.is_empty() {
105140
0
106141
} else {
107-
u8::try_from(centroids.len().trailing_zeros())
108-
.map_err(|_| vortex_err!("centroids bit_width does not fit in u8"))?
142+
#[expect(
143+
clippy::cast_possible_truncation,
144+
reason = "bit_width is guaranteed <= 8"
145+
)]
146+
(centroids.len().trailing_zeros() as u8)
109147
};
110148

111-
// Derive num_rounds from the FSL rotation_signs length (0 for degenerate arrays).
112-
let num_rounds = u8::try_from(rotation_signs.len())
113-
.map_err(|_| vortex_err!("rotation_signs num_rounds does not fit in u8"))?;
149+
#[expect(
150+
clippy::cast_possible_truncation,
151+
reason = "num_rounds fits in u8 by the caller's invariants"
152+
)]
153+
let num_rounds = rotation_signs.len() as u8;
114154

115-
let data = TurboQuantData::try_new(vector_metadata.dimensions(), bit_width, num_rounds)?;
155+
// SAFETY: The caller guarantees that dimension, bit_width, and num_rounds satisfy the
156+
// invariants documented on `TurboQuantData::new_unchecked`.
157+
let data = unsafe { TurboQuantData::new_unchecked(dimension, bit_width, num_rounds) };
116158
let parts = ArrayParts::new(TurboQuant, dtype, len, data)
117159
.with_slots(TurboQuantData::make_slots(codes, centroids, rotation_signs));
118160

119-
Array::try_from_parts(parts)
161+
// SAFETY: The caller guarantees the parts are logically consistent.
162+
unsafe { Array::from_parts_unchecked(parts) }
120163
}
121164
}
122165

0 commit comments

Comments
 (0)