Skip to content

Commit 5285a13

Browse files
committed
change defaults and constraints and tests
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent d4c1cdf commit 5285a13

8 files changed

Lines changed: 201 additions & 38 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ impl vortex_tensor::encodings::turboquant::TurboQuant
1010

1111
pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId
1212

13+
pub const vortex_tensor::encodings::turboquant::TurboQuant::MIN_DIMENSION: u32
14+
1315
pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_tensor::encodings::turboquant::TurboQuantArray>
1416

1517
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<&vortex_array::dtype::extension::erased::ExtDTypeRef>

vortex-tensor/src/encodings/turboquant/array/centroids.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use vortex_error::VortexResult;
1515
use vortex_error::vortex_bail;
1616
use vortex_utils::aliases::dash_map::DashMap;
1717

18+
use crate::encodings::turboquant::TurboQuant;
19+
1820
/// Number of numerical integration points for computing conditional expectations.
1921
const INTEGRATION_POINTS: usize = 1000;
2022

@@ -36,8 +38,11 @@ pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> {
3638
if !(1..=8).contains(&bit_width) {
3739
vortex_bail!("TurboQuant bit_width must be 1-8, got {bit_width}");
3840
}
39-
if dimension < 3 {
40-
vortex_bail!("TurboQuant dimension must be >= 3, got {dimension}");
41+
if dimension < TurboQuant::MIN_DIMENSION {
42+
vortex_bail!(
43+
"TurboQuant dimension must be >= {}, got {dimension}",
44+
TurboQuant::MIN_DIMENSION
45+
);
4146
}
4247

4348
if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) {
@@ -306,6 +311,6 @@ mod tests {
306311
assert!(get_centroids(128, 0).is_err());
307312
assert!(get_centroids(128, 9).is_err());
308313
assert!(get_centroids(1, 2).is_err());
309-
assert!(get_centroids(2, 2).is_err());
314+
assert!(get_centroids(127, 2).is_err());
310315
}
311316
}

vortex-tensor/src/encodings/turboquant/array/data.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ impl TurboQuantData {
9292
/// The caller must ensure:
9393
///
9494
/// - `dtype` is a [`Vector`](crate::vector::Vector) extension type whose storage list size
95-
/// is >= 3.
95+
/// is >= [`MIN_DIMENSION`](crate::encodings::turboquant::TurboQuant::MIN_DIMENSION).
9696
/// - `codes` is a non-nullable `FixedSizeListArray<u8>` with `list_size == padded_dim` and
9797
/// `codes.len() == norms.len()`. Null vectors are represented by all-zero codes.
9898
/// - `norms` is a primitive array whose ptype matches the element type of the Vector's storage

vortex-tensor/src/encodings/turboquant/array/scheme.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ mod tests {
114114
/// f32 input at 768-d (padded to 1024) with 1000 vectors should give ~4-6x.
115115
/// f32 input at 1024-d (no padding) should give higher ratio since no waste.
116116
#[rstest]
117-
#[case::f32_768d(32, 768, 1000, 3.5, 8.0)]
118-
#[case::f32_1024d(32, 1024, 1000, 5.0, 9.0)]
119-
#[case::f32_1536d(32, 1536, 1000, 3.0, 8.0)]
120-
#[case::f32_128d(32, 128, 1000, 4.0, 8.0)]
121-
#[case::f64_768d(64, 768, 1000, 7.0, 16.0)]
122-
#[case::f16_768d(16, 768, 1000, 1.5, 4.5)]
117+
#[case::f32_768d(32, 768, 1000, 2.5, 4.0)]
118+
#[case::f32_1024d(32, 1024, 1000, 3.5, 5.0)]
119+
#[case::f32_1536d(32, 1536, 1000, 2.5, 4.0)]
120+
#[case::f32_128d(32, 128, 1000, 3.0, 5.0)]
121+
#[case::f64_768d(64, 768, 1000, 5.0, 7.0)]
122+
#[case::f16_768d(16, 768, 1000, 1.2, 2.0)]
123123
fn compression_ratio_in_expected_range(
124124
#[case] bits_per_element: usize,
125125
#[case] dim: u32,

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pub struct TurboQuantConfig {
4343
impl Default for TurboQuantConfig {
4444
fn default() -> Self {
4545
Self {
46-
bit_width: 4,
46+
bit_width: 8,
4747
seed: Some(42),
4848
}
4949
}
@@ -226,8 +226,9 @@ pub fn turboquant_encode(
226226
);
227227
let dimension = fsl.list_size();
228228
vortex_ensure!(
229-
dimension >= 3,
230-
"TurboQuant requires dimension >= 3, got {dimension}"
229+
dimension >= TurboQuant::MIN_DIMENSION,
230+
"TurboQuant requires dimension >= {}, got {dimension}",
231+
TurboQuant::MIN_DIMENSION
231232
);
232233

233234
if fsl.is_empty() {

vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
//! Approximate cosine similarity in the quantized domain.
55
//!
66
//! Since the SRHT is orthogonal, inner products are preserved in the rotated
7-
//! domain. For two vectors from the same TurboQuant column (same rotation and
8-
//! centroids), we can compute the dot product of their quantized representations
9-
//! without full decompression:
7+
//! domain. For two TurboQuant arrays that share the same SRHT rotation (i.e.,
8+
//! encoded from the same column), we can compute the dot product of their
9+
//! quantized representations without full decompression:
1010
//!
1111
//! ```text
1212
//! cos_approx(a, b) = sum(centroids[code_a[j]] × centroids[code_b[j]])
@@ -85,8 +85,12 @@ fn compute_unit_dots(
8585
Ok(dots)
8686
}
8787

88-
/// Compute approximate cosine similarity for all rows between two TurboQuant
89-
/// arrays (same rotation matrix and codebook) without full decompression.
88+
/// Compute approximate cosine similarity for all rows between two TurboQuant arrays without
89+
/// full decompression.
90+
///
91+
/// Both arrays must share the same rotation (i.e., were encoded from the same TurboQuant
92+
/// column). For this function, results are meaningless if the rotations differ (there are other
93+
/// methods that can allow this, but that is future work).
9094
///
9195
/// Since TurboQuant stores unit-normalized rotated vectors, the dot product of the quantized
9296
/// codes directly approximates cosine similarity without needing the stored norms.
@@ -120,8 +124,12 @@ pub fn cosine_similarity_quantized_column(
120124
})
121125
}
122126

123-
/// Compute approximate dot product for all rows between two TurboQuant
124-
/// arrays (same rotation matrix and codebook) without full decompression.
127+
/// Compute approximate dot product for all rows between two TurboQuant arrays without
128+
/// full decompression.
129+
///
130+
/// Both arrays must share the same SRHT rotation (i.e., were encoded from the same TurboQuant
131+
/// column). For this function, results are meaningless if the rotations differ (there are other
132+
/// methods that can allow this, but that is future work).
125133
///
126134
/// `dot_product(a, b) = ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])`
127135
///

vortex-tensor/src/encodings/turboquant/tests.rs

Lines changed: 155 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,9 @@ fn encode_decode(
152152
// -----------------------------------------------------------------------
153153

154154
#[rstest]
155-
#[case(32, 1)]
156-
#[case(32, 2)]
157-
#[case(32, 3)]
158-
#[case(32, 4)]
155+
#[case(128, 1)]
159156
#[case(128, 2)]
157+
#[case(128, 3)]
160158
#[case(128, 4)]
161159
#[case(128, 6)]
162160
#[case(128, 8)]
@@ -280,8 +278,9 @@ fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> {
280278

281279
#[rstest]
282280
#[case(1)]
283-
#[case(2)]
284-
fn rejects_dimension_below_3(#[case] dim: usize) {
281+
#[case(64)]
282+
#[case(127)]
283+
fn rejects_dimension_below_128(#[case] dim: usize) {
285284
let fsl = make_fsl_small(dim);
286285
let ext = make_vector_ext(&fsl);
287286
let config = TurboQuantConfig {
@@ -340,7 +339,7 @@ fn all_zero_vectors_roundtrip() -> VortexResult<()> {
340339
#[test]
341340
fn f64_input_encodes_successfully() -> VortexResult<()> {
342341
let num_rows = 10;
343-
let dim = 64;
342+
let dim = 128;
344343
let mut rng = StdRng::seed_from_u64(99);
345344
let normal = Normal::new(0.0f64, 1.0).unwrap();
346345

@@ -371,6 +370,48 @@ fn f64_input_encodes_successfully() -> VortexResult<()> {
371370
Ok(())
372371
}
373372

373+
/// Verify that f16 input is accepted and encoded (upcast to f32 internally).
374+
#[test]
375+
fn f16_input_encodes_successfully() -> VortexResult<()> {
376+
let num_rows = 10;
377+
let dim = 128;
378+
let mut rng = StdRng::seed_from_u64(99);
379+
let normal = Normal::new(0.0f32, 1.0).unwrap();
380+
381+
let mut buf = BufferMut::<half::f16>::with_capacity(num_rows * dim);
382+
for _ in 0..(num_rows * dim) {
383+
buf.push(half::f16::from_f32(normal.sample(&mut rng)));
384+
}
385+
let elements = PrimitiveArray::new::<half::f16>(buf.freeze(), Validity::NonNullable);
386+
let fsl = FixedSizeListArray::try_new(
387+
elements.into_array(),
388+
dim.try_into()
389+
.expect("somehow got dimension greater than u32::MAX"),
390+
Validity::NonNullable,
391+
num_rows,
392+
)?;
393+
394+
let ext = make_vector_ext(&fsl);
395+
let config = TurboQuantConfig {
396+
bit_width: 3,
397+
seed: Some(42),
398+
};
399+
let mut ctx = SESSION.create_execution_ctx();
400+
let encoded = turboquant_encode(&ext, &config, &mut ctx)?;
401+
let tq = encoded.as_opt::<TurboQuant>().unwrap();
402+
assert_eq!(tq.norms().len(), num_rows);
403+
assert_eq!(tq.dimension() as usize, dim);
404+
405+
// Verify roundtrip: decode and check reconstruction is reasonable.
406+
let decoded_ext = encoded.execute::<ExtensionArray>(&mut ctx)?;
407+
let decoded_fsl = decoded_ext
408+
.storage_array()
409+
.to_canonical()?
410+
.into_fixed_size_list();
411+
assert_eq!(decoded_fsl.len(), num_rows);
412+
Ok(())
413+
}
414+
374415
// -----------------------------------------------------------------------
375416
// Verification tests for stored metadata
376417
// -----------------------------------------------------------------------
@@ -494,7 +535,7 @@ fn slice_preserves_data() -> VortexResult<()> {
494535

495536
#[test]
496537
fn scalar_at_matches_decompress() -> VortexResult<()> {
497-
let fsl = make_fsl(10, 64, 42);
538+
let fsl = make_fsl(10, 128, 42);
498539
let ext = make_vector_ext(&fsl);
499540
let config = TurboQuantConfig {
500541
bit_width: 3,
@@ -593,7 +634,9 @@ fn cosine_similarity_quantized_accuracy() -> VortexResult<()> {
593634
.sum::<f32>()
594635
};
595636

596-
// 4-bit quantization: expect reasonable accuracy.
637+
// At 4-bit, the theoretical MSE bound per coordinate is ~0.0106 (Theorem 1). For cosine
638+
// similarity (bounded [-1, 1]), the error is bounded roughly by 2*sqrt(MSE) ~ 0.2. We use
639+
// 0.15 as a tighter empirical bound.
597640
let error = (exact_cos - approx_cos).abs();
598641
assert!(
599642
error < 0.15,
@@ -604,6 +647,105 @@ fn cosine_similarity_quantized_accuracy() -> VortexResult<()> {
604647
Ok(())
605648
}
606649

650+
/// Verify approximate dot product in the quantized domain.
651+
///
652+
/// NOTE: The MSE quantizer (TurboQuant_mse) has inherent **multiplicative bias** for inner
653+
/// products — the quantized dot product systematically over- or under-estimates the true value.
654+
/// This is a fundamental property: the paper's `TurboQuant_prod` variant adds QJL specifically
655+
/// to debias inner products, but we only implement the MSE-only variant.
656+
///
657+
/// Even at 8-bit (near-lossless reconstruction, MSE ~4e-5), the quantized-domain dot product
658+
/// can have ~10-15% relative error due to this bias. This tolerance is therefore intentionally
659+
/// loose — we're testing that the approximation is in the right ballpark, not that it's precise.
660+
///
661+
/// TODO(connor): Revisit these tolerances when we have TurboQuant_prod (QJL debiasing).
662+
#[test]
663+
fn dot_product_quantized_accuracy() -> VortexResult<()> {
664+
let fsl = make_fsl(20, 128, 42);
665+
let ext = make_vector_ext(&fsl);
666+
let config = TurboQuantConfig {
667+
bit_width: 8,
668+
seed: Some(123),
669+
};
670+
let mut ctx = SESSION.create_execution_ctx();
671+
let encoded = turboquant_encode(&ext, &config, &mut ctx)?;
672+
let tq = encoded.as_opt::<TurboQuant>().unwrap();
673+
674+
let input_prim = fsl.elements().to_canonical()?.into_primitive();
675+
let input_f32 = input_prim.as_slice::<f32>();
676+
677+
let mut ctx = SESSION.create_execution_ctx();
678+
let pd = tq.padded_dim() as usize;
679+
let norms_prim = tq.norms().clone().execute::<PrimitiveArray>(&mut ctx)?;
680+
let norms = norms_prim.as_slice::<f32>();
681+
let codes_fsl = tq.codes().clone().execute::<FixedSizeListArray>(&mut ctx)?;
682+
let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive();
683+
let all_codes = codes_prim.as_slice::<u8>();
684+
let centroids_prim = tq.centroids().clone().execute::<PrimitiveArray>(&mut ctx)?;
685+
let centroid_vals = centroids_prim.as_slice::<f32>();
686+
687+
for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] {
688+
let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128];
689+
let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128];
690+
691+
let exact_dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum();
692+
693+
let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd];
694+
let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd];
695+
let unit_dot: f32 = codes_a
696+
.iter()
697+
.zip(codes_b.iter())
698+
.map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize])
699+
.sum();
700+
let approx_dot = norms[row_a] * norms[row_b] * unit_dot;
701+
702+
// See doc comment above: 15% relative error is expected due to MSE quantizer bias.
703+
let scale = exact_dot.abs().max(1.0);
704+
let rel_error = (exact_dot - approx_dot).abs() / scale;
705+
assert!(
706+
rel_error < 0.15,
707+
"dot product error too large for ({row_a}, {row_b}): \
708+
exact={exact_dot:.4}, approx={approx_dot:.4}, rel_error={rel_error:.4}"
709+
);
710+
}
711+
Ok(())
712+
}
713+
714+
/// Roundtrip at large embedding dimensions to validate padding and SRHT at common sizes.
715+
///
716+
/// NOTE: The theoretical MSE bound (Theorem 1) is proved for Haar-distributed random orthogonal
717+
/// matrices, not SRHT. The SRHT is a practical O(d log d) approximation that doesn't exactly
718+
/// satisfy the Haar assumption, so empirical MSE can slightly exceed the theoretical bound. We
719+
/// use a 2x multiplier to account for this gap.
720+
///
721+
/// The 1024-d case uses 5-bit instead of 4-bit because at 4-bit the SRHT approximation error
722+
/// at d=1024 pushes MSE ~20% above the 1x theoretical bound (0.0127 vs bound 0.0106).
723+
///
724+
/// TODO(connor): Revisit after Stage 2 block decomposition — at d=768 with block_size=256,
725+
/// the per-block SRHT will be lower-dimensional and may have different error characteristics.
726+
#[rstest]
727+
#[case(768, 4)]
728+
#[case(1024, 5)]
729+
fn large_dimension_roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> {
730+
let num_rows = 10;
731+
let fsl = make_fsl(num_rows, dim, 42);
732+
let config = TurboQuantConfig {
733+
bit_width,
734+
seed: Some(123),
735+
};
736+
let (original, decoded) = encode_decode(&fsl, &config)?;
737+
assert_eq!(decoded.len(), original.len());
738+
739+
let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows);
740+
// 2x slack for the SRHT-vs-Haar gap (see doc comment above).
741+
let bound = 2.0 * theoretical_mse_bound(bit_width);
742+
assert!(
743+
normalized_mse < bound,
744+
"Normalized MSE {normalized_mse:.6} exceeds 2x bound {bound:.6} for dim={dim}, bits={bit_width}",
745+
);
746+
Ok(())
747+
}
748+
607749
/// Verify that the encoded array's dtype is a Vector extension type.
608750
#[test]
609751
fn encoded_dtype_is_vector_extension() -> VortexResult<()> {
@@ -702,7 +844,7 @@ fn nullable_vectors_roundtrip() -> VortexResult<()> {
702844
#[test]
703845
fn nullable_norms_match_validity() -> VortexResult<()> {
704846
let validity = Validity::from_iter([true, false, true, false, true]);
705-
let fsl = make_fsl_with_validity(5, 64, 42, validity);
847+
let fsl = make_fsl_with_validity(5, 128, 42, validity);
706848
let ext = make_vector_ext(&fsl);
707849

708850
let config = TurboQuantConfig {
@@ -729,7 +871,7 @@ fn nullable_norms_match_validity() -> VortexResult<()> {
729871
#[test]
730872
fn nullable_l2_norm_readthrough() -> VortexResult<()> {
731873
let validity = Validity::from_iter([true, false, true, false, true]);
732-
let fsl = make_fsl_with_validity(5, 64, 42, validity);
874+
let fsl = make_fsl_with_validity(5, 128, 42, validity);
733875
let ext = make_vector_ext(&fsl);
734876

735877
let config = TurboQuantConfig {
@@ -749,7 +891,7 @@ fn nullable_l2_norm_readthrough() -> VortexResult<()> {
749891
for row in 0..5 {
750892
if row % 2 == 0 {
751893
assert!(norms.is_valid(row)?, "row {row} should be valid");
752-
let expected: f32 = orig_f32[row * 64..(row + 1) * 64]
894+
let expected: f32 = orig_f32[row * 128..(row + 1) * 128]
753895
.iter()
754896
.map(|&v| v * v)
755897
.sum::<f32>()
@@ -773,7 +915,7 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> {
773915
let validity = Validity::from_iter([
774916
true, true, false, true, true, false, true, false, true, true,
775917
]);
776-
let fsl = make_fsl_with_validity(10, 64, 42, validity);
918+
let fsl = make_fsl_with_validity(10, 128, 42, validity);
777919
let ext = make_vector_ext(&fsl);
778920

779921
let config = TurboQuantConfig {

0 commit comments

Comments
 (0)