Skip to content

Commit 27f4be5

Browse files
committed
Harden TurboQuant L2Norm fast path and document the crate
Fix two correctness bugs in the L2Norm(TQDecode(_)) fast path. (1) The kernel coerced the returned norms to the child's nullability rather than the parent's, so wider-child-validity storage shapes that parse_storage accepts errored out at the dtype invariant. The kernel now coerces to parent.dtype().nullability() and a new test mirrors the malformed.rs shape. (2) The per-row inv_direction_norm computation could store a 0.0 sentinel for finite rows whose squared sum overflowed to +inf in f32 (or a +inf for denormal norm_squared), making decode emit zeros while the kernel returned the nonzero stored norm. Encode now rejects non-finite input norms up front and the denormal recip is guarded by is_normal(); regression tests cover both cases. Several should-fix items go with the must-fix: parse_storage_norms_only lets the kernel skip executing the codes and inv_direction_norms children it does not consume; the parity test pins down the exact new = old * inv_direction_norm[row] relationship rather than asserting "the values differ"; file roundtrip now asserts the new field survives serialization and the kernel still preserves stored norms; tests are parameterized over f16/f32/f64 and across padded vs unpadded dimensions; the kernel result is cross-checked against canonical L2Norm of the materialized decode. The hypothetical defensive metadata check on the kernel is dropped (registry key plus TQDecode signature already enforce shape). The dev-dep on vortex-array switches to workspace = true to match sibling encodings. Over-long doc lines are reflowed. Every type in the crate now has a doc comment, emphasizing the new inv_direction_norms storage child and the 0.0 sentinel semantics. Module docs single-source the storage schema rationale in storage.rs; lib.rs and the scalar-fn modules defer to it. Verified: cargo check, cargo clippy --all-targets --all-features, cargo +nightly fmt --all --check, cargo doc --no-deps, and cargo nextest run (102 tests, +14 new) all clean. Signed-off-by: "Connor Tsui" <connor.tsui20@gmail.com>
1 parent 58028c4 commit 27f4be5

15 files changed

Lines changed: 481 additions & 83 deletions

File tree

vortex-turboquant/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ vortex-utils = { workspace = true, features = ["dashmap"] }
3232
divan = { workspace = true }
3333
rand = { workspace = true }
3434
rstest = { workspace = true }
35-
vortex-array = { path = "../vortex-array", features = ["_test-harness"] }
35+
vortex-array = { workspace = true, features = ["_test-harness"] }
3636
vortex-file = { workspace = true }
3737
vortex-io = { workspace = true }
3838
vortex-layout = { workspace = true }

vortex-turboquant/src/lib.rs

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,31 +34,24 @@
3434
//! )
3535
//! ```
3636
//!
37-
//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Scalar quantization
38-
//! perturbs the transformed unit vector, and inverse SORF plus truncation can leave the decoded
39-
//! quantized direction with norm different from `1.0`. If decode only multiplied that direction by
40-
//! the original row norm, `L2Norm(TQDecode(_))` would not equal the norm of the vector returned by
41-
//! `TQDecode`. TurboQuant therefore stores `inv_direction_norms = 1 / ||decoded_direction||` so
42-
//! decode can first renormalize the lossy quantized direction and then apply the original norm.
43-
//!
44-
//! Storing the correction also keeps future query kernels cheap. Inner product and cosine kernels can
45-
//! rotate a query once and gather against centroids directly; the per-row scale they need is already
46-
//! available as `norms * inv_direction_norms` for inner product and `inv_direction_norms` for cosine.
47-
//! Without this field, those kernels would have to recompute the inverse SORF/truncated norm per row
48-
//! or give up the `TQDecode` norm-preservation invariant.
37+
//! Stored norms are authoritative for future TurboQuant-aware scalar functions. The rationale
38+
//! for the `inv_direction_norms` correction field lives next to the storage layout; see
39+
//! `vector/storage.rs`.
4940
//!
5041
//! # Source map
5142
//!
5243
//! Implementation details are documented next to the code that owns them:
5344
//!
54-
//! - `vector/storage.rs`: physical storage shape, full-length child arrays, and field-level
55-
//! validity for null vectors.
56-
//! - `vector/normalize.rs`: TurboQuant-local normalization and how it differs from the tensor
57-
//! crate's null-row zeroing helper.
58-
//! - `vector/quantize.rs`: SORF transform, centroid lookup, and why invalid rows are skipped rather
59-
//! than quantized.
45+
//! - `vector/storage.rs`: physical storage shape and parsing.
46+
//! - `vector/normalize.rs`: TurboQuant-local normalization and the encode-time finite-norm
47+
//! guard.
48+
//! - `vector/quantize.rs`: SORF transform, centroid lookup, and the per-row
49+
//! `inv_direction_norm` computation.
50+
//! - `scalar_fns/compute/`: session-scoped optimizer kernels that intercept canonical scalar
51+
//! functions over TurboQuant inputs (currently `L2Norm(TQDecode(_))`).
6052
//! - `centroids.rs`: deterministic Max-Lloyd centroid computation and process-local caching.
61-
//! - `sorf/`: the Walsh-Hadamard-based structured transform and the stable SplitMix64 sign stream.
53+
//! - `sorf/`: Walsh-Hadamard-based structured transform plus the stable SplitMix64 sign
54+
//! stream.
6255
//!
6356
//! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL
6457
//! residual correction for unbiased inner-product estimation, and it still uses internal

