Skip to content

Commit b2ae417

Browse files
lwwmanningclaude
andcommitted
fix[turboquant]: address PR review findings
- Reject nullable FixedSizeListArray input in both turboquant_encode_mse and turboquant_encode_qjl with a clear error message. TurboQuant is lossy and cannot preserve null positions. - Fix with_vector_quantization composability: store TurboQuantConfig in the builder and apply at build() time, so it doesn't discard a previously-configured compressor. Document precedence rules. - Export VECTOR_EXT_ID and FIXED_SHAPE_TENSOR_EXT_ID as public constants from vortex-turboquant; import in vortex-btrblocks instead of hardcoding duplicate string literals. - Add QJL roundtrip and inner product bias tests for dim=768 (non- power-of-2 requiring padding to 1024). - Move function-scoped imports to top of test module and benchmark file per CLAUDE.md conventions. - Regenerate public-api.lock. Total: 88 unit tests + 1 doctest. Signed-off-by: Will Manning <will@spiraldb.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Will Manning <will@willmanning.io>
1 parent bfe80e7 commit b2ae417

6 files changed

Lines changed: 54 additions & 23 deletions

File tree

encodings/turboquant/public-api.lock

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ pub fn vortex_turboquant::TurboQuantQJLMetadata::clear(&mut self)
332332

333333
pub fn vortex_turboquant::TurboQuantQJLMetadata::encoded_len(&self) -> usize
334334

335+
pub const vortex_turboquant::FIXED_SHAPE_TENSOR_EXT_ID: &str
336+
337+
pub const vortex_turboquant::VECTOR_EXT_ID: &str
338+
335339
pub fn vortex_turboquant::initialize(session: &mut vortex_session::VortexSession)
336340

337341
pub fn vortex_turboquant::turboquant_encode_mse(fsl: &vortex_array::arrays::fixed_size_list::array::FixedSizeListArray, config: &vortex_turboquant::TurboQuantConfig) -> vortex_error::VortexResult<vortex_turboquant::TurboQuantMSEArray>

encodings/turboquant/src/compress.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use vortex_array::IntoArray;
77
use vortex_array::arrays::BoolArray;
88
use vortex_array::arrays::FixedSizeListArray;
99
use vortex_array::arrays::PrimitiveArray;
10+
use vortex_array::dtype::Nullability;
1011
use vortex_array::dtype::PType;
1112
use vortex_array::validity::Validity;
1213
use vortex_buffer::BitBufferMut;
@@ -59,10 +60,17 @@ fn l2_norm(x: &[f32]) -> f32 {
5960
}
6061

