Skip to content

Commit 5590fee

Browse files
lwwmanningclaude
andcommitted
refactor[turboquant]: restructure into subdirectory modules, delete dead code
Restructure the turboquant crate to follow the fastlanes encoding pattern where each encoding type gets its own subdirectory with array/ and vtable/ subdirectories: mse/ mod.rs — marker struct + re-exports array/mod.rs — TurboQuantMSEArray struct + accessors vtable/mod.rs — VTable + ValidityChild impls qjl/ mod.rs — marker struct + re-exports array/mod.rs — TurboQuantQJLArray struct + accessors vtable/mod.rs — VTable + ValidityChild impls Delete all dead code: - Remove old monolithic array.rs (TurboQuantArray, TurboQuantVariant) - Remove old mse_array.rs, qjl_array.rs flat files - Remove old rules.rs - Remove legacy decode functions from decompress.rs - Remove TurboQuantVariant from TurboQuantConfig (now just bit_width + seed) Update all consumers: - BtrBlocks compressor (already using new API) - Benchmarks: turboquant_encode → turboquant_encode_mse - lib.rs: use glob re-exports (pub use mse::*, pub use qjl::*) - Docstring example updated for new API Signed-off-by: Will Manning <will@spiraldb.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Will Manning <will@willmanning.io>
1 parent 3901305 commit 5590fee

13 files changed

Lines changed: 574 additions & 1968 deletions

File tree

encodings/turboquant/public-api.lock

Lines changed: 153 additions & 543 deletions
Large diffs are not rendered by default.

encodings/turboquant/src/array.rs

Lines changed: 0 additions & 410 deletions
This file was deleted.

encodings/turboquant/src/compress.rs

Lines changed: 7 additions & 286 deletions
Original file line numberDiff line numberDiff line change
@@ -16,82 +16,24 @@ use vortex_error::vortex_bail;
1616
use vortex_error::vortex_ensure;
1717
use vortex_fastlanes::bitpack_compress::bitpack_encode;
1818

19-
use crate::array::TurboQuantArray;
20-
use crate::array::TurboQuantVariant;
2119
use crate::centroids::find_nearest_centroid;
2220
use crate::centroids::get_centroids;
23-
use crate::mse_array::TurboQuantMSEArray;
24-
use crate::qjl_array::TurboQuantQJLArray;
21+
use crate::mse::array::TurboQuantMSEArray;
22+
use crate::qjl::array::TurboQuantQJLArray;
2523
use crate::rotation::RotationMatrix;
2624

2725
/// Configuration for TurboQuant encoding.
2826
#[derive(Clone, Debug)]
2927
pub struct TurboQuantConfig {
30-
/// Bits per coordinate (1-4).
28+
/// Bits per coordinate.
29+
///
30+
/// For MSE encoding: 1-8.
31+
/// For QJL encoding: 2-9 (the MSE inner uses `bit_width - 1`).
3132
pub bit_width: u8,
32-
/// Which variant to use.
33-
pub variant: TurboQuantVariant,
3433
/// Optional seed for the rotation matrix. If None, a random seed is generated.
3534
pub seed: Option<u64>,
3635
}
3736