vortex-turboquant/src/scalar_fns/compute/l2_norm.rs

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@ use vortex_array::arrays::PrimitiveArray;
1111
use vortex_array::arrays::ScalarFn;
1212
use vortex_array::arrays::scalar_fn::ExactScalarFn;
1313
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
14+
use vortex_array::dtype::Nullability;
1415
use vortex_array::optimizer::kernels::ArrayKernelsExt;
1516
use vortex_array::optimizer::kernels::ExecuteParentFn;
1617
use vortex_array::scalar_fn::ScalarFnVTable;
18+
use vortex_array::validity::Validity;
1719
use vortex_error::VortexResult;
1820
use vortex_error::vortex_ensure_eq;
1921
use vortex_session::VortexSession;
2022
use vortex_tensor::scalar_fns::l2_norm::L2Norm;
2123

2224
use crate::TQDecode;
23-
use crate::vector::storage::parse_storage;
24-
use crate::vtable::TurboQuant;
25+
use crate::vector::storage::parse_storage_norms_only;
2526

2627
/// Register the `L2Norm(TQDecode(_))` execute-parent kernel on the session.
2728
pub(super) fn register(session: &VortexSession) {
@@ -34,13 +35,14 @@ pub(super) fn register(session: &VortexSession) {
3435

3536
/// Intercepts `L2Norm(TQDecode(tq_arr))` and returns the stored TurboQuant `norms` field.
3637
///
37-
/// The kernel only fires when both the parent matches `ExactScalarFn<L2Norm>` and the child
38-
/// matches `ExactScalarFn<TQDecode>`. Returns `Ok(None)` for any other shape so the canonical
39-
/// `L2Norm` path runs unchanged.
40-
//
41-
// This is semantically correct because TurboQuant stores per-row inverse direction norms and
42-
// `TQDecode` applies that correction before re-applying the original row norm. In other words,
43-
// valid nonzero decoded rows preserve the stored L2 norm even though coordinates are lossy.
38+
/// Semantically valid because [`TQDecode`] renormalizes the lossy quantized direction with the
39+
/// stored inverse direction-norm before re-applying the original row norm, so decoded rows
40+
/// preserve the stored L2 norm. The kernel returns `Ok(None)` for any non-matching parent /
41+
/// child pair so the canonical `L2Norm` path runs unchanged.
42+
///
43+
/// The result's nullability is coerced to the parent's expected dtype because the stored
44+
/// `norms` child may be wider than the outer struct (a shape [`parse_storage_norms_only`]
45+
/// accepts).
4446
fn l2_norm_tq_decode_execute_parent(
4547
child: &ArrayRef,
4648
parent: &ArrayRef,
@@ -55,24 +57,16 @@ fn l2_norm_tq_decode_execute_parent(
5557
}
5658

5759
let tq_array = child.as_::<ScalarFn>().child_at(0).clone();
60+
let parsed = parse_storage_norms_only(tq_array, ctx)?;
5861

59-
// Defensive: TQDecode's signature already guarantees this, but a misregistration or a
60-
// future TQDecode that takes a wrapped child should fall back to the canonical path.
61-
if tq_array
62-
.dtype()
63-
.as_extension_opt()
64-
.and_then(|d| d.metadata_opt::<TurboQuant>())
65-
.is_none()
66-
{
67-
return Ok(None);
68-
}
69-
70-
let parsed = parse_storage(tq_array, ctx)?;
71-
let norms_validity = parsed.norms.validity()?;
62+
let norms_validity = match parent.dtype().nullability() {
63+
Nullability::NonNullable => Validity::NonNullable,
64+
Nullability::Nullable => parsed.vector_validity,
65+
};
7266
let norms = PrimitiveArray::from_buffer_handle(
7367
parsed.norms.buffer_handle().clone(),
7468
parsed.norms.ptype(),
75-
norms_validity.and(parsed.vector_validity)?,
69+
norms_validity,
7670
)
7771
.into_array();
7872

vortex-turboquant/src/scalar_fns/compute/mod.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@
33

44
//! TurboQuant-specific session-scoped optimizer kernels.
55
//!
6-
//! Each kernel module owns its own [`ArrayKernelsExt::register_execute_parent`] call. New
7-
//! kernels (e.g. for `InnerProduct` or `CosineSimilarity`) should be added as sibling modules
8-
//! and threaded through [`register_kernels`] with a single line.
6+
//! Each kernel module owns its own
7+
//! [`register_execute_parent`](vortex_array::optimizer::kernels::ArrayKernelsExt::register_execute_parent)
8+
//! call. New kernels (for example `InnerProduct` or `CosineSimilarity`) should be added as
9+
//! sibling modules and threaded through [`register_kernels`] with a single line.
910
1011
mod l2_norm;
1112

1213
use vortex_session::VortexSession;
1314

14-
/// Register every TurboQuant kernel on `session`.
15+
/// Register every TurboQuant-specific optimizer kernel on `session`.
16+
///
17+
/// Called from the crate-level [`crate::initialize`] after the TurboQuant extension type and
18+
/// the `TQEncode` / `TQDecode` scalar functions are registered, so kernels can resolve the
19+
/// scalar-fn ids they intercept.
1520
pub(crate) fn register_kernels(session: &VortexSession) {
1621
l2_norm::register(session);
1722
}

vortex-turboquant/src/scalar_fns/decode.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,24 @@ fn build_empty_vector(
205205
})
206206
}
207207

208+
/// Borrowed bundle of the per-array decode inputs passed to the typed inner loop.
209+
///
210+
/// Packaged as a struct rather than positional arguments because `decode_typed` runs through
211+
/// [`vortex_array::match_each_float_ptype!`] which expands once per supported element ptype.
212+
/// Each expansion takes the same set of inputs, and the struct keeps the call site short.
208213
struct DecodeInputs<'a> {
214+
/// TurboQuant metadata recovered from the input extension dtype.
209215
metadata: &'a TurboQuantMetadata,
216+
/// SORF transform reconstructed from `metadata.seed` and `metadata.num_rounds`.
210217
sorf_matrix: &'a SorfMatrix,
218+
/// Centroid codebook for `(padded_dim, bit_width)`, in f32.
211219
centroids: &'a [f32],
220+
/// Per-row stored L2 norm of the original input vector, in the element ptype.
212221
norms: &'a PrimitiveArray,
222+
/// Per-row reciprocal of the decoded direction's L2 norm, always in f32. See
223+
/// [`crate::vector::storage`] for the sentinel semantics.
213224
inv_direction_norms: &'a PrimitiveArray,
225+
/// Flat per-row centroid indices, `num_vectors * padded_dim` bytes.
214226
codes: &'a PrimitiveArray,
215227
}
216228

