Skip to content

Commit b2bcd38

Browse files
lwwmanningclaude
andcommitted
refactor[turboquant]: simplify code from review findings
- Consolidate encode_decode_mse and encode_decode_qjl test helpers into a single closure-parameterized encode_decode function - Replace 14 copy-pasted benchmark functions (~200 lines) with a turboquant_bench! macro (~40 lines) - Extract QJL correction scale factor to a named function with doc comment explaining the derivation - Precompute centroid decision boundaries (midpoints) once before the row loop, replacing per-coordinate distance comparisons with a single partition_point lookup. This removes two abs() calls and a branch from the innermost quantization loop. Net: -150 lines. Signed-off-by: Will Manning <will@spiraldb.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Will Manning <will@willmanning.io>
1 parent 0c289ab commit b2bcd38

6 files changed

Lines changed: 103 additions & 253 deletions

File tree

encodings/turboquant/public-api.lock

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ pub mod vortex_turboquant
22

33
pub mod vortex_turboquant::centroids
44

5-
pub fn vortex_turboquant::centroids::find_nearest_centroid(value: f32, centroids: &[f32]) -> u8
5+
pub fn vortex_turboquant::centroids::compute_boundaries(centroids: &[f32]) -> alloc::vec::Vec<f32>
6+
7+
pub fn vortex_turboquant::centroids::find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8
68

79
pub fn vortex_turboquant::centroids::get_centroids(dimension: u32, bit_width: u8) -> vortex_error::VortexResult<alloc::vec::Vec<f32>>
810

encodings/turboquant/src/centroids.rs

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -147,32 +147,24 @@ fn pdf_unnormalized(x_val: f64, exponent: f64) -> f64 {
147147
base.powf(exponent)
148148
}
149149

150-
/// Find the index of the nearest centroid to the given value.
150+
/// Precompute decision boundaries (midpoints between adjacent centroids).
151151
///
152-
/// Centroids must be sorted in ascending order. Uses binary search for efficiency.
153-
#[inline]
154-
pub fn find_nearest_centroid(value: f32, centroids: &[f32]) -> u8 {
155-
debug_assert!(!centroids.is_empty());
156-
157-
let idx = centroids.partition_point(|&c_val| c_val < value);
158-
159-
if idx == 0 {
160-
return 0;
161-
}
162-
if idx >= centroids.len() {
163-
#[allow(clippy::cast_possible_truncation)]
164-
return (centroids.len() - 1) as u8;
165-
}
166-
167-
let dist_left = (value - centroids[idx - 1]).abs();
168-
let dist_right = (value - centroids[idx]).abs();
152+
/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps
153+
/// to centroid 0, a value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`,
154+
/// and a value >= `boundaries[k-2]` maps to centroid `k-1`.
155+
pub fn compute_boundaries(centroids: &[f32]) -> Vec<f32> {
156+
centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect()
157+
}
169158

170-
#[allow(clippy::cast_possible_truncation)]
171-
if dist_left <= dist_right {
172-
(idx - 1) as u8
173-
} else {
174-
idx as u8
175-
}
159+
/// Find the index of the nearest centroid using precomputed decision boundaries.
160+
///
161+
/// `boundaries` must be the output of [`compute_boundaries`] for the corresponding
162+
/// centroids. Uses binary search on the midpoints, avoiding distance comparisons
163+
/// in the inner loop.
164+
#[inline]
165+
#[allow(clippy::cast_possible_truncation)]
166+
pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 {
167+
boundaries.partition_point(|&b| b < value) as u8
176168
}
177169

178170
#[cfg(test)]
@@ -263,14 +255,15 @@ mod tests {
263255
#[test]
264256
fn find_nearest_basic() -> VortexResult<()> {
265257
let centroids = get_centroids(128, 2)?;
266-
assert_eq!(find_nearest_centroid(-1.0, &centroids), 0);
258+
let boundaries = compute_boundaries(&centroids);
259+
assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0);
267260
#[allow(clippy::cast_possible_truncation)]
268261
let last_idx = (centroids.len() - 1) as u8;
269-
assert_eq!(find_nearest_centroid(1.0, &centroids), last_idx);
262+
assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx);
270263
for (idx, &cv) in centroids.iter().enumerate() {
271264
#[allow(clippy::cast_possible_truncation)]
272265
let expected = idx as u8;
273-
assert_eq!(find_nearest_centroid(cv, &centroids), expected);
266+
assert_eq!(find_nearest_centroid(cv, &boundaries), expected);
274267
}
275268
Ok(())
276269
}