38-
/// Encode a FixedSizeListArray of floats into a TurboQuantArray.
39-
///
40-
/// The input should be the storage array of a Vector or FixedShapeTensor extension type.
41-
/// Each row (fixed-size-list element) is treated as a d-dimensional vector to quantize.
42-
pub fn turboquant_encode(
43-
fsl: &FixedSizeListArray,
44-
config: &TurboQuantConfig,
45-
) -> VortexResult<TurboQuantArray> {
46-
match config.variant {
47-
TurboQuantVariant::Mse => vortex_ensure!(
48-
config.bit_width >= 1 && config.bit_width <= 8,
49-
"MSE variant bit_width must be 1-8, got {}",
50-
config.bit_width
51-
),
52-
TurboQuantVariant::Prod => vortex_ensure!(
53-
config.bit_width >= 2 && config.bit_width <= 9,
54-
"Prod variant bit_width must be 2-9, got {}",
55-
config.bit_width
56-
),
57-
}
58-
59-
let dimension = fsl.list_size();
60-
vortex_ensure!(
61-
dimension >= 2,
62-
"TurboQuant requires dimension >= 2, got {dimension}"
63-
);
64-
let num_rows = fsl.len();
65-
66-
if num_rows == 0 {
67-
return encode_empty(fsl, config, dimension);
68-
}
69-
70-
let seed = config.seed.unwrap_or_else(rand::random);
71-
72-
// Extract flat f32 elements from the FixedSizeListArray.
73-
let f32_elements = extract_f32_elements(fsl)?;
74-
75-
match config.variant {
76-
TurboQuantVariant::Mse => encode_mse(
77-
&f32_elements,
78-
num_rows,
79-
dimension,
80-
config.bit_width,
81-
seed,
82-
fsl,
83-
),
84-
TurboQuantVariant::Prod => encode_prod(
85-
&f32_elements,
86-
num_rows,
87-
dimension,
88-
config.bit_width,
89-
seed,
90-
fsl,
91-
),
92-
}
93-
}
94-
9537
/// Extract elements from a FixedSizeListArray as a flat f32 vec.
9638
#[allow(clippy::cast_possible_truncation)]
9739
fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult<Vec<f32>> {
@@ -110,231 +52,12 @@ fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult<Vec<f32>> {
11052
}
11153
}
11254

