Skip to content

Commit 58028c4

Browse files
committed
Preserve TurboQuant decoded vector norms
Signed-off-by: "Connor Tsui" <connor.tsui20@gmail.com>
1 parent cb477c2 commit 58028c4

12 files changed

Lines changed: 323 additions & 62 deletions

File tree

vortex-turboquant/src/lib.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,33 @@
1919
//! The [`TQEncode`] scalar function first computes and stores the original L2 norm for each vector
2020
//! row, then normalizes each valid nonzero row internally before SORF transform and scalar
2121
//! quantization. The [`TQDecode`] scalar function dequantizes through deterministic centroids,
22-
//! applies the inverse SORF transform, truncates back to the original dimension, and re-applies the
23-
//! stored norm.
22+
//! applies the inverse SORF transform, truncates back to the original dimension, and applies a
23+
//! stored inverse direction-norm correction before re-applying the stored norm.
2424
//!
2525
//! The encoded storage is a row-aligned extension tree:
2626
//!
2727
//! ```text
2828
//! Extension<TurboQuant>(
2929
//! Struct {
3030
//! norms: Primitive<element_ptype, vector_validity>,
31+
//! inv_direction_norms: Primitive<f32, vector_validity>,
3132
//! codes: FixedSizeList<Primitive<u8>, padded_dim, vector_validity>,
3233
//! }
3334
//! )
3435
//! ```
3536
//!
36-
//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Decoded quantized
37-
//! directions are not guaranteed to have unit norm after scalar quantization and inverse transform.
37+
//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Scalar quantization
38+
//! perturbs the transformed unit vector, and inverse SORF plus truncation can leave the decoded
39+
//! quantized direction with norm different from `1.0`. If decode only multiplied that direction by
40+
//! the original row norm, `L2Norm(TQDecode(_))` would not equal the norm of the vector returned by
41+
//! `TQDecode`. TurboQuant therefore stores `inv_direction_norms = 1 / ||decoded_direction||` so
42+
//! decode can first renormalize the lossy quantized direction and then apply the original norm.
43+
//!
44+
//! Storing the correction also keeps future query kernels cheap. Inner product and cosine kernels can
45+
//! rotate a query once and gather against centroids directly; the per-row scale they need is already
46+
//! available as `norms * inv_direction_norms` for inner product and `inv_direction_norms` for cosine.
47+
//! Without this field, those kernels would have to recompute the inverse SORF/truncated norm per row
48+
//! or give up the `TQDecode` norm-preservation invariant.
3849
//!
3950
//! # Source map
4051
//!