encodings/turboquant/src/compress.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use vortex_error::vortex_bail;
1717
use vortex_error::vortex_ensure;
1818
use vortex_fastlanes::bitpack_compress::bitpack_encode;
1919

20+
use crate::centroids::compute_boundaries;
2021
use crate::centroids::find_nearest_centroid;
2122
use crate::centroids::get_centroids;
2223
use crate::mse::array::TurboQuantMSEArray;
@@ -96,6 +97,7 @@ pub fn turboquant_encode_mse(
9697
let f32_elements = extract_f32_elements(fsl)?;
9798
#[allow(clippy::cast_possible_truncation)]
9899
let centroids = get_centroids(padded_dim as u32, config.bit_width)?;
100+
let boundaries = compute_boundaries(&centroids);
99101

100102
let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
101103
let mut norms_buf = BufferMut::<f32>::with_capacity(num_rows);
@@ -117,7 +119,7 @@ pub fn turboquant_encode_mse(
117119
rotation.rotate(&padded, &mut rotated);
118120

119121
for j in 0..padded_dim {
120-
all_indices.push(find_nearest_centroid(rotated[j], &centroids));
122+
all_indices.push(find_nearest_centroid(rotated[j], &boundaries));
121123
}
122124
}
123125

@@ -201,6 +203,7 @@ pub fn turboquant_encode_qjl(
201203
let f32_elements = extract_f32_elements(fsl)?;
202204
#[allow(clippy::cast_possible_truncation)]
203205
let centroids = get_centroids(padded_dim as u32, mse_bit_width)?;
206+
let boundaries = compute_boundaries(&centroids);
204207

205208
// QJL uses a different rotation than the MSE stage to ensure statistical
206209
// independence between the quantization noise and the sign projection.
@@ -232,7 +235,7 @@ pub fn turboquant_encode_qjl(
232235
rotation.rotate(&padded, &mut rotated);
233236

234237
for j in 0..padded_dim {
235-
let idx = find_nearest_centroid(rotated[j], &centroids);
238+
let idx = find_nearest_centroid(rotated[j], &boundaries);
236239
dequantized_rotated[j] = centroids[idx as usize];
237240
}
238241

encodings/turboquant/src/decompress.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ use crate::mse::array::TurboQuantMSEArray;
1717
use crate::qjl::array::TurboQuantQJLArray;
1818
use crate::rotation::RotationMatrix;
1919

20+
/// QJL correction scale factor: `sqrt(π/2) / padded_dim`.
21+
///
22+
/// Accounts for the SRHT normalization (`1/padded_dim^{3/2}` per transform)
23+
/// combined with `E[|z|] = sqrt(2/π)` for half-normal sign expectations.
24+
/// Verified empirically via the `qjl_inner_product_bias` test suite.
25+
#[inline]
26+
fn qjl_correction_scale(padded_dim: usize) -> f32 {
27+
(std::f32::consts::FRAC_PI_2).sqrt() / (padded_dim as f32)
28+
}
29+
2030
/// Decompress a `TurboQuantMSEArray` into a `FixedSizeListArray` of floats.
2131
///
2232
/// Reads stored centroids and rotation signs from the array's children,
@@ -126,11 +136,7 @@ pub fn execute_decompress_qjl(
126136
let qjl_rot_signs_bool = array.rotation_signs.clone().execute::<BoolArray>(ctx)?;
127137
let qjl_rot = RotationMatrix::from_bool_array(&qjl_rot_signs_bool, dim)?;
128138

129-
// QJL correction scale: sqrt(π/2) / padded_dim.
130-
// This accounts for the SRHT normalization (1/padded_dim^{3/2} per transform)
131-
// combined with the E[|z|] = sqrt(2/π) expectation of half-normal signs.
132-
// Verified empirically via the `qjl_inner_product_bias` test suite.
133-
let qjl_scale = (std::f32::consts::FRAC_PI_2).sqrt() / (padded_dim as f32);
139+
let qjl_scale = qjl_correction_scale(padded_dim);
134140

135141
let mut output = BufferMut::<f32>::with_capacity(num_rows * dim);
136142
let mut qjl_signs_vec = vec![0.0f32; padded_dim];

encodings/turboquant/src/lib.rs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ mod tests {
120120
use rand_distr::Distribution;
121121
use rand_distr::Normal;
122122
use rstest::rstest;
123+
use vortex_array::ArrayRef;
123124
use vortex_array::IntoArray;
124125
use vortex_array::VortexSessionExecute;
125126
use vortex_array::arrays::FixedSizeListArray;
@@ -186,46 +187,43 @@ mod tests {
186187
total / num_rows as f32
187188
}
188189

189-
/// Encode via MSE and decode, returning (original, decoded) flat f32 slices.
190-
fn encode_decode_mse(
190+
/// Encode and decode, returning (original, decoded) flat f32 slices.
191+
fn encode_decode(
191192
fsl: &FixedSizeListArray,
192-
config: &TurboQuantConfig,
193+
encode_fn: impl FnOnce(&FixedSizeListArray) -> VortexResult<ArrayRef>,
193194
) -> VortexResult<(Vec<f32>, Vec<f32>)> {
194195
let original: Vec<f32> = {
195196
let prim = fsl.elements().to_canonical().unwrap().into_primitive();
196197
prim.as_slice::<f32>().to_vec()
197198
};
198-
let encoded = turboquant_encode_mse(fsl, config)?;
199+
let encoded = encode_fn(fsl)?;
199200
let mut ctx = SESSION.create_execution_ctx();
200-
let decoded = encoded
201-
.into_array()
202-
.execute::<FixedSizeListArray>(&mut ctx)?;
201+
let decoded = encoded.execute::<FixedSizeListArray>(&mut ctx)?;
203202
let decoded_elements: Vec<f32> = {
204203
let prim = decoded.elements().to_canonical().unwrap().into_primitive();
205204
prim.as_slice::<f32>().to_vec()
206205
};
207206
Ok((original, decoded_elements))
208207
}
209208

210-
/// Encode via QJL and decode, returning (original, decoded) flat f32 slices.
209+
fn encode_decode_mse(
210+
fsl: &FixedSizeListArray,
211+
config: &TurboQuantConfig,
212+
) -> VortexResult<(Vec<f32>, Vec<f32>)> {
213+
let config = config.clone();
214+
encode_decode(fsl, |fsl| {
215+
Ok(turboquant_encode_mse(fsl, &config)?.into_array())
216+
})
217+
}
218+
211219
fn encode_decode_qjl(
212220
fsl: &FixedSizeListArray,
213221
config: &TurboQuantConfig,
214222
) -> VortexResult<(Vec<f32>, Vec<f32>)> {
215-
let original: Vec<f32> = {
216-
let prim = fsl.elements().to_canonical().unwrap().into_primitive();
217-
prim.as_slice::<f32>().to_vec()
218-
};
219-
let encoded = turboquant_encode_qjl(fsl, config)?;
220-
let mut ctx = SESSION.create_execution_ctx();
221-
let decoded = encoded
222-
.into_array()
223-
.execute::<FixedSizeListArray>(&mut ctx)?;
224-
let decoded_elements: Vec<f32> = {
225-
let prim = decoded.elements().to_canonical().unwrap().into_primitive();
226-
prim.as_slice::<f32>().to_vec()
227-
};
228-
Ok((original, decoded_elements))
223+
let config = config.clone();
224+
encode_decode(fsl, |fsl| {
225+
Ok(turboquant_encode_qjl(fsl, &config)?.into_array())
226+
})
229227
}
230228

231229
// -----------------------------------------------------------------------

0 commit comments

Comments
 (0)