Skip to content

Commit ff12040

Browse files
authored
TurboQuant again! (#7829)
## Summary Tracking issue: #7830 Moves TurboQuant out of `vortex-tensor` into a new `vortex-turboquant` crate. The first commit was mostly copying and pasting a bunch of code, as well as adding the `unpack` method to replace canonicalization. The second commit was cleaning up everything holistically. A lot of the code in `vortex-tensor` was reviewed pretty lightly because we knew that it was unstable, but now that we are more certain about the implementation (not necessarily about the exact design, but the actual implementation of the TQ algorithms), I think it is worth reviewing everything as a whole. ## Testing These tests were mostly there before, but now there are more! --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent f3d5f09 commit ff12040

35 files changed

Lines changed: 3890 additions & 70 deletions

Cargo.lock

Lines changed: 22 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ members = [
1313
"vortex-proto",
1414
"vortex-array",
1515
"vortex-tensor",
16+
"vortex-turboquant",
1617
"vortex-compressor",
1718
"vortex-btrblocks",
1819
"vortex-layout",
@@ -296,6 +297,7 @@ vortex-sequence = { version = "0.1.0", path = "encodings/sequence", default-feat
296297
vortex-session = { version = "0.1.0", path = "./vortex-session", default-features = false }
297298
vortex-sparse = { version = "0.1.0", path = "./encodings/sparse", default-features = false }
298299
vortex-tensor = { version = "0.1.0", path = "./vortex-tensor", default-features = false }
300+
vortex-turboquant = { version = "0.1.0", path = "./vortex-turboquant", default-features = false }
299301
vortex-utils = { version = "0.1.0", path = "./vortex-utils", default-features = false }
300302
vortex-zigzag = { version = "0.1.0", path = "./encodings/zigzag", default-features = false }
301303
vortex-zstd = { version = "0.1.0", path = "./encodings/zstd", default-features = false }

vortex-tensor/public-api.lock

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,10 @@ pub fn vortex_tensor::vector::AnyVector::try_match<'a>(&'a vortex_array::dtype::
528528

529529
pub struct vortex_tensor::vector::Vector
530530

531+
impl vortex_tensor::vector::Vector
532+
533+
pub fn vortex_tensor::vector::Vector::try_new_vector_array(vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
534+
531535
impl core::clone::Clone for vortex_tensor::vector::Vector
532536

533537
pub fn vortex_tensor::vector::Vector::clone(&self) -> vortex_tensor::vector::Vector

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,8 @@ fn turboquant_quantize_core(
205205
let dimension = fsl.list_size() as usize;
206206
let num_rows = fsl.len();
207207

208-
let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?;
209-
let padded_dim = rotation.padded_dim();
208+
let padded_dim = dimension.next_power_of_two();
209+
let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?;
210210
let padded_dim_u32 =
211211
u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
212212

vortex-tensor/src/encodings/turboquant/tests/structural.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ fn sorf_transform_roundtrip_isolation() -> VortexResult<()> {
259259
}
260260

261261
// Forward transform + quantize (mimicking what turboquant_quantize_core does).
262-
let rotation = SorfMatrix::try_new(seed, dim, num_rounds as usize)?;
263-
let padded_dim = rotation.padded_dim();
262+
let padded_dim = dim.next_power_of_two();
263+
let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?;
264264
let centroids = compute_or_get_centroids(padded_dim as u32, 8)?;
265265
let boundaries = compute_centroid_boundaries(&centroids);
266266

vortex-tensor/src/scalar_fns/inner_product.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ impl InnerProduct {
366366
let mut padded_query = vec![0.0f32; padded_dim];
367367
padded_query[..dim].copy_from_slice(flat.as_slice::<f32>());
368368

369-
let rotation = SorfMatrix::try_new(seed, dim, num_rounds)?;
369+
let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds, seed)?;
370370
let mut rotated_query = vec![0.0f32; padded_dim];
371371
rotation.rotate(&padded_query, &mut rotated_query);
372372

@@ -930,7 +930,7 @@ mod tests {
930930
seed: u64,
931931
num_rounds: u8,
932932
) -> VortexResult<Vec<f32>> {
933-
let rotation = SorfMatrix::try_new(seed, dim, num_rounds as usize)?;
933+
let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?;
934934
let mut padded = vec![0.0f32; padded_dim];
935935
let mut rotated = vec![0.0f32; padded_dim];
936936
let mut out = Vec::with_capacity(num_rows * dim);

vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs

Lines changed: 105 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,25 @@ impl SorfMatrix {
7777
/// round-major, block-major order, with each `u64` contributing 64 sign bits in
7878
/// least-significant-bit-first order.
7979
pub fn try_new(seed: u64, dimensions: usize, num_rounds: usize) -> VortexResult<Self> {
80+
Self::try_new_padded(dimensions.next_power_of_two(), num_rounds, seed)
81+
}
82+
83+
/// Create a new structured Walsh-Hadamard-based orthogonal transform for a padded dimension.
84+
///
85+
/// `padded_dimensions` must already be a power of two. Callers that start from an unpadded
86+
/// logical dimension should call [`Self::try_new`] instead.
87+
pub(crate) fn try_new_padded(
88+
padded_dimensions: usize,
89+
num_rounds: usize,
90+
seed: u64,
91+
) -> VortexResult<Self> {
8092
vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}");
93+
vortex_ensure!(
94+
padded_dimensions.is_power_of_two(),
95+
"padded_dimensions must be a power of two, got {padded_dimensions}"
96+
);
8197

82-
let padded_dim = dimensions.next_power_of_two();
98+
let padded_dim = padded_dimensions;
8399
let sign_masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds);
84100

85101
// Compute in f64 for precision, then store as f32 since the WHT operates on f32 buffers.
@@ -132,8 +148,7 @@ impl SorfMatrix {
132148
/// Apply the forward structured transform: `norm · H · D_k · ... · H · D₁ · x`.
133149
fn apply_srht(&self, buf: &mut [f32]) {
134150
for round in 0..self.num_rounds {
135-
let offset = round * self.padded_dim;
136-
apply_signs_xor(buf, &self.sign_masks[offset..offset + self.padded_dim]);
151+
self.apply_signs_xor(buf, round);
137152
walsh_hadamard_transform(buf);
138153
}
139154

@@ -148,14 +163,24 @@ impl SorfMatrix {
148163
fn apply_inverse_srht(&self, buf: &mut [f32]) {
149164
for round in (0..self.num_rounds).rev() {
150165
walsh_hadamard_transform(buf);
151-
let offset = round * self.padded_dim;
152-
apply_signs_xor(buf, &self.sign_masks[offset..offset + self.padded_dim]);
166+
self.apply_signs_xor(buf, round);
153167
}
154168

155169
let norm = self.norm_factor;
156170
buf.iter_mut().for_each(|val| *val *= norm);
157171
}
158172

173+
/// Apply one round's sign masks via XOR on the IEEE 754 sign bit.
174+
///
175+
/// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to
176+
/// multiplying each element by +/-1.0, but avoids FP dependency chains.
177+
fn apply_signs_xor(&self, buf: &mut [f32], round: usize) {
178+
let masks = &self.sign_masks[round * self.padded_dim..][..self.padded_dim];
179+
for (val, &mask) in buf.iter_mut().zip(masks.iter()) {
180+
*val = f32::from_bits(val.to_bits() ^ mask);
181+
}
182+
}
183+
159184
/// Export the sign vectors as a flat `Vec<u8>` of 0/1 values in inverse application order
160185
/// `[D_k | ... | D₁]`.
161186
///
@@ -263,16 +288,6 @@ fn sign_mask_from_word(word: u64, bit_idx: usize) -> u32 {
263288
}
264289
}
265290

266-
/// Apply sign masks via XOR on the IEEE 754 sign bit.
267-
///
268-
/// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to
269-
/// multiplying each element by +/-1.0, but avoids FP dependency chains.
270-
fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) {
271-
for (val, &mask) in buf.iter_mut().zip(masks.iter()) {
272-
*val = f32::from_bits(val.to_bits() ^ mask);
273-
}
274-
}
275-
276291
/// In-place Fast Walsh-Hadamard Transform (FWHT), unnormalized and iterative.
277292
///
278293
/// Input length must be a power of 2. Runs in O(n log n) via `log2(n)` stages of `n / 2`
@@ -327,14 +342,24 @@ mod tests {
327342
.collect()
328343
}
329344

345+
fn dim_to_usize(dim: u32) -> usize {
346+
usize::try_from(dim).unwrap()
347+
}
348+
349+
fn rounds_to_usize(num_rounds: u8) -> usize {
350+
usize::from(num_rounds)
351+
}
352+
330353
#[test]
331354
fn deterministic_from_seed() -> VortexResult<()> {
332-
let r1 = SorfMatrix::try_new(42, 64, 3)?;
333-
let r2 = SorfMatrix::try_new(42, 64, 3)?;
355+
let dim = dim_to_usize(64u32);
356+
let num_rounds = rounds_to_usize(3u8);
357+
let r1 = SorfMatrix::try_new(42u64, dim, num_rounds)?;
358+
let r2 = SorfMatrix::try_new(42u64, dim, num_rounds)?;
334359
let pd = r1.padded_dim();
335360

336361
let mut input = vec![0.0f32; pd];
337-
for i in 0..64 {
362+
for i in 0..dim {
338363
input[i] = i as f32;
339364
}
340365
let mut out1 = vec![0.0f32; pd];
@@ -349,41 +374,58 @@ mod tests {
349374

350375
#[test]
351376
fn export_inverse_signs_matches_golden_words() -> VortexResult<()> {
352-
let rot = SorfMatrix::try_new(42, 64, 2)?;
377+
let dim = dim_to_usize(64u32);
378+
let num_rounds = rounds_to_usize(2u8);
379+
let seed = 42u64;
380+
let rot = SorfMatrix::try_new(seed, dim, num_rounds)?;
381+
let padded_dim = rot.padded_dim();
353382
let actual = rot.export_inverse_signs_u8();
354-
let mut rng = SplitMix64::new(42);
383+
let mut rng = SplitMix64::new(seed);
355384
let round0_word = rng.next_u64();
356385
let round1_word = rng.next_u64();
357386

358-
let mut expected = Vec::with_capacity(128);
359-
expected.extend(unpack_sign_bits(round1_word, 64));
360-
expected.extend(unpack_sign_bits(round0_word, 64));
387+
let mut expected = Vec::with_capacity(num_rounds * padded_dim);
388+
expected.extend(unpack_sign_bits(round1_word, padded_dim));
389+
expected.extend(unpack_sign_bits(round0_word, padded_dim));
361390

362391
assert_eq!(actual, expected);
363392
Ok(())
364393
}
365394

366395
#[test]
367396
fn one_word_generates_64_signs_lsb_first() {
368-
let masks = gen_sign_masks_from_seed(42, 64, 1);
369-
assert_eq!(masks.len(), 64);
397+
let seed = 42u64;
398+
let padded_dim = dim_to_usize(64u32);
399+
let num_rounds = rounds_to_usize(1u8);
400+
let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds);
401+
assert_eq!(masks.len(), padded_dim);
370402

371-
let mut rng = SplitMix64::new(42);
403+
let mut rng = SplitMix64::new(seed);
372404
let word = rng.next_u64();
373-
let expected: Vec<_> = (0..64)
405+
let expected: Vec<_> = (0..padded_dim)
374406
.map(|bit_idx| sign_mask_from_word(word, bit_idx))
375407
.collect();
376408
assert_eq!(masks, expected);
377409
}
378410

411+
#[test]
412+
fn accepts_non_power_of_two_dimensions() -> VortexResult<()> {
413+
let rot = SorfMatrix::try_new(42u64, dim_to_usize(100u32), rounds_to_usize(3u8))?;
414+
assert_eq!(rot.padded_dim(), 128);
415+
Ok(())
416+
}
417+
379418
#[test]
380419
fn tail_block_uses_only_required_bits() {
381-
let masks = gen_sign_masks_from_seed(42, 32, 1);
382-
assert_eq!(masks.len(), 32);
420+
let seed = 42u64;
421+
let padded_dim = dim_to_usize(32u32);
422+
let num_rounds = rounds_to_usize(1u8);
423+
let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds);
424+
assert_eq!(masks.len(), padded_dim);
383425

384-
let mut rng = SplitMix64::new(42);
426+
let mut rng = SplitMix64::new(seed);
385427
let word = rng.next_u64();
386-
let expected: Vec<_> = (0..32)
428+
let expected: Vec<_> = (0..padded_dim)
387429
.map(|bit_idx| sign_mask_from_word(word, bit_idx))
388430
.collect();
389431
assert_eq!(masks, expected);
@@ -392,19 +434,21 @@ mod tests {
392434
/// Verify roundtrip is exact to f32 precision across many dimensions and round counts,
393435
/// including non-power-of-two dimensions that require padding.
394436
#[rstest]
395-
#[case(32, 3)]
396-
#[case(64, 3)]
397-
#[case(100, 3)]
398-
#[case(128, 1)]
399-
#[case(128, 2)]
400-
#[case(128, 3)]
401-
#[case(128, 5)]
402-
#[case(256, 3)]
403-
#[case(512, 3)]
404-
#[case(768, 3)]
405-
#[case(1024, 3)]
406-
fn roundtrip_exact(#[case] dim: usize, #[case] num_rounds: usize) -> VortexResult<()> {
407-
let rot = SorfMatrix::try_new(42, dim, num_rounds)?;
437+
#[case(32u32, 3u8)]
438+
#[case(64u32, 3u8)]
439+
#[case(100u32, 3u8)]
440+
#[case(128u32, 1u8)]
441+
#[case(128u32, 2u8)]
442+
#[case(128u32, 3u8)]
443+
#[case(128u32, 5u8)]
444+
#[case(256u32, 3u8)]
445+
#[case(512u32, 3u8)]
446+
#[case(768u32, 3u8)]
447+
#[case(1024u32, 3u8)]
448+
fn roundtrip_exact(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> {
449+
let dim = dim_to_usize(dim);
450+
let num_rounds = rounds_to_usize(num_rounds);
451+
let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?;
408452
let padded_dim = rot.padded_dim();
409453

410454
let mut input = vec![0.0f32; padded_dim];
@@ -435,12 +479,14 @@ mod tests {
435479

436480
/// Verify norm preservation across dimensions and round counts.
437481
#[rstest]
438-
#[case(128, 1)]
439-
#[case(128, 3)]
440-
#[case(128, 5)]
441-
#[case(768, 3)]
442-
fn preserves_norm(#[case] dim: usize, #[case] num_rounds: usize) -> VortexResult<()> {
443-
let rot = SorfMatrix::try_new(7, dim, num_rounds)?;
482+
#[case(128u32, 1u8)]
483+
#[case(128u32, 3u8)]
484+
#[case(128u32, 5u8)]
485+
#[case(768u32, 3u8)]
486+
fn preserves_norm(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> {
487+
let dim = dim_to_usize(dim);
488+
let num_rounds = rounds_to_usize(num_rounds);
489+
let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?;
444490
let padded_dim = rot.padded_dim();
445491

446492
let mut input = vec![0.0f32; padded_dim];
@@ -465,16 +511,15 @@ mod tests {
465511

466512
/// Verify that export -> [`from_u8_slice`] produces identical transform output.
467513
#[rstest]
468-
#[case(64, 3)]
469-
#[case(128, 1)]
470-
#[case(128, 3)]
471-
#[case(128, 5)]
472-
#[case(768, 3)]
473-
fn sign_export_import_roundtrip(
474-
#[case] dim: usize,
475-
#[case] num_rounds: usize,
476-
) -> VortexResult<()> {
477-
let rot = SorfMatrix::try_new(42, dim, num_rounds)?;
514+
#[case(64u32, 3u8)]
515+
#[case(128u32, 1u8)]
516+
#[case(128u32, 3u8)]
517+
#[case(128u32, 5u8)]
518+
#[case(768u32, 3u8)]
519+
fn sign_export_import_roundtrip(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> {
520+
let dim = dim_to_usize(dim);
521+
let num_rounds = rounds_to_usize(num_rounds);
522+
let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?;
478523
let padded_dim = rot.padded_dim();
479524

480525
let signs_u8 = rot.export_inverse_signs_u8();

vortex-tensor/src/scalar_fns/sorf_transform/tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ fn forward_rotate_and_quantize(
6565
}
6666
}
6767

68-
let rotation = SorfMatrix::try_new(seed, dim, num_rounds)?;
69-
let padded_dim = rotation.padded_dim();
68+
let padded_dim = dim.next_power_of_two();
69+
let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds, seed)?;
7070
let centroids = compute_or_get_centroids(padded_dim as u32, bit_width)?;
7171
let boundaries = compute_centroid_boundaries(&centroids);
7272

vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ impl ScalarFnVTable for SorfTransform {
164164
let f32_elements = elements_prim.into_buffer::<f32>();
165165

166166
// Reconstruct the orthogonal transform matrix from the seed.
167-
let rotation = SorfMatrix::try_new(options.seed, dim, options.num_rounds as usize)?;
167+
let rotation =
168+
SorfMatrix::try_new_padded(padded_dim, options.num_rounds as usize, options.seed)?;
168169

169170
// Inverse transform each row, truncate to original dimension, cast to target type.
170171
match_each_float_ptype!(options.element_ptype, |T| {

vortex-tensor/src/types/vector/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ impl Vector {
4949
/// # Errors
5050
///
5151
/// Returns an error if the [`Vector`] extension dtype rejects the storage array.
52-
pub(crate) fn try_new_vector_array(storage: ArrayRef) -> VortexResult<ArrayRef> {
52+
pub fn try_new_vector_array(storage: ArrayRef) -> VortexResult<ArrayRef> {
5353
ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, storage)
5454
.map(|ext| ext.into_array())
5555
}

0 commit comments

Comments
 (0)