113-
fn encode_empty(
114-
fsl: &FixedSizeListArray,
115-
config: &TurboQuantConfig,
116-
dimension: u32,
117-
) -> VortexResult<TurboQuantArray> {
118-
let seed = config.seed.unwrap_or(0);
119-
let codes = PrimitiveArray::empty::<u8>(fsl.dtype().nullability());
120-
let norms = PrimitiveArray::empty::<f32>(fsl.dtype().nullability());
121-
122-
match config.variant {
123-
TurboQuantVariant::Mse => TurboQuantArray::try_new_mse(
124-
fsl.dtype().clone(),
125-
codes.into_array(),
126-
norms.into_array(),
127-
dimension,
128-
config.bit_width,
129-
seed,
130-
),
131-
TurboQuantVariant::Prod => {
132-
let qjl_signs = PrimitiveArray::empty::<u8>(fsl.dtype().nullability());
133-
let residual_norms = PrimitiveArray::empty::<f32>(fsl.dtype().nullability());
134-
TurboQuantArray::try_new_prod(
135-
fsl.dtype().clone(),
136-
codes.into_array(),
137-
norms.into_array(),
138-
qjl_signs.into_array(),
139-
residual_norms.into_array(),
140-
dimension,
141-
config.bit_width,
142-
seed,
143-
)
144-
}
145-
}
146-
}
147-
148-
fn encode_mse(
149-
elements: &[f32],
150-
num_rows: usize,
151-
dimension: u32,
152-
bit_width: u8,
153-
seed: u64,
154-
fsl: &FixedSizeListArray,
155-
) -> VortexResult<TurboQuantArray> {
156-
let dim = dimension as usize;
157-
let rotation = RotationMatrix::try_new(seed, dim)?;
158-
let padded_dim = rotation.padded_dim();
159-
#[allow(clippy::cast_possible_truncation)]
160-
let centroids = get_centroids(padded_dim as u32, bit_width)?;
161-
162-
let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
163-
let mut norms_buf = BufferMut::<f32>::with_capacity(num_rows);
164-
165-
let mut padded = vec![0.0f32; padded_dim];
166-
let mut rotated = vec![0.0f32; padded_dim];
167-
168-
for row in 0..num_rows {
169-
let x = &elements[row * dim..(row + 1) * dim];
170-
171-
let norm = l2_norm(x);
172-
norms_buf.push(norm);
173-
174-
// Normalize, zero-pad to padded_dim, and rotate.
175-
padded.fill(0.0);
176-
if norm > 0.0 {
177-
let inv_norm = 1.0 / norm;
178-
for (dst, &src) in padded[..dim].iter_mut().zip(x.iter()) {
179-
*dst = src * inv_norm;
180-
}
181-
}
182-
rotation.rotate(&padded, &mut rotated);
183-
184-
// Quantize all padded_dim coordinates.
185-
for j in 0..padded_dim {
186-
all_indices.push(find_nearest_centroid(rotated[j], &centroids));
187-
}
188-
}
189-
190-
// Pack indices: bitpack for 1-7 bits, store raw u8 for 8 bits.
191-
let indices_array = PrimitiveArray::new::<u8>(all_indices.freeze(), Validity::NonNullable);
192-
let codes = if bit_width < 8 {
193-
bitpack_encode(&indices_array, bit_width, None)?.into_array()
194-
} else {
195-
indices_array.into_array()
196-
};
197-
198-
let norms_array = PrimitiveArray::new::<f32>(norms_buf.freeze(), Validity::NonNullable);
199-
200-
TurboQuantArray::try_new_mse(
201-
fsl.dtype().clone(),
202-
codes,
203-
norms_array.into_array(),
204-
dimension,
205-
bit_width,
206-
seed,
207-
)
208-
}
209-
210-
fn encode_prod(
211-
elements: &[f32],
212-
num_rows: usize,
213-
dimension: u32,
214-
bit_width: u8,
215-
seed: u64,
216-
fsl: &FixedSizeListArray,
217-
) -> VortexResult<TurboQuantArray> {
218-
let dim = dimension as usize;
219-
let mse_bit_width = bit_width - 1;
220-
221-
let rotation = RotationMatrix::try_new(seed, dim)?;
222-
let padded_dim = rotation.padded_dim();
223-
#[allow(clippy::cast_possible_truncation)]
224-
let centroids = get_centroids(padded_dim as u32, mse_bit_width)?;
225-
226-
let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
227-
let mut norms_buf = BufferMut::<f32>::with_capacity(num_rows);
228-
let mut residual_norms_buf = BufferMut::<f32>::with_capacity(num_rows);
229-
230-
// QJL sign bits: num_rows * padded_dim bits, packed into bytes.
231-
let total_sign_bits = num_rows * padded_dim;
232-
let sign_byte_count = total_sign_bits.div_ceil(8);
233-
let mut sign_buf = BufferMut::<u8>::with_capacity(sign_byte_count);
234-
sign_buf.extend(std::iter::repeat_n(0u8, sign_byte_count));
235-
let sign_slice = sign_buf.as_mut_slice();
236-
237-
let mut padded = vec![0.0f32; padded_dim];
238-
let mut rotated = vec![0.0f32; padded_dim];
239-
let mut dequantized_rotated = vec![0.0f32; padded_dim];
240-
let mut dequantized = vec![0.0f32; padded_dim];
241-
let mut residual = vec![0.0f32; padded_dim];
242-
let mut projected = vec![0.0f32; padded_dim];
243-
244-
// QJL random sign matrix generator (using seed + 1).
245-
let qjl_rotation = RotationMatrix::try_new(seed.wrapping_add(1), dim)?;
246-
247-
for row in 0..num_rows {
248-
let x = &elements[row * dim..(row + 1) * dim];
249-
250-
let norm = l2_norm(x);
251-
norms_buf.push(norm);
252-
253-
// Normalize, zero-pad, and rotate.
254-
padded.fill(0.0);
255-
if norm > 0.0 {
256-
let inv_norm = 1.0 / norm;
257-
for (dst, &src) in padded[..dim].iter_mut().zip(x.iter()) {
258-
*dst = src * inv_norm;
259-
}
260-
}
261-
rotation.rotate(&padded, &mut rotated);
262-
263-
// MSE quantize at (bit_width - 1) bits over padded_dim coordinates.
264-
for j in 0..padded_dim {
265-
let idx = find_nearest_centroid(rotated[j], &centroids);
266-
all_indices.push(idx);
267-
dequantized_rotated[j] = centroids[idx as usize];
268-
}
269-
270-
// Dequantize MSE result (inverse rotate to full padded space, take first dim).
271-
rotation.inverse_rotate(&dequantized_rotated, &mut dequantized);
272-
if norm > 0.0 {
273-
for val in &mut dequantized {
274-
*val *= norm;
275-
}
276-
}
277-
278-
// Compute residual r = x - x_hat_mse (only first dim elements matter).
279-
residual.fill(0.0);
280-
for j in 0..dim {
281-
residual[j] = x[j] - dequantized[j];
282-
}
283-
let residual_norm = l2_norm(&residual[..dim]);
284-
residual_norms_buf.push(residual_norm);
285-
286-
// QJL: sign(S * r).
287-
projected.fill(0.0);
288-
if residual_norm > 0.0 {
289-
qjl_rotation.rotate(&residual, &mut projected);
290-
}
291-
292-
// Store sign bits for padded_dim positions.
293-
let bit_offset = row * padded_dim;
294-
for j in 0..padded_dim {
295-
if projected[j] >= 0.0 {
296-
let bit_idx = bit_offset + j;
297-
sign_slice[bit_idx / 8] |= 1 << (bit_idx % 8);
298-
}
299-
}
300-
}
301-
302-
// Pack MSE indices: bitpack for 1-7 bits, store raw u8 for 8 bits.
303-
let indices_array = PrimitiveArray::new::<u8>(all_indices.freeze(), Validity::NonNullable);
304-
let codes = if mse_bit_width < 8 {
305-
bitpack_encode(&indices_array, mse_bit_width, None)?.into_array()
306-
} else {
307-
indices_array.into_array()
308-
};
309-
310-
let norms_array = PrimitiveArray::new::<f32>(norms_buf.freeze(), Validity::NonNullable);
311-
let residual_norms_array =
312-
PrimitiveArray::new::<f32>(residual_norms_buf.freeze(), Validity::NonNullable);
313-
314-
let qjl_signs = PrimitiveArray::new::<u8>(sign_buf.freeze(), Validity::NonNullable);
315-
316-
TurboQuantArray::try_new_prod(
317-
fsl.dtype().clone(),
318-
codes,
319-
norms_array.into_array(),
320-
qjl_signs.into_array(),
321-
residual_norms_array.into_array(),
322-
dimension,
323-
bit_width,
324-
seed,
325-
)
326-
}
327-
32855
/// Compute the L2 norm of a vector.
32956
#[inline]
33057
fn l2_norm(x: &[f32]) -> f32 {
33158
x.iter().map(|&v| v * v).sum::<f32>().sqrt()
33259
}
33360

