Skip to content

Commit 94cdaf2

Browse files
committed
fix tolerance and benchmarks
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent e40d6bd commit 94cdaf2

4 files changed

Lines changed: 104 additions & 16 deletions

File tree

vortex-tensor/src/encodings/turboquant/scheme/compress.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use vortex_array::arrays::FixedSizeListArray;
1818
use vortex_array::arrays::PrimitiveArray;
1919
use vortex_array::arrays::extension::ExtensionArrayExt;
2020
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
21+
use vortex_array::dtype::DType;
2122
use vortex_array::dtype::Nullability;
2223
use vortex_array::dtype::PType;
2324
use vortex_array::match_each_float_ptype;
@@ -39,9 +40,20 @@ use crate::scalar_fns::ApproxOptions;
3940
use crate::scalar_fns::l2_norm::L2Norm;
4041
use crate::vector::AnyVector;
4142

42-
/// Tolerance for the unit-norm check in [`turboquant_encode`]. Each row's L2 norm must be within
43-
/// this distance of 1.0 (or be exactly 0.0 for zero vectors).
44-
const UNIT_NORM_TOLERANCE: f64 = 1e-10;
43+
/// Returns the acceptable unit-norm drift for the given element precision.
44+
///
45+
/// The checked encode path validates the post-normalization storage values, so the tolerance has
46+
/// to account for quantization back into the vector element type.
47+
///
48+
/// These numbers are somewhat arbitrary and are derived from testing reasonable values.
49+
fn unit_norm_tolerance(element_ptype: PType) -> f64 {
50+
match element_ptype {
51+
PType::F16 => 2e-3,
52+
PType::F32 => 1e-6,
53+
PType::F64 => 1e-10,
54+
_ => unreachable!("TurboQuant requires float elements, got {element_ptype:?}"),
55+
}
56+
}
4557

