Skip to content

Commit 0ea2b2b

Browse files
committed
clean up some stuff
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 0170156 commit 0ea2b2b

File tree

13 files changed

+141
-113
lines changed

13 files changed

+141
-113
lines changed

vortex-tensor/public-api.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ pub fn T::padded_dim(&self) -> u32
182182

183183
pub fn T::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef
184184

185-
pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: &vortex_array::arrays::extension::vtable::ExtensionArray, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
185+
pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
186186

187187
pub type vortex_tensor::encodings::turboquant::TurboQuantArray = vortex_array::array::typed::Array<vortex_tensor::encodings::turboquant::TurboQuant>
188188

vortex-tensor/src/encodings/turboquant/array/centroids.rs

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,53 +12,58 @@
1212
use std::sync::LazyLock;
1313

1414
use vortex_error::VortexResult;
15-
use vortex_error::vortex_bail;
15+
use vortex_error::vortex_ensure;
1616
use vortex_utils::aliases::dash_map::DashMap;
1717

1818
use crate::encodings::turboquant::TurboQuant;
1919

20-
/// Number of numerical integration points for computing conditional expectations.
21-
const INTEGRATION_POINTS: usize = 1000;
20+
/// The maximum iterations for Max-Lloyd algorithm when computing centroids.
21+
const MAX_ITERATIONS: usize = 200;
2222

23-
/// Max-Lloyd convergence threshold.
23+
/// The Max-Lloyd convergence threshold for stopping early when computing centroids.
2424
const CONVERGENCE_EPSILON: f64 = 1e-12;
2525

26-
/// Maximum iterations for Max-Lloyd algorithm.
27-
const MAX_ITERATIONS: usize = 200;
26+
/// Number of numerical integration points for computing conditional expectations.
27+
const INTEGRATION_POINTS: usize = 1000;
2828

29+
// TODO(connor): Maybe we should just store an `ArrayRef` here?
2930
/// Global centroid cache keyed by (dimension, bit_width).
3031
static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Vec<f32>>> = LazyLock::new(DashMap::default);
3132

3233
/// Get or compute cached centroids for the given dimension and bit width.
3334
///
34-
/// Returns `2^bit_width` centroids sorted in ascending order, representing
35-
/// optimal scalar quantization levels for the coordinate distribution after
36-
/// random rotation in `dimension`-dimensional space.
35+
/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar
36+
/// quantization levels for the coordinate distribution after random rotation in
37+
/// `dimension`-dimensional space.
3738
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> {
38-
if !(1..=8).contains(&bit_width) {
39-
vortex_bail!("TurboQuant bit_width must be 1-8, got {bit_width}");
40-
}
41-
if dimension < TurboQuant::MIN_DIMENSION {
42-
vortex_bail!(
43-
"TurboQuant dimension must be >= {}, got {dimension}",
44-
TurboQuant::MIN_DIMENSION
45-
);
46-
}
39+
vortex_ensure!(
40+
(1..=8).contains(&bit_width),
41+
"TurboQuant bit_width must be 1-8, got {bit_width}"
42+
);
43+
vortex_ensure!(
44+
dimension >= TurboQuant::MIN_DIMENSION,
45+
"TurboQuant dimension must be >= {}, got {dimension}",
46+
TurboQuant::MIN_DIMENSION
47+
);
4748

4849
if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) {
4950
return Ok(centroids.clone());
5051
}
5152

5253
let centroids = max_lloyd_centroids(dimension, bit_width);
5354
CENTROID_CACHE.insert((dimension, bit_width), centroids.clone());
55+
5456
Ok(centroids)
5557
}
5658