334-
// ---------------------------------------------------------------------------
335-
// New encoding producing cascaded MSE/QJL arrays
336-
// ---------------------------------------------------------------------------
337-
33861
/// Encode a FixedSizeListArray into a `TurboQuantMSEArray`.
33962
pub fn turboquant_encode_mse(
34063
fsl: &FixedSizeListArray,
@@ -390,7 +113,7 @@ pub fn turboquant_encode_mse(
390113
}
391114
}
392115

393-
// Pack indices.
116+
// Pack indices: bitpack for 1-7 bits, store raw u8 for 8 bits.
394117
let indices_array = PrimitiveArray::new::<u8>(all_indices.freeze(), Validity::NonNullable);
395118
let codes = if config.bit_width < 8 {
396119
bitpack_encode(&indices_array, config.bit_width, None)?.into_array()
@@ -448,7 +171,6 @@ pub fn turboquant_encode_qjl(
448171
// First, encode the MSE inner at (bit_width - 1).
449172
let mse_config = TurboQuantConfig {
450173
bit_width: mse_bit_width,
451-
variant: TurboQuantVariant::Mse, // legacy field, not used in new path
452174
seed: Some(seed),
453175
};
454176
let mse_inner = turboquant_encode_mse(fsl, &mse_config)?;
@@ -581,7 +303,6 @@ fn build_empty_qjl_array(
581303
) -> VortexResult<TurboQuantQJLArray> {
582304
let mse_config = TurboQuantConfig {
583305
bit_width: bit_width - 1,
584-
variant: TurboQuantVariant::Mse,
585306
seed: Some(seed),
586307
};
587308
let mse_inner = turboquant_encode_mse(fsl, &mse_config)?;

0 commit comments

Comments
 (0)