4658
/// Configuration for TurboQuant encoding.
4759
#[derive(Clone, Debug)]
@@ -165,7 +177,7 @@ fn turboquant_quantize_core(
165177
fn build_turboquant(
166178
num_rows: usize,
167179
core: QuantizationResult,
168-
ext_dtype: &vortex_array::dtype::DType,
180+
ext_dtype: &DType,
169181
) -> VortexResult<TurboQuantArray> {
170182
let padded_dim = core.padded_dim;
171183
let padded_dim_u32 =
@@ -199,9 +211,9 @@ fn build_turboquant(
199211
/// **Null vectors are not supported.** The caller must normalize and strip nullability before
200212
/// calling this function, for example via [`normalize_as_l2_denorm`].
201213
///
202-
/// This function validates that every row has L2 norm within `UNIT_NORM_TOLERANCE` of 1.0 (or is
203-
/// exactly 0.0). Use [`turboquant_encode_unchecked`] to skip this check when the caller has just
204-
/// performed normalization.
214+
/// This function validates that every row has L2 norm within a storage-precision-aware tolerance
215+
/// of 1.0 (or is exactly 0.0). Use [`turboquant_encode_unchecked`] to skip this check when the
216+
/// caller has just performed normalization.
205217
///
206218
/// The returned array is a plain [`TurboQuantArray`] that decompresses to unit-norm vectors.
207219
/// The caller is responsible for wrapping it in an [`L2Denorm`] ScalarFnArray if the original
@@ -232,12 +244,13 @@ pub fn turboquant_encode(
232244
.as_extension()
233245
.metadata::<AnyVector>()
234246
.element_ptype();
247+
let tolerance = unit_norm_tolerance(element_ptype);
235248

236249
match_each_float_ptype!(element_ptype, |T| {
237250
for (i, &norm) in norms.as_slice::<T>().iter().enumerate() {
238251
let norm_f64: f64 = ToPrimitive::to_f64(&norm).unwrap_or(f64::NAN);
239252
vortex_ensure!(
240-
norm_f64 == 0.0 || (norm_f64 - 1.0).abs() < UNIT_NORM_TOLERANCE,
253+
norm_f64 == 0.0 || (norm_f64 - 1.0).abs() < tolerance,
241254
"TurboQuant requires unit-norm input, but row {i} has L2 norm {norm_f64:.6} \
242255
(expected 1.0 or 0.0)",
243256
);

vortex-tensor/src/encodings/turboquant/tests.rs

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,17 @@ fn empty_turboquant_parts(
220220
))
221221
}
222222

223+
fn normalized_child(
224+
ext: &ExtensionArray,
225+
ctx: &mut vortex_array::ExecutionCtx,
226+
) -> VortexResult<ArrayRef> {
227+
Ok(
228+
normalize_as_l2_denorm(&ApproxOptions::Exact, ext.as_ref().clone(), ctx)?
229+
.child_at(0)
230+
.clone(),
231+
)
232+
}
233+
223234
// -----------------------------------------------------------------------
224235
// Roundtrip tests
225236
// -----------------------------------------------------------------------
@@ -399,6 +410,44 @@ fn rejects_dimension_below_128(#[case] dim: usize) {
399410
assert!(turboquant_encode(ext.as_view(), &config, &mut ctx).is_err());
400411
}
401412

413+
#[test]
414+
fn checked_encode_accepts_normalized_f16_input() -> VortexResult<()> {
415+
let num_rows = 10;
416+
let dim = 128;
417+
let mut rng = StdRng::seed_from_u64(99);
418+
let normal = Normal::new(0.0f32, 1.0).unwrap();
419+
420+
let mut buf = BufferMut::<half::f16>::with_capacity(num_rows * dim);
421+
for _ in 0..(num_rows * dim) {
422+
buf.push(half::f16::from_f32(normal.sample(&mut rng)));
423+
}
424+
let elements = PrimitiveArray::new::<half::f16>(buf.freeze(), Validity::NonNullable);
425+
let fsl = FixedSizeListArray::try_new(
426+
elements.into_array(),
427+
dim.try_into()
428+
.expect("somehow got dimension greater than u32::MAX"),
429+
Validity::NonNullable,
430+
num_rows,
431+
)?;
432+
433+
let ext = make_vector_ext(&fsl);
434+
let config = TurboQuantConfig {
435+
bit_width: 3,
436+
seed: Some(42),
437+
num_rounds: 3,
438+
};
439+
440+
let mut ctx = SESSION.create_execution_ctx();
441+
let normalized = normalized_child(&ext, &mut ctx)?;
442+
let normalized_ext = normalized
443+
.as_opt::<Extension>()
444+
.vortex_expect("normalized child should be an Extension array");
445+
446+
let encoded = turboquant_encode(normalized_ext, &config, &mut ctx)?;
447+
assert_eq!(encoded.len(), num_rows);
448+
Ok(())
449+
}
450+
402451
fn make_fsl_small(dim: usize) -> FixedSizeListArray {
403452
let mut buf = BufferMut::<f32>::with_capacity(dim);
404453
for i in 0..dim {
@@ -1092,8 +1141,9 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> {
10921141
// -----------------------------------------------------------------------
10931142

10941143
/// Verify that a TurboQuant array (extracted from the L2Denorm wrapper) survives
1095-
/// serialize/deserialize. ScalarFnArray cannot be serialized yet, so we test the TQ child
1096-
/// directly.
1144+
/// serialize/deserialize.
1145+
///
1146+
/// TODO(connor): ScalarFnArray cannot be serialized yet, so we test the TQ child directly.
10971147
#[test]
10981148
fn serde_roundtrip() -> VortexResult<()> {
10991149
use vortex_array::ArrayContext;

vortex-tensor/src/encodings/turboquant/vtable.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use vortex_array::serde::ArrayChildren;
2626
use vortex_array::validity::Validity;
2727
use vortex_array::vtable::VTable;
2828
use vortex_array::vtable::ValidityVTable;
29+
#[cfg(debug_assertions)]
2930
use vortex_error::VortexExpect;
3031
use vortex_error::VortexResult;
3132
use vortex_error::vortex_ensure;

vortex/benches/single_encoding_throughput.rs

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -439,16 +439,20 @@ mod turboquant_benches {
439439
use rand::SeedableRng;
440440
use rand::rngs::StdRng;
441441
use vortex::array::IntoArray;
442+
use vortex::array::arrays::Extension;
442443
use vortex::array::arrays::ExtensionArray;
443444
use vortex::array::arrays::FixedSizeListArray;
444445
use vortex::array::arrays::PrimitiveArray;
446+
use vortex::array::arrays::scalar_fn::ScalarFnArrayExt;
445447
use vortex::array::dtype::extension::ExtDType;
446448
use vortex::array::extension::EmptyMetadata;
447449
use vortex::array::validity::Validity;
448450
use vortex_array::VortexSessionExecute;
449451
use vortex_buffer::BufferMut;
450452
use vortex_tensor::encodings::turboquant::TurboQuantConfig;
451-
use vortex_tensor::encodings::turboquant::turboquant_encode;
453+
use vortex_tensor::encodings::turboquant::turboquant_encode_unchecked;
454+
use vortex_tensor::scalar_fns::ApproxOptions;
455+
use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm;
452456
use vortex_tensor::vector::Vector;
453457

454458
use super::SESSION;
@@ -492,18 +496,35 @@ mod turboquant_benches {
492496
}
493497
}
494498

499+
fn setup_normalized_vector_ext(dim: usize) -> ExtensionArray {
500+
let ext = setup_vector_ext(dim);
501+
let mut ctx = SESSION.create_execution_ctx();
502+
let normalized = normalize_as_l2_denorm(&ApproxOptions::Exact, ext.into_array(), &mut ctx)
503+
.unwrap()
504+
.child_at(0)
505+
.clone();
506+
normalized.execute::<ExtensionArray>(&mut ctx).unwrap()
507+
}
508+
495509
macro_rules! turboquant_bench {
496510
(compress, $dim:literal, $bits:literal, $name:ident) => {
497511
paste! {
498512
#[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit"))]
499513
fn $name(bencher: Bencher) {
500-
let ext = setup_vector_ext($dim);
514+
let normalized_ext = setup_normalized_vector_ext($dim);
501515
let config = turboquant_config($bits);
502516
with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64)
503-
.with_inputs(|| ext.clone())
517+
.with_inputs(|| normalized_ext.clone())
504518
.bench_refs(|a| {
505519
let mut ctx = SESSION.create_execution_ctx();
506-
turboquant_encode(a.as_view(), &config, &mut ctx).unwrap()
520+
let normalized = a
521+
.as_ref()
522+
.as_opt::<Extension>()
523+
.expect("normalized benchmark input should be an Extension array");
524+
// SAFETY: Benchmark inputs are normalized once up front so the timed
525+
// region measures only TurboQuant encoding.
526+
unsafe { turboquant_encode_unchecked(normalized, &config, &mut ctx) }
527+
.unwrap()
507528
});
508529
}
509530
}
@@ -512,10 +533,13 @@ mod turboquant_benches {
512533
paste! {
513534
#[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit"))]
514535
fn $name(bencher: Bencher) {
515-
let ext = setup_vector_ext($dim);
536+
let normalized_ext = setup_normalized_vector_ext($dim);
516537
let config = turboquant_config($bits);
517538
let mut ctx = SESSION.create_execution_ctx();
518-
let compressed = turboquant_encode(ext.as_view(), &config, &mut ctx).unwrap();
539+
let compressed = unsafe {
540+
turboquant_encode_unchecked(normalized_ext.as_view(), &config, &mut ctx)
541+
}
542+
.unwrap();
519543
with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64)
520544
.with_inputs(|| &compressed)
521545
.bench_refs(|a| {

0 commit comments

Comments
 (0)