6162
/// Encode a FixedSizeListArray into a `TurboQuantMSEArray`.
63+
///
64+
/// The input must be non-nullable. TurboQuant is a lossy encoding that does not
65+
/// preserve null positions; callers must handle validity externally.
6266
pub fn turboquant_encode_mse(
6367
fsl: &FixedSizeListArray,
6468
config: &TurboQuantConfig,
6569
) -> VortexResult<TurboQuantMSEArray> {
70+
vortex_ensure!(
71+
fsl.dtype().nullability() == Nullability::NonNullable,
72+
"TurboQuant requires non-nullable input, got nullable FixedSizeListArray"
73+
);
6674
vortex_ensure!(
6775
config.bit_width >= 1 && config.bit_width <= 8,
6876
"MSE bit_width must be 1-8, got {}",
@@ -148,10 +156,16 @@ pub fn turboquant_encode_mse(
148156
/// Encode a FixedSizeListArray into a `TurboQuantQJLArray`.
149157
///
150158
/// Produces a cascaded structure: QJLArray wrapping an MSEArray at `bit_width - 1`.
159+
/// The input must be non-nullable. TurboQuant is a lossy encoding that does not
160+
/// preserve null positions; callers must handle validity externally.
151161
pub fn turboquant_encode_qjl(
152162
fsl: &FixedSizeListArray,
153163
config: &TurboQuantConfig,
154164
) -> VortexResult<TurboQuantQJLArray> {
165+
vortex_ensure!(
166+
fsl.dtype().nullability() == Nullability::NonNullable,
167+
"TurboQuant requires non-nullable input, got nullable FixedSizeListArray"
168+
);
155169
vortex_ensure!(
156170
config.bit_width >= 2 && config.bit_width <= 9,
157171
"QJL bit_width must be 2-9, got {}",

encodings/turboquant/src/lib.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ mod mse;
9494
mod qjl;
9595
pub mod rotation;
9696

97+
/// Extension ID for the `Vector` type from `vortex-tensor`.
98+
pub const VECTOR_EXT_ID: &str = "vortex.tensor.vector";
99+
100+
/// Extension ID for the `FixedShapeTensor` type from `vortex-tensor`.
101+
pub const FIXED_SHAPE_TENSOR_EXT_ID: &str = "vortex.tensor.fixed_shape_tensor";
102+
97103
use vortex_array::session::ArraySessionExt;
98104
use vortex_session::VortexSession;
99105

@@ -108,6 +114,11 @@ pub fn initialize(session: &mut VortexSession) {
108114
mod tests {
109115
use std::sync::LazyLock;
110116

117+
use rand::RngExt;
118+
use rand::SeedableRng;
119+
use rand::rngs::StdRng;
120+
use rand_distr::Distribution;
121+
use rand_distr::Normal;
111122
use rstest::rstest;
112123
use vortex_array::IntoArray;
113124
use vortex_array::VortexSessionExecute;
@@ -128,11 +139,6 @@ mod tests {
128139

129140
/// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal).
130141
fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray {
131-
use rand::SeedableRng;
132-
use rand::rngs::StdRng;
133-
use rand_distr::Distribution;
134-
use rand_distr::Normal;
135-
136142
let mut rng = StdRng::seed_from_u64(seed);
137143
let normal = Normal::new(0.0f32, 1.0).unwrap();
138144

@@ -339,6 +345,7 @@ mod tests {
339345
#[case(128, 6)]
340346
#[case(128, 8)]
341347
#[case(128, 9)]
348+
#[case(768, 3)]
342349
fn roundtrip_qjl(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> {
343350
let fsl = make_fsl(10, dim, 42);
344351
let config = TurboQuantConfig {
@@ -357,6 +364,8 @@ mod tests {
357364
#[case(128, 6)]
358365
#[case(128, 8)]
359366
#[case(128, 9)]
367+
#[case(768, 3)]
368+
#[case(768, 4)]
360369
fn qjl_inner_product_bias(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> {
361370
let num_rows = 100;
362371
let fsl = make_fsl(num_rows, dim, 42);
@@ -367,14 +376,10 @@ mod tests {
367376
let (original, decoded) = encode_decode_qjl(&fsl, &config)?;
368377

369378
let num_pairs = 500;
370-
let mut rng = {
371-
use rand::SeedableRng;
372-
rand::rngs::StdRng::seed_from_u64(0)
373-
};
379+
let mut rng = StdRng::seed_from_u64(0);
374380
let mut signed_errors = Vec::with_capacity(num_pairs);
375381

376382
for _ in 0..num_pairs {
377-
use rand::RngExt;
378383
let qi = rng.random_range(0..num_rows);
379384
let xi = rng.random_range(0..num_rows);
380385
if qi == xi {

vortex-btrblocks/src/compressor/turboquant.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@ use vortex_array::ArrayRef;
77
use vortex_array::IntoArray;
88
use vortex_array::arrays::ExtensionArray;
99
use vortex_error::VortexResult;
10+
use vortex_turboquant::FIXED_SHAPE_TENSOR_EXT_ID;
1011
use vortex_turboquant::TurboQuantConfig;
12+
use vortex_turboquant::VECTOR_EXT_ID;
1113
use vortex_turboquant::turboquant_encode_qjl;
1214

13-
/// Extension IDs for tensor types (from vortex-tensor).
14-
const VECTOR_EXT_ID: &str = "vortex.tensor.vector";
15-
const FIXED_SHAPE_TENSOR_EXT_ID: &str = "vortex.tensor.fixed_shape_tensor";
16-
1715
/// Check if an extension array has a tensor extension type.
1816
pub(crate) fn is_tensor_extension(ext_array: &ExtensionArray) -> bool {
1917
let ext_id = ext_array.ext_dtype().id();

vortex-file/src/strategy.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ pub static ALLOWED_ENCODINGS: LazyLock<ArrayRegistry> = LazyLock::new(|| {
125125
/// bulk decoding performance, and IOPS required to perform an indexed read.
126126
pub struct WriteStrategyBuilder {
127127
compressor: Option<Arc<dyn CompressorPlugin>>,
128+
turboquant_config: Option<vortex_turboquant::TurboQuantConfig>,
128129
row_block_size: usize,
129130
field_writers: HashMap<FieldPath, Arc<dyn LayoutStrategy>>,
130131
allow_encodings: Option<ArrayRegistry>,
@@ -137,6 +138,7 @@ impl Default for WriteStrategyBuilder {
137138
fn default() -> Self {
138139
Self {
139140
compressor: None,
141+
turboquant_config: None,
140142
row_block_size: 8192,
141143
field_writers: HashMap::new(),
142144
allow_encodings: Some(ALLOWED_ENCODINGS.clone()),
@@ -237,18 +239,19 @@ impl WriteStrategyBuilder {
237239
/// The TurboQuant array's children (norms, codes) are recursively compressed by the
238240
/// BtrBlocks compressor.
239241
///
242+
/// This can be combined with other builder methods. If a custom compressor is also set
243+
/// via [`with_compressor`](Self::with_compressor), the custom compressor takes precedence
244+
/// and the TurboQuant config is ignored.
245+
///
240246
/// # Examples
241247
///
242248
/// ```ignore
243249
/// WriteStrategyBuilder::default()
244-
/// .with_vector_quantization(TurboQuantConfig { bit_width: 3, .. })
250+
/// .with_vector_quantization(TurboQuantConfig { bit_width: 3, seed: None })
245251
/// .build()
246252
/// ```
247253
pub fn with_vector_quantization(mut self, config: vortex_turboquant::TurboQuantConfig) -> Self {
248-
let btrblocks = BtrBlocksCompressorBuilder::default()
249-
.with_turboquant(config)
250-
.build();
251-
self.compressor = Some(Arc::new(btrblocks));
254+
self.turboquant_config = Some(config);
252255
self
253256
}
254257

@@ -270,6 +273,14 @@ impl WriteStrategyBuilder {
270273
// 5. compress each chunk
271274
let compressing = if let Some(ref compressor) = self.compressor {
272275
CompressingStrategy::new_opaque(buffered, compressor.clone())
276+
} else if let Some(tq_config) = self.turboquant_config {
277+
let btrblocks = BtrBlocksCompressorBuilder::default()
278+
.with_turboquant(tq_config)
279+
.build();
280+
CompressingStrategy::new_opaque(
281+
buffered,
282+
Arc::new(btrblocks) as Arc<dyn CompressorPlugin>,
283+
)
273284
} else {
274285
CompressingStrategy::new_btrblocks(buffered, true)
275286
};

vortex/benches/single_encoding_throughput.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ use rand::prelude::IndexedRandom;
1717
use rand::rngs::StdRng;
1818
use vortex::array::IntoArray;
1919
use vortex::array::ToCanonical;
20+
use vortex::array::arrays::FixedSizeListArray;
2021
use vortex::array::arrays::PrimitiveArray;
2122
use vortex::array::arrays::VarBinViewArray;
2223
use vortex::array::builders::dict::dict_encode;
2324
use vortex::array::builtins::ArrayBuiltins;
25+
use vortex::array::validity::Validity;
2426
use vortex::dtype::PType;
2527
use vortex::encodings::alp::RDEncoder;
2628
use vortex::encodings::alp::alp_encode;
@@ -39,6 +41,7 @@ use vortex::encodings::zstd::ZstdArray;
3941
use vortex_array::VortexSessionExecute;
4042
use vortex_array::dtype::Nullability;
4143
use vortex_array::session::ArraySession;
44+
use vortex_buffer::BufferMut;
4245
use vortex_sequence::SequenceArray;
4346
use vortex_session::VortexSession;
4447

@@ -410,10 +413,6 @@ fn bench_zstd_decompress_string(bencher: Bencher) {
410413

411414
// TurboQuant vector quantization benchmarks
412415

413-
use vortex::array::arrays::FixedSizeListArray;
414-
use vortex::array::validity::Validity;
415-
use vortex_buffer::BufferMut;
416-
417416
const NUM_VECTORS: usize = 1_000;
418417

419418
/// Generate `num_vectors` random f32 vectors of the given dimension using i.i.d.

0 commit comments

Comments
 (0)