59+
// TODO(connor): It would potentially be more performant if this was modelled as const generic
60+
// parameters to functions.
5761
/// Half-integer exponent: represents `int_part + (if has_half { 0.5 } else { 0.0 })`.
5862
///
59-
/// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd)
60-
/// or a half-integer (when `d` is even). This type makes that invariant explicit and
61-
/// avoids floating-point comparison in the hot path.
63+
/// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd) or a
64+
/// half-integer (when `d` is even).
65+
///
66+
/// This type makes that invariant explicit and avoids floating-point comparison in the hot path.
6267
#[derive(Clone, Copy, Debug)]
6368
struct HalfIntExponent {
6469
int_part: i32,
@@ -70,12 +75,7 @@ impl HalfIntExponent {
7075
///
7176
/// `numerator` is `d - 3` where `d` is the dimension (>= 2), so it can be negative.
7277
fn from_numerator(numerator: i32) -> Self {
73-
// Integer division truncates toward zero; for negative odd numerators
74-
// (e.g., d=2 → num=-1) this gives int_part=0, has_half=true,
75-
// representing -0.5 = 0 + (-0.5). The sign is handled by adjusting
76-
// int_part: -1/2 = 0 with has_half, but we need the floor division.
77-
// Rust's `/` truncates toward zero, so -1/2 = 0. We want floor: -1.
78-
// Use divmod that rounds toward negative infinity.
78+
// Use Euclidean division to get floor division toward negative infinity.
7979
let int_part = numerator.div_euclid(2);
8080
let has_half = numerator.rem_euclid(2) != 0;
8181
Self { int_part, has_half }
@@ -84,12 +84,14 @@ impl HalfIntExponent {
8484

8585
/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm.
8686
///
87-
/// Operates on the marginal distribution of a single coordinate of a randomly
88-
/// rotated unit vector in d dimensions. The PDF is:
87+
/// Operates on the marginal distribution of a single coordinate of a randomly rotated unit vector
88+
/// in d dimensions.
89+
///
90+
/// The probability distribution function is:
8991
/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]`
9092
/// where `C_d` is the normalizing constant.
91-
#[allow(clippy::cast_possible_truncation)] // f64→f32 centroid values are intentional
9293
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
94+
debug_assert!((1..=8).contains(&bit_width));
9395
let num_centroids = 1usize << bit_width;
9496

9597
// For the marginal distribution on [-1, 1], we use the exponent (d-3)/2.
@@ -114,7 +116,7 @@ fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
114116
for idx in 0..num_centroids {
115117
let lo = boundaries[idx];
116118
let hi = boundaries[idx + 1];
117-
let new_centroid = conditional_mean(lo, hi, exponent);
119+
let new_centroid = mean_between_centroids(lo, hi, exponent);
118120
max_change = max_change.max((new_centroid - centroids[idx]).abs());
119121
centroids[idx] = new_centroid;
120122
}
@@ -124,14 +126,19 @@ fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
124126
}
125127
}
126128

129+
#[expect(
130+
clippy::cast_possible_truncation,
131+
reason = "all values are in [-1, 1] so this just loses precision"
132+
)]
127133
centroids.into_iter().map(|val| val as f32).collect()
128134
}
129135

130136
/// Compute the conditional mean of the coordinate distribution on interval [lo, hi].
131137
///
132-
/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent`
133-
/// on [-1, 1].
134-
fn conditional_mean(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 {
138+
/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` on [-1, 1].
139+
///
140+
/// Since there is no closed form for the integrals, we compute this numerically.
141+
fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 {
135142
if (hi - lo).abs() < 1e-15 {
136143
return (lo + hi) / 2.0;
137144
}
@@ -164,9 +171,9 @@ fn conditional_mean(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 {
164171

165172
/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`.
166173
///
167-
/// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents
168-
/// that arise from `(d-3)/2`. This is significantly faster than the general
169-
/// `powf` which goes through `exp(exponent * ln(base))`.
174+
/// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents that arise from `(d-3)/2`.
175+
/// This is significantly faster than the general `powf` which goes through
176+
/// `exp(exponent * ln(base))`.
170177
#[inline]
171178
fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 {
172179
let base = (1.0 - x_val * x_val).max(0.0);
@@ -182,10 +189,10 @@ fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 {
182189

183190
/// Precompute decision boundaries (midpoints between adjacent centroids).
184191
///
185-
/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps
186-
/// to centroid 0, a value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`,
187-
/// and a value >= `boundaries[k-2]` maps to centroid `k-1`.
188-
pub fn compute_boundaries(centroids: &[f32]) -> Vec<f32> {
192+
/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps to centroid 0, a
193+
/// value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, and a
194+
/// value `>= boundaries[k-2]` maps to centroid `k-1`.
195+
pub fn compute_centroid_boundaries(centroids: &[f32]) -> Vec<f32> {
189196
centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect()
190197
}
191198

@@ -195,14 +202,21 @@ pub fn compute_boundaries(centroids: &[f32]) -> Vec<f32> {
195202
/// centroids. Uses binary search on the midpoints, avoiding distance comparisons
196203
/// in the inner loop.
197204
#[inline]
198-
#[allow(clippy::cast_possible_truncation)] // bounded by num_centroids <= 256
199205
pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 {
200206
debug_assert!(
201207
boundaries.windows(2).all(|w| w[0] <= w[1]),
202208
"boundaries must be sorted"
203209
);
210+
debug_assert!(
211+
boundaries.len() <= 256, // 1 << 8
212+
"boundaries must be sorted"
213+
);
204214

205-
boundaries.partition_point(|&b| b < value) as u8
215+
#[expect(
216+
clippy::cast_possible_truncation,
217+
reason = "num_centroids <= 256 and partition_point will return at most 255"
218+
)]
219+
(boundaries.partition_point(|&b| b < value) as u8)
206220
}
207221

208222
#[cfg(test)]
@@ -294,7 +308,7 @@ mod tests {
294308
#[test]
295309
fn find_nearest_basic() -> VortexResult<()> {
296310
let centroids = get_centroids(128, 2)?;
297-
let boundaries = compute_boundaries(&centroids);
311+
let boundaries = compute_centroid_boundaries(&centroids);
298312
assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0);
299313

300314
let last_idx = (centroids.len() - 1) as u8;

vortex-tensor/src/encodings/turboquant/array/data.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ pub struct TurboQuantData {
3232
/// Stored as a convenience field to avoid repeatedly extracting it from `dtype`.
3333
pub(crate) dimension: u32,
3434

35-
/// The number of bits per coordinate (1-8), derived from `log2(centroids.len())`.
35+
/// The number of bits per coordinate (0-8), derived from `log2(centroids.len())`.
3636
///
3737
/// This is 0 for degenerate empty arrays.
3838
pub(crate) bit_width: u8,
@@ -56,6 +56,7 @@ impl TurboQuantData {
5656
bit_width <= 8,
5757
"bit_width is expected to be between 0 and 8, got {bit_width}"
5858
);
59+
5960
Ok(Self {
6061
dimension,
6162
bit_width,

vortex-tensor/src/encodings/turboquant/array/mod.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,3 @@ pub(crate) mod centroids;
1111
pub(crate) mod rotation;
1212

1313
pub(crate) mod scheme;
14-
15-
use num_traits::Float;
16-
use num_traits::FromPrimitive;
17-
use vortex_error::VortexExpect;
18-
19-
/// Convert an f32 value to a float type `T`.
20-
///
21-
/// `FromPrimitive::from_f32` is infallible for all Vortex float types: f16 saturates via the
22-
/// inherent `f16::from_f32()`, f32 is identity, f64 is lossless widening.
23-
pub(crate) fn float_from_f32<T: Float + FromPrimitive>(v: f32) -> T {
24-
FromPrimitive::from_f32(v).vortex_expect("f32-to-float conversion is infallible")
25-
}

vortex-tensor/src/encodings/turboquant/array/scheme.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
66
use vortex_array::ArrayRef;
77
use vortex_array::Canonical;
8+
use vortex_array::arrays::Extension;
89
use vortex_compressor::CascadingCompressor;
910
use vortex_compressor::ctx::CompressorContext;
1011
use vortex_compressor::scheme::Scheme;
@@ -75,11 +76,13 @@ impl Scheme for TurboQuantScheme {
7576
data: &mut ArrayAndStats,
7677
_ctx: CompressorContext,
7778
) -> VortexResult<ArrayRef> {
78-
// TODO(connor): Fix this once we ensure that the data array is always canonical.
79-
let ext_array = data.array().to_canonical()?.into_extension();
79+
let ext_array = data
80+
.array()
81+
.as_opt::<Extension>()
82+
.vortex_expect("expected an extension array");
8083

8184
let config = TurboQuantConfig::default();
82-
turboquant_encode(&ext_array, &config, &mut compressor.execution_ctx())
85+
turboquant_encode(ext_array, &config, &mut compressor.execution_ctx())
8386
}
8487
}
8588

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
66
use num_traits::ToPrimitive;
77
use vortex_array::ArrayRef;
8+
use vortex_array::ArrayView;
89
use vortex_array::ExecutionCtx;
910
use vortex_array::IntoArray;
10-
use vortex_array::arrays::ExtensionArray;
11+
use vortex_array::arrays::Extension;
1112
use vortex_array::arrays::FixedSizeListArray;
1213
use vortex_array::arrays::PrimitiveArray;
1314
use vortex_array::arrays::extension::ExtensionArrayExt;
@@ -25,7 +26,7 @@ use vortex_error::vortex_ensure;
2526
use vortex_fastlanes::bitpack_compress::bitpack_encode;
2627

2728
use crate::encodings::turboquant::TurboQuant;
28-
use crate::encodings::turboquant::array::centroids::compute_boundaries;
29+
use crate::encodings::turboquant::array::centroids::compute_centroid_boundaries;
2930
use crate::encodings::turboquant::array::centroids::find_nearest_centroid;
3031
use crate::encodings::turboquant::array::centroids::get_centroids;
3132
use crate::encodings::turboquant::array::rotation::RotationMatrix;
@@ -95,7 +96,7 @@ struct QuantizationResult {
9596
/// all-zero codes.
9697
#[allow(clippy::cast_possible_truncation)]
9798
fn turboquant_quantize_core(
98-
ext: &ExtensionArray,
99+
ext: ArrayView<Extension>,
99100
fsl: &FixedSizeListArray,
100101
seed: u64,
101102
bit_width: u8,
@@ -130,7 +131,7 @@ fn turboquant_quantize_core(
130131
let f32_elements = extract_f32_elements(fsl)?;
131132

132133
let centroids = get_centroids(padded_dim as u32, bit_width)?;
133-
let boundaries = compute_boundaries(&centroids);
134+
let boundaries = compute_centroid_boundaries(&centroids);
134135

135136
let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
136137
let mut padded = vec![0.0f32; padded_dim];
@@ -213,7 +214,7 @@ fn build_turboquant(
213214
/// Nullable inputs are supported: null vectors get all-zero codes and null norms. The validity
214215
/// of the resulting TurboQuant array is carried by the norms child.
215216
pub fn turboquant_encode(
216-
ext: &ExtensionArray,
217+
ext: ArrayView<Extension>,
217218
config: &TurboQuantConfig,
218219
ctx: &mut ExecutionCtx,
219220
) -> VortexResult<ArrayRef> {

vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use vortex_error::vortex_ensure_eq;
4444

4545
use crate::encodings::turboquant::TurboQuant;
4646
use crate::encodings::turboquant::TurboQuantArrayExt;
47-
use crate::encodings::turboquant::array::float_from_f32;
47+
use crate::encodings::turboquant::compute::float_from_f32;
4848
use crate::vector::AnyVector;
4949

5050
/// Compute the per-row unit-norm dot products in f32 (centroids are always f32).

vortex-tensor/src/encodings/turboquant/compute/mod.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,21 @@
44
//! Compute pushdown implementations for TurboQuant.
55
66
pub(crate) mod cosine_similarity;
7+
78
mod ops;
8-
pub(crate) mod rules;
99
mod slice;
1010
mod take;
11+
12+
pub(crate) mod rules;
13+
14+
use num_traits::Float;
15+
use num_traits::FromPrimitive;
16+
use vortex_error::VortexExpect;
17+
18+
/// Convert an f32 value to a float type `T`.
19+
///
20+
/// `FromPrimitive::from_f32` is infallible for all Vortex float types: f16 saturates via the
21+
/// inherent `f16::from_f32()`, f32 is identity, f64 is lossless widening.
22+
pub(crate) fn float_from_f32<T: Float + FromPrimitive>(v: f32) -> T {
23+
FromPrimitive::from_f32(v).vortex_expect("f32-to-float conversion is infallible")
24+
}

vortex-tensor/src/encodings/turboquant/decompress.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ use vortex_error::VortexResult;
2222

2323
use crate::encodings::turboquant::TurboQuant;
2424
use crate::encodings::turboquant::TurboQuantArrayExt;
25-
use crate::encodings::turboquant::array::float_from_f32;
2625
use crate::encodings::turboquant::array::rotation::RotationMatrix;
26+
use crate::encodings::turboquant::compute::float_from_f32;
2727
use crate::vector::AnyVector;
2828

2929
/// Decompress a `TurboQuantArray` into a [`Vector`] extension array.

vortex-tensor/src/encodings/turboquant/metadata.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use vortex_error::vortex_err;
1111
/// Serialized metadata for TurboQuant arrays.
1212
#[derive(Clone, PartialEq, Message)]
1313
pub(super) struct TurboQuantMetadata {
14-
/// The number of bits per coordinate.
14+
/// The number of bits per coordinate, which must be <= 8.
1515
#[prost(uint32, required, tag = "1")]
1616
bit_width: u32,
1717
}

0 commit comments

Comments
 (0)