Skip to content

Commit 0c289ab

Browse files
lwwmanningclaude
andcommitted
fix[turboquant]: second-round review fixes and merge conflict resolution
- Add TurboQuantMSE and TurboQuantQJL to ALLOWED_ENCODINGS in vortex-file so TurboQuant-encoded files can be deserialized - Fix as_ptype() panic: use primitive.ptype() after to_canonical() instead of calling the panicking as_ptype() on the raw dtype - Move rand_distr to dev-dependencies (only used in tests) - Remove unused vortex-mask dependency - Handle nullable storage in compress_turboquant: return None to fall through to default compression instead of failing - Remove apply_inverse_srht_from_bits (dead code, only used in its own test) and apply_signs_from_bits helper - Fix function-scoped import in gen_random_signs - Add TODO for double f32 extraction in QJL encode - Fix execute() signature after merge with develop (Arc<Array<Self>>) - Collapse nested if-let per clippy Signed-off-by: Will Manning <will@spiraldb.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Will Manning <will@willmanning.io>
1 parent b2ae417 commit 0c289ab

10 files changed

Lines changed: 41 additions & 94 deletions

File tree

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encodings/turboquant/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,15 @@ workspace = true
1919
[dependencies]
2020
prost = { workspace = true }
2121
rand = { workspace = true }
22-
rand_distr = { workspace = true }
2322
vortex-array = { workspace = true }
2423
vortex-buffer = { workspace = true }
2524
vortex-error = { workspace = true }
2625
vortex-fastlanes = { workspace = true }
27-
vortex-mask = { workspace = true }
2826
vortex-session = { workspace = true }
2927
vortex-utils = { workspace = true }
3028
parking_lot = { workspace = true }
3129

3230
[dev-dependencies]
31+
rand_distr = { workspace = true }
3332
rstest = { workspace = true }
3433
vortex-array = { workspace = true, features = ["_test-harness"] }