vortex-turboquant/src/sorf/splitmix64.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ const SPLITMIX64_MUL1: u64 = 0xBF58_476D_1CE4_E5B9;
1919
/// Second SplitMix64 mixing multiplier from the reference implementation.
2020
const SPLITMIX64_MUL2: u64 = 0x94D0_49BB_1331_11EB;
2121

22-
/// Frozen local SplitMix64 stream used to define SORF sign diagonals.
22+
/// Frozen local SplitMix64 stream used to define SORF sign diagonals. Bit-identical to the
23+
/// reference implementation linked at the module top, which makes the sign stream part of the
24+
/// encoding's wire contract.
2325
pub(crate) struct SplitMix64 {
2426
state: u64,
2527
}

vortex-turboquant/src/tests/encode_decode.rs

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,130 @@ fn decode_preserves_original_l2_norms_for_non_power_of_two_dimensions() -> Vorte
200200
Ok(())
201201
}
202202

203+
/// Encode rejects rows whose L2 norm is non-finite. Without this guard, a row whose squared
204+
/// sum overflows would normalize to all-zero placeholders and decode-vs-kernel would silently
205+
/// diverge (`NaN` vs `+inf`).
206+
#[test]
207+
fn encode_rejects_non_finite_norms() -> VortexResult<()> {
208+
let session = test_session();
209+
let mut ctx = session.create_execution_ctx();
210+
211+
// A row of `1e30` repeated `dim=128` times has squared sum `128 * 1e60 ≈ 1.28e62`, which
212+
// overflows `f32` (max ≈ 3.4e38) and produces `+inf` when `L2Norm` runs in `f32`.
213+
let values = vec![1e30f32; 128];
214+
let input = vector_array(128, &values, Validity::NonNullable)?;
215+
216+
let result = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx);
217+
assert!(
218+
result.is_err(),
219+
"encode must reject non-finite norms (overflow case)"
220+
);
221+
let error = result.err().unwrap().to_string();
222+
assert!(
223+
error.contains("non-finite"),
224+
"expected non-finite error, got: {error}"
225+
);
226+
Ok(())
227+
}
228+
229+
/// Encode rejects rows containing `NaN` values, which propagate through `L2Norm` to produce
230+
/// a `NaN` stored norm.
231+
#[test]
232+
fn encode_rejects_nan_input() -> VortexResult<()> {
233+
let session = test_session();
234+
let mut ctx = session.create_execution_ctx();
235+
236+
let mut values = vec![1.0f32; 128];
237+
values[0] = f32::NAN;
238+
let input = vector_array(128, &values, Validity::NonNullable)?;
239+
240+
let result = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx);
241+
assert!(result.is_err(), "encode must reject NaN input rows");
242+
Ok(())
243+
}
244+
245+
/// Decode preserves stored L2 norms across element ptypes and padded/unpadded dimensions.
246+
#[rstest]
247+
#[case::f16_dim_128(PType::F16, 128_u32, 1e-2_f32)]
248+
#[case::f16_dim_129(PType::F16, 129_u32, 1e-2_f32)]
249+
#[case::f32_dim_128(PType::F32, 128_u32, 1e-4_f32)]
250+
#[case::f32_dim_129(PType::F32, 129_u32, 1e-4_f32)]
251+
#[case::f32_dim_257(PType::F32, 257_u32, 1e-4_f32)]
252+
#[case::f64_dim_128(PType::F64, 128_u32, 1e-4_f32)]
253+
#[case::f64_dim_129(PType::F64, 129_u32, 1e-4_f32)]
254+
fn decode_preserves_original_l2_norms_across_ptypes_and_dims(
255+
#[case] ptype: PType,
256+
#[case] dim: u32,
257+
#[case] tolerance: f32,
258+
) -> VortexResult<()> {
259+
let session = tensor_test_session();
260+
let mut ctx = session.create_execution_ctx();
261+
let rows = 3;
262+
let raw = (0..rows * dim as usize)
263+
.map(|i| (i % 17) as f32 - 8.0)
264+
.map(|v| v * 0.25)
265+
.collect::<Vec<_>>();
266+
let input = match ptype {
267+
PType::F16 => {
268+
let values: Vec<half::f16> = raw.iter().copied().map(half::f16::from_f32).collect();
269+
vector_array(dim, &values, Validity::NonNullable)?
270+
}
271+
PType::F32 => vector_array(dim, &raw, Validity::NonNullable)?,
272+
PType::F64 => {
273+
let values: Vec<f64> = raw.iter().copied().map(f64::from).collect();
274+
vector_array(dim, &values, Validity::NonNullable)?
275+
}
276+
_ => unreachable!("ptype must be float"),
277+
};
278+
let config = TurboQuantConfig::try_new(3, 42, 3)?;
279+
280+
let encoded = execute_tq_encode(input, &config, &mut ctx)?;
281+
let decoded = execute_tq_decode(encoded, &mut ctx)?;
282+
let decoded_norms: PrimitiveArray = L2Norm::try_new_array(decoded, rows)?
283+
.into_array()
284+
.execute(&mut ctx)?;
285+
286+
// L2Norm returns the element ptype; widen to f32 for comparison.
287+
let actuals: Vec<f32> = match ptype {
288+
PType::F16 => decoded_norms
289+
.as_slice::<half::f16>()
290+
.iter()
291+
.map(|v| f32::from(*v))
292+
.collect(),
293+
PType::F32 => decoded_norms.as_slice::<f32>().to_vec(),
294+
PType::F64 => decoded_norms
295+
.as_slice::<f64>()
296+
.iter()
297+
.map(|v| {
298+
#[expect(
299+
clippy::cast_possible_truncation,
300+
reason = "norms are bounded by the test's input magnitudes (~|raw| * dim^0.5), \
301+
well within f32 range"
302+
)]
303+
let widened = *v as f32;
304+
widened
305+
})
306+
.collect(),
307+
_ => unreachable!(),
308+
};
309+
310+
// Recompute expected from the raw f32 input to avoid coupling to internal storage.
311+
let expected: Vec<f32> = (0..rows)
312+
.map(|i| {
313+
let row = &raw[i * dim as usize..][..dim as usize];
314+
row.iter().map(|v| v * v).sum::<f32>().sqrt()
315+
})
316+
.collect();
317+
318+
for (actual, exp) in actuals.iter().zip(expected.iter()) {
319+
assert!(
320+
(*actual - *exp).abs() <= tolerance * exp.max(1.0),
321+
"decoded norm {actual} did not match expected {exp} (ptype {ptype:?}, dim {dim})"
322+
);
323+
}
324+
Ok(())
325+
}
326+
203327
#[test]
204328
fn normalize_as_l2_denorm_preserves_child_validity() -> VortexResult<()> {
205329
let session = test_session();

vortex-turboquant/src/tests/file.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
use vortex_array::IntoArray;
55
use vortex_array::VortexSessionExecute;
6+
use vortex_array::arrays::PrimitiveArray;
67
use vortex_array::stream::ArrayStreamExt;
78
use vortex_array::validity::Validity;
89
use vortex_error::VortexResult;
910
use vortex_file::OpenOptionsSessionExt;
1011
use vortex_file::VortexWriteOptions;
1112
use vortex_io::runtime::BlockingRuntime;
1213
use vortex_io::runtime::single::SingleThreadRuntime;
14+
use vortex_tensor::scalar_fns::l2_norm::L2Norm;
1315
use vortex_tensor::vector::Vector;
1416

1517
use super::execute_tq_decode_from_metadata;
@@ -19,6 +21,7 @@ use super::file_session;
1921
use super::vector_validity;
2022
use crate::TQDecode;
2123
use crate::TurboQuantConfig;
24+
use crate::vector::storage::parse_storage;
2225
use crate::vtable::tq_metadata;
2326

2427
#[test]
@@ -46,6 +49,54 @@ fn file_roundtrip_with_initialize_session() -> VortexResult<()> {
4649
Ok(())
4750
}
4851

52+
/// File-roundtrip preserves `inv_direction_norms` and the `L2Norm(TQDecode(_))` fast-path
53+
/// invariant. A regression that silently dropped the field at serialization would only show
54+
/// up downstream as norm divergence; this test surfaces it at the IO layer.
55+
#[test]
56+
fn file_roundtrip_preserves_inv_direction_norms_and_l2_norm_invariant() -> VortexResult<()> {
57+
let runtime = SingleThreadRuntime::default();
58+
let session = file_session(&runtime);
59+
let mut ctx = session.create_execution_ctx();
60+
let input = f32_vector_array(128, 4, 0.25, Validity::NonNullable)?;
61+
let config = TurboQuantConfig::try_new(3, 42, 3)?;
62+
let encoded = execute_tq_encode(input, &config, &mut ctx)?;
63+
let original_norms: PrimitiveArray = parse_storage(encoded.clone(), &mut ctx)?.norms;
64+
65+
let mut file_bytes = Vec::new();
66+
VortexWriteOptions::new(session.clone())
67+
.blocking(&runtime)
68+
.write(&mut file_bytes, encoded.to_array_iterator())?;
69+
70+
let file = session.open_options().open_buffer(file_bytes)?;
71+
let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?;
72+
73+
// The inv_direction_norms field must survive serialization with finite-positive values for
74+
// every valid row.
75+
let parsed = parse_storage(read.clone(), &mut ctx)?;
76+
let inv_direction_norms = parsed.inv_direction_norms.as_slice::<f32>();
77+
assert_eq!(inv_direction_norms.len(), 4);
78+
for &v in inv_direction_norms {
79+
assert!(
80+
v.is_finite() && v > 0.0,
81+
"inv_direction_norm {v} after file roundtrip is not finite-positive"
82+
);
83+
}
84+
85+
// Fast-path `L2Norm(TQDecode(_))` must still return the originally stored row norms after
86+
// the file roundtrip. If the kernel or the `inv_direction_norms` field had silently broken
87+
// at serialization, this is where it would surface.
88+
let decoded = TQDecode::try_new_array(read)?.into_array();
89+
let kernel_norms: PrimitiveArray = L2Norm::try_new_array(decoded, 4)?
90+
.into_array()
91+
.execute(&mut ctx)?;
92+
assert_eq!(
93+
kernel_norms.as_slice::<f32>(),
94+
original_norms.as_slice::<f32>(),
95+
"L2Norm(TQDecode(read_back)) must equal the originally stored row norms"
96+
);
97+
Ok(())
98+
}
99+
49100
#[test]
50101
fn file_roundtrip_lazy_decode_scalar_fn_with_initialize_session() -> VortexResult<()> {
51102
let runtime = SingleThreadRuntime::default();

0 commit comments

Comments
 (0)