vortex-turboquant/src/scalar_fns/compute/l2_norm.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,9 @@ pub(super) fn register(session: &VortexSession) {
3838
/// matches `ExactScalarFn<TQDecode>`. Returns `Ok(None)` for any other shape so the canonical
3939
/// `L2Norm` path runs unchanged.
4040
//
41-
// TODO(vortex-data/vortex#TODO): The TurboQuant storage `norms` field is pre-quantization — it
42-
// is the L2 norm of each original vector before SORF transform and scalar quantization. The
43-
// lossy contract (see `vortex-turboquant/src/lib.rs`) means decoded vectors are not guaranteed
44-
// to be unit-norm, so strictly `l2_norm(tq_decode(x))` may differ slightly from the stored
45-
// norm. We treat the stored norms as authoritative here for parity with the `L2Denorm` fast
46-
// path in `vortex-tensor/src/scalar_fns/l2_norm.rs`. A future fix should recompute norms
47-
// post-quantization.
41+
// This is semantically correct because TurboQuant stores per-row inverse direction norms and
42+
// `TQDecode` applies that correction before re-applying the original row norm. In other words,
43+
// valid nonzero decoded rows preserve the stored L2 norm even though coordinates are lossy.
4844
fn l2_norm_tq_decode_execute_parent(
4945
child: &ArrayRef,
5046
parent: &ArrayRef,

vortex-turboquant/src/scalar_fns/decode.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,10 @@ impl ScalarFnVTable for TQDecode {
153153

154154
/// Decode a `TurboQuant` extension array back into a `Vector` extension array.
155155
///
156-
/// The decoded directions are inverse-transformed, truncated to the original dimension, and
157-
/// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with
158-
/// [`TQEncode`](crate::TQEncode).
156+
/// The decoded directions are inverse-transformed, truncated to the original dimension, normalized
157+
/// by the stored inverse direction norms, and multiplied by the stored row norms. The conversion is
158+
/// lossy and does not roundtrip with [`TQEncode`](crate::TQEncode), but valid nonzero decoded rows
159+
/// preserve the original stored L2 norm.
159160
pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
160161
let parsed = parse_storage(input, ctx)?;
161162
let metadata = parsed.metadata;
@@ -177,6 +178,7 @@ pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexRe
177178
sorf_matrix: &transform,
178179
centroids: &centroids,
179180
norms: &parsed.norms,
181+
inv_direction_norms: &parsed.inv_direction_norms,
180182
codes: &parsed.codes,
181183
},
182184
parsed.vector_validity,
@@ -208,6 +210,7 @@ struct DecodeInputs<'a> {
208210
sorf_matrix: &'a SorfMatrix,
209211
centroids: &'a [f32],
210212
norms: &'a PrimitiveArray,
213+
inv_direction_norms: &'a PrimitiveArray,
211214
codes: &'a PrimitiveArray,
212215
}
213216

@@ -226,6 +229,7 @@ where
226229
let padded_dim = decode.sorf_matrix.padded_dim();
227230
let centroids = decode.centroids;
228231
let norms = decode.norms.as_slice::<T>();
232+
let inv_direction_norms = decode.inv_direction_norms.as_slice::<f32>();
229233
let codes = decode.codes.as_slice::<u8>();
230234
let mask = vector_validity.execute_mask(num_vectors, ctx)?;
231235

@@ -249,11 +253,12 @@ where
249253
decode.sorf_matrix.inverse_transform(&decoded, &mut inverse);
250254

251255
let norm = norms[i];
256+
let inv_direction_norm = inv_direction_norms[i];
252257
for &value in inverse.iter().take(dimensions) {
253258
// `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`,
254259
// `f64`): values outside `f16` range saturate to `±inf` rather than returning
255260
// `None`.
256-
let value = T::from_f32(value)
261+
let value = T::from_f32(value * inv_direction_norm)
257262
.vortex_expect("from_f32 is infallible for supported float types");
258263

259264
// SAFETY: total pushes across all match arms equal `output_len`.

vortex-turboquant/src/scalar_fns/encode.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use vortex_array::IntoArray;
1212
use vortex_array::arrays::Extension;
1313
use vortex_array::arrays::ExtensionArray;
1414
use vortex_array::arrays::FixedSizeListArray;
15+
use vortex_array::arrays::PrimitiveArray;
1516
use vortex_array::arrays::ScalarFnArray;
1617
use vortex_array::arrays::extension::ExtensionArrayExt;
1718
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
@@ -209,7 +210,14 @@ pub(crate) fn encode_vector(
209210
// SAFETY: `tq_normalize_as_l2_denorm` returned this normalized Vector child.
210211
unsafe { turboquant_quantize_core(&normalized_fsl, config, ctx)? }
211212
};
212-
let codes = build_codes_child(num_vectors, core, vector_validity.clone())?;
213+
let inv_direction_norms =
214+
PrimitiveArray::new::<f32>(core.inv_direction_norms, vector_validity.clone()).into_array();
215+
let codes = build_codes_child(
216+
num_vectors,
217+
core.all_indices,
218+
core.padded_dim,
219+
vector_validity.clone(),
220+
)?;
213221

214222
let metadata = TurboQuantMetadata {
215223
element_ptype,
@@ -218,7 +226,13 @@ pub(crate) fn encode_vector(
218226
seed: config.seed(),
219227
num_rounds: config.num_rounds(),
220228
};
221-
let storage = build_storage(norms, codes, num_vectors, vector_validity)?;
229+
let storage = build_storage(
230+
norms,
231+
inv_direction_norms,
232+
codes,
233+
num_vectors,
234+
vector_validity,
235+
)?;
222236

223237
Ok(ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage)?.into_array())
224238
}

vortex-turboquant/src/tests/encode_decode.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ use vortex_array::dtype::PType;
1616
use vortex_array::validity::Validity;
1717
use vortex_buffer::Buffer;
1818
use vortex_error::VortexResult;
19+
use vortex_tensor::scalar_fns::l2_norm::L2Norm;
1920

2021
use super::execute_tq_decode;
2122
use super::execute_tq_encode;
2223
use super::f32_vector_array;
24+
use super::tensor_test_session;
2325
use super::test_session;
2426
use super::turboquant_storage;
2527
use super::vector_array;
@@ -29,6 +31,7 @@ use super::vector_values_f32;
2931
use crate::TurboQuantConfig;
3032
use crate::centroids::compute_or_get_centroids;
3133
use crate::vector::normalize::tq_normalize_as_l2_denorm;
34+
use crate::vector::storage::parse_storage;
3235

3336
#[rstest]
3437
#[case::zero_bits(0, 42, 3)]
@@ -105,6 +108,10 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> {
105108
.unmasked_field_by_name("norms")?
106109
.clone()
107110
.execute(&mut ctx)?;
111+
let inv_direction_norms: PrimitiveArray = storage
112+
.unmasked_field_by_name("inv_direction_norms")?
113+
.clone()
114+
.execute(&mut ctx)?;
108115
let codes: FixedSizeListArray = storage
109116
.unmasked_field_by_name("codes")?
110117
.clone()
@@ -114,13 +121,21 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> {
114121
assert!(!mask.value(1));
115122
assert!(mask.value(2));
116123
assert_eq!(norms.validity()?.nullability(), Nullability::Nullable);
124+
assert_eq!(
125+
inv_direction_norms.validity()?.nullability(),
126+
Nullability::Nullable
127+
);
117128
assert_eq!(codes.validity()?.nullability(), Nullability::Nullable);
118129

119130
let norms_validity = norms.validity()?.execute_mask(3, &mut ctx)?;
131+
let inv_direction_norms_validity = inv_direction_norms.validity()?.execute_mask(3, &mut ctx)?;
120132
let codes_validity = codes.validity()?.execute_mask(3, &mut ctx)?;
121133
assert!(norms_validity.value(0));
122134
assert!(!norms_validity.value(1));
123135
assert!(norms_validity.value(2));
136+
assert!(inv_direction_norms_validity.value(0));
137+
assert!(!inv_direction_norms_validity.value(1));
138+
assert!(inv_direction_norms_validity.value(2));
124139
assert!(codes_validity.value(0));
125140
assert!(!codes_validity.value(1));
126141
assert!(codes_validity.value(2));
@@ -134,6 +149,57 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> {
134149
Ok(())
135150
}
136151

152+
#[test]
153+
fn encode_stores_zero_inv_direction_norm_for_zero_rows() -> VortexResult<()> {
154+
let session = test_session();
155+
let mut ctx = session.create_execution_ctx();
156+
let mut values = vec![0.0f32; 3 * 128];
157+
values[0] = 3.0;
158+
values[1] = 4.0;
159+
values[256] = 1.0;
160+
let input = vector_array(128, &values, Validity::NonNullable)?;
161+
162+
let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?;
163+
let storage = turboquant_storage(encoded, &mut ctx)?;
164+
let inv_direction_norms: PrimitiveArray = storage
165+
.unmasked_field_by_name("inv_direction_norms")?
166+
.clone()
167+
.execute(&mut ctx)?;
168+
169+
let values = inv_direction_norms.as_slice::<f32>();
170+
assert!(values[0].is_finite() && values[0] > 0.0);
171+
assert_eq!(values[1], 0.0);
172+
assert!(values[2].is_finite() && values[2] > 0.0);
173+
Ok(())
174+
}
175+
176+
#[test]
177+
fn decode_preserves_original_l2_norms_for_non_power_of_two_dimensions() -> VortexResult<()> {
178+
let session = tensor_test_session();
179+
let mut ctx = session.create_execution_ctx();
180+
let input = f32_vector_array(129, 3, 0.25, Validity::NonNullable)?;
181+
let config = TurboQuantConfig::try_new(3, 42, 3)?;
182+
183+
let encoded = execute_tq_encode(input, &config, &mut ctx)?;
184+
let expected_norms = parse_storage(encoded.clone(), &mut ctx)?.norms;
185+
let decoded = execute_tq_decode(encoded, &mut ctx)?;
186+
let decoded_norms: PrimitiveArray = L2Norm::try_new_array(decoded, 3)?
187+
.into_array()
188+
.execute(&mut ctx)?;
189+
190+
for (actual, expected) in decoded_norms
191+
.as_slice::<f32>()
192+
.iter()
193+
.zip(expected_norms.as_slice::<f32>())
194+
{
195+
assert!(
196+
(*actual - *expected).abs() <= 1e-4 * expected.max(1.0),
197+
"decoded norm {actual} did not match stored norm {expected}"
198+
);
199+
}
200+
Ok(())
201+
}
202+
137203
#[test]
138204
fn normalize_as_l2_denorm_preserves_child_validity() -> VortexResult<()> {
139205
let session = test_session();

vortex-turboquant/src/tests/kernels.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ const DIM: u32 = 128;
3030

3131
/// Fast path: `L2Norm(TQDecode(tq_arr))` returns the storage `norms` field bit-for-bit.
3232
///
33-
/// The slow path would recompute norms from lossily decoded vectors, which only approximately
34-
/// match the stored norms. Bit-exact equality is the strongest invariant that confirms the
35-
/// session-registered kernel fired.
33+
/// `TQDecode` applies the stored inverse direction-norm correction, so decoded vectors preserve
34+
/// these norms. Bit-exact equality is the strongest invariant that confirms the session-registered
35+
/// kernel fired instead of recomputing.
3636
#[test]
3737
fn l2_norm_over_tq_decode_returns_stored_norms() -> VortexResult<()> {
3838
let session = tensor_test_session();

0 commit comments

Comments
 (0)