encodings/turboquant/public-api.lock

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ pub fn vortex_turboquant::rotation::RotationMatrix::rotate(&self, input: &[f32],
2828

2929
pub fn vortex_turboquant::rotation::RotationMatrix::try_new(seed: u64, dimension: usize) -> vortex_error::VortexResult<Self>
3030

31-
pub fn vortex_turboquant::rotation::apply_inverse_srht_from_bits(buf: &mut [f32], signs_bytes: &[u8], padded_dim: usize, norm_factor: f32)
32-
3331
pub struct vortex_turboquant::TurboQuantConfig
3432

3533
pub vortex_turboquant::TurboQuantConfig::bit_width: u8
@@ -86,7 +84,7 @@ pub fn vortex_turboquant::TurboQuantMSE::deserialize(bytes: &[u8], _dtype: &vort
8684

8785
pub fn vortex_turboquant::TurboQuantMSE::dtype(array: &vortex_turboquant::TurboQuantMSEArray) -> &vortex_array::dtype::DType
8886

89-
pub fn vortex_turboquant::TurboQuantMSE::execute(array: alloc::sync::Arc<Self::Array>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::executor::ExecutionResult>
87+
pub fn vortex_turboquant::TurboQuantMSE::execute(array: alloc::sync::Arc<vortex_array::vtable::typed::Array<Self>>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::executor::ExecutionResult>
9088

9189
pub fn vortex_turboquant::TurboQuantMSE::id(&self) -> vortex_array::vtable::dyn_::ArrayId
9290

@@ -232,7 +230,7 @@ pub fn vortex_turboquant::TurboQuantQJL::deserialize(bytes: &[u8], _dtype: &vort
232230

233231
pub fn vortex_turboquant::TurboQuantQJL::dtype(array: &vortex_turboquant::TurboQuantQJLArray) -> &vortex_array::dtype::DType
234232

235-
pub fn vortex_turboquant::TurboQuantQJL::execute(array: alloc::sync::Arc<Self::Array>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::executor::ExecutionResult>
233+
pub fn vortex_turboquant::TurboQuantQJL::execute(array: alloc::sync::Arc<vortex_array::vtable::typed::Array<Self>>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::executor::ExecutionResult>
236234

237235
pub fn vortex_turboquant::TurboQuantQJL::id(&self) -> vortex_array::vtable::dyn_::ArrayId
238236

encodings/turboquant/src/compress.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ pub struct TurboQuantConfig {
3939
#[allow(clippy::cast_possible_truncation)]
4040
fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult<Vec<f32>> {
4141
let elements = fsl.elements();
42-
let ptype = elements.dtype().as_ptype();
4342
let primitive = elements.to_canonical()?.into_primitive();
43+
let ptype = primitive.ptype();
4444

4545
match ptype {
4646
PType::F32 => Ok(primitive.as_slice::<f32>().to_vec()),
@@ -196,6 +196,8 @@ pub fn turboquant_encode_qjl(
196196
return build_empty_qjl_array(fsl, config.bit_width, padded_dim, seed);
197197
}
198198

199+
// TODO(perf): `turboquant_encode_mse` above already extracts f32 elements
200+
// internally. Refactor to share the buffer to avoid double materialization.
199201
let f32_elements = extract_f32_elements(fsl)?;
200202
#[allow(clippy::cast_possible_truncation)]
201203
let centroids = get_centroids(padded_dim as u32, mse_bit_width)?;

encodings/turboquant/src/mse/vtable/mod.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
//! VTable implementation for TurboQuant MSE encoding.
55
66
use std::hash::Hash;
7+
use std::ops::Deref;
78
use std::sync::Arc;
89

910
use vortex_array::ArrayEq;
@@ -22,6 +23,7 @@ use vortex_array::dtype::Nullability;
2223
use vortex_array::dtype::PType;
2324
use vortex_array::serde::ArrayChildren;
2425
use vortex_array::stats::StatsSetRef;
26+
use vortex_array::vtable::Array;
2527
use vortex_array::vtable::ArrayId;
2628
use vortex_array::vtable::NotSupported;
2729
use vortex_array::vtable::VTable;
@@ -209,9 +211,11 @@ impl VTable for TurboQuantMSE {
209211
Ok(())
210212
}
211213

212-
fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
213-
let array = Arc::try_unwrap(array).unwrap_or_else(|arc| (*arc).clone());
214-
Ok(ExecutionResult::done(execute_decompress_mse(array, ctx)?))
214+
fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
215+
let inner = Arc::try_unwrap(array)
216+
.map(|a| a.into_inner())
217+
.unwrap_or_else(|arc| arc.as_ref().deref().clone());
218+
Ok(ExecutionResult::done(execute_decompress_mse(inner, ctx)?))
215219
}
216220
}
217221

encodings/turboquant/src/qjl/vtable/mod.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
//! VTable implementation for TurboQuant QJL encoding.
55
66
use std::hash::Hash;
7+
use std::ops::Deref;
78
use std::sync::Arc;
89

910
use vortex_array::ArrayEq;
@@ -22,6 +23,7 @@ use vortex_array::dtype::Nullability;
2223
use vortex_array::dtype::PType;
2324
use vortex_array::serde::ArrayChildren;
2425
use vortex_array::stats::StatsSetRef;
26+
use vortex_array::vtable::Array;
2527
use vortex_array::vtable::ArrayId;
2628
use vortex_array::vtable::NotSupported;
2729
use vortex_array::vtable::VTable;
@@ -204,9 +206,11 @@ impl VTable for TurboQuantQJL {
204206
Ok(())
205207
}
206208

207-
fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
208-
let array = Arc::try_unwrap(array).unwrap_or_else(|arc| (*arc).clone());
209-
Ok(ExecutionResult::done(execute_decompress_qjl(array, ctx)?))
209+
fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
210+
let inner = Arc::try_unwrap(array)
211+
.map(|a| a.into_inner())
212+
.unwrap_or_else(|arc| arc.as_ref().deref().clone());
213+
Ok(ExecutionResult::done(execute_decompress_qjl(inner, ctx)?))
210214
}
211215
}
212216

encodings/turboquant/src/rotation.rs

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//! For dimensions that are not powers of 2, the input is zero-padded to the
1313
//! next power of 2 before the transform and truncated afterward.
1414
15+
use rand::RngExt;
1516
use rand::SeedableRng;
1617
use rand::rngs::StdRng;
1718
use vortex_array::arrays::BoolArray;
@@ -212,42 +213,8 @@ impl RotationMatrix {
212213
/// contains `3 * padded_dim` bits in inverse-application order `[D₃ | D₂ | D₁]`.
213214
/// Convention: bit set (1) = +1, bit unset (0) = -1 (negate).
214215
///
215-
/// Applies: H → D₃ → H → D₂ → H → D₁ → scale
216-
#[inline]
217-
pub fn apply_inverse_srht_from_bits(
218-
buf: &mut [f32],
219-
signs_bytes: &[u8],
220-
padded_dim: usize,
221-
norm_factor: f32,
222-
) {
223-
debug_assert!(padded_dim.is_power_of_two());
224-
debug_assert_eq!(buf.len(), padded_dim);
225-
226-
for round in 0..3 {
227-
walsh_hadamard_transform(buf);
228-
apply_signs_from_bits(buf, signs_bytes, round * padded_dim);
229-
}
230-
231-
for val in buf.iter_mut() {
232-
*val *= norm_factor;
233-
}
234-
}
235-
236-
/// Element-wise negate coordinates where the sign bit is unset (0 = -1).
237-
#[inline]
238-
fn apply_signs_from_bits(buf: &mut [f32], signs_bytes: &[u8], bit_offset: usize) {
239-
for (j, val) in buf.iter_mut().enumerate() {
240-
let idx = bit_offset + j;
241-
let is_positive = (signs_bytes[idx / 8] >> (idx % 8)) & 1 == 1;
242-
if !is_positive {
243-
*val = -*val;
244-
}
245-
}
246-
}
247-
248216
/// Generate a vector of random ±1 signs.
249217
fn gen_random_signs(rng: &mut StdRng, len: usize) -> Vec<f32> {
250-
use rand::RngExt;
251218
(0..len)
252219
.map(|_| {
253220
if rng.random_bool(0.5) {
@@ -416,48 +383,6 @@ mod tests {
416383
Ok(())
417384
}
418385

419-
/// Verify that the hot-path `apply_inverse_srht_from_bits` matches `inverse_rotate`.
420-
#[rstest]
421-
#[case(64)]
422-
#[case(128)]
423-
#[case(768)]
424-
fn hot_path_matches_inverse_rotate(#[case] dim: usize) -> VortexResult<()> {
425-
let rot = RotationMatrix::try_new(99, dim)?;
426-
let padded_dim = rot.padded_dim();
427-
let norm_factor = rot.norm_factor();
428-
429-
let signs_array = rot.export_inverse_signs_bool_array();
430-
let bit_buf = signs_array.to_bit_buffer();
431-
let (_, _, raw_buf) = bit_buf.into_inner();
432-
433-
// Create some rotated input.
434-
let mut input = vec![0.0f32; padded_dim];
435-
for i in 0..dim {
436-
input[i] = (i as f32 + 1.0) * 0.01;
437-
}
438-
let mut rotated = vec![0.0f32; padded_dim];
439-
rot.rotate(&input, &mut rotated);
440-
441-
// Inverse via the struct method.
442-
let mut recovered1 = vec![0.0f32; padded_dim];
443-
rot.inverse_rotate(&rotated, &mut recovered1);
444-
445-
// Inverse via the hot-path function.
446-
let mut recovered2 = rotated.clone();
447-
apply_inverse_srht_from_bits(&mut recovered2, raw_buf.as_ref(), padded_dim, norm_factor);
448-
449-
for i in 0..padded_dim {
450-
assert!(
451-
(recovered1[i] - recovered2[i]).abs() < 1e-10,
452-
"Hot-path mismatch at {i}: {} vs {}",
453-
recovered1[i],
454-
recovered2[i]
455-
);
456-
}
457-
458-
Ok(())
459-
}
460-
461386
#[test]
462387
fn wht_basic() {
463388
// WHT of [1, 0, 0, 0] should be [1, 1, 1, 1]

vortex-btrblocks/src/canonical_compressor.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,12 @@ impl CanonicalCompressor for BtrBlocksCompressor {
296296
}
297297

298298
// Compress tensor extension types with TurboQuant if configured.
299+
// Falls through to default compression for nullable storage.
299300
if let Some(tq_config) = &self.turboquant_config
300301
&& is_tensor_extension(&ext_array)
302+
&& let Some(compressed) = compress_turboquant(&ext_array, tq_config)?
301303
{
302-
return compress_turboquant(&ext_array, tq_config);
304+
return Ok(compressed);
303305
}
304306

305307
// Compress the underlying storage array.

vortex-btrblocks/src/compressor/turboquant.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ pub(crate) fn is_tensor_extension(ext_array: &ExtensionArray) -> bool {
1818
ext_id.as_ref() == VECTOR_EXT_ID || ext_id.as_ref() == FIXED_SHAPE_TENSOR_EXT_ID
1919
}
2020

21-
/// Compress a tensor extension array using TurboQuant.
21+
/// Try to compress a tensor extension array using TurboQuant.
22+
///
23+
/// Returns `Ok(Some(...))` on success, or `Ok(None)` if the storage is nullable
24+
/// (TurboQuant requires non-nullable input). The caller should fall through to
25+
/// default compression when `None` is returned.
2226
///
2327
/// Produces a `TurboQuantQJLArray` wrapping a `TurboQuantMSEArray`, stored inside
2428
/// the Extension wrapper. All children (codes, norms, centroids, rotation signs,
@@ -27,13 +31,19 @@ pub(crate) fn is_tensor_extension(ext_array: &ExtensionArray) -> bool {
2731
pub(crate) fn compress_turboquant(
2832
ext_array: &ExtensionArray,
2933
config: &TurboQuantConfig,
30-
) -> VortexResult<ArrayRef> {
34+
) -> VortexResult<Option<ArrayRef>> {
3135
let storage = ext_array.storage_array();
3236
let fsl = storage.to_canonical()?.into_fixed_size_list();
3337

38+
if fsl.dtype().is_nullable() {
39+
return Ok(None);
40+
}
41+
3442
// Produce the cascaded QJL(MSE) structure. The layout writer will
3543
// recursively descend into children and compress each one.
3644
let qjl_array = turboquant_encode_qjl(&fsl, config)?;
3745

38-
Ok(ExtensionArray::new(ext_array.ext_dtype().clone(), qjl_array.into_array()).into_array())
46+
Ok(Some(
47+
ExtensionArray::new(ext_array.ext_dtype().clone(), qjl_array.into_array()).into_array(),
48+
))
3949
}

vortex-file/src/strategy.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ use vortex_pco::Pco;
6060
use vortex_runend::RunEnd;
6161
use vortex_sequence::Sequence;
6262
use vortex_sparse::Sparse;
63+
use vortex_turboquant::TurboQuantMSE;
64+
use vortex_turboquant::TurboQuantQJL;
6365
use vortex_utils::aliases::hash_map::HashMap;
6466
use vortex_zigzag::ZigZag;
6567
#[cfg(feature = "zstd")]
@@ -109,6 +111,8 @@ pub static ALLOWED_ENCODINGS: LazyLock<ArrayRegistry> = LazyLock::new(|| {
109111
session.register(Sequence);
110112
session.register(Sparse);
111113
session.register(ZigZag);
114+
session.register(TurboQuantMSE);
115+
session.register(TurboQuantQJL);
112116

113117
#[cfg(feature = "zstd")]
114118
session.register(Zstd);

0 commit comments

Comments
 (0)