1212use std:: sync:: LazyLock ;
1313
1414use vortex_error:: VortexResult ;
15- use vortex_error:: vortex_bail ;
15+ use vortex_error:: vortex_ensure ;
1616use vortex_utils:: aliases:: dash_map:: DashMap ;
1717
1818use crate :: encodings:: turboquant:: TurboQuant ;
1919
20- /// Number of numerical integration points for computing conditional expectations .
21- const INTEGRATION_POINTS : usize = 1000 ;
20+ /// The maximum iterations for Max-Lloyd algorithm when computing centroids .
21+ const MAX_ITERATIONS : usize = 200 ;
2222
23- /// Max-Lloyd convergence threshold.
23+ /// The Max-Lloyd convergence threshold for stopping early when computing centroids .
2424const CONVERGENCE_EPSILON : f64 = 1e-12 ;
2525
26- /// Maximum iterations for Max-Lloyd algorithm .
27- const MAX_ITERATIONS : usize = 200 ;
26+ /// Number of numerical integration points for computing conditional expectations .
27+ const INTEGRATION_POINTS : usize = 1000 ;
2828
29+ // TODO(connor): Maybe we should just store an `ArrayRef` here?
2930/// Global centroid cache keyed by (dimension, bit_width).
3031static CENTROID_CACHE : LazyLock < DashMap < ( u32 , u8 ) , Vec < f32 > > > = LazyLock :: new ( DashMap :: default) ;
3132
3233/// Get or compute cached centroids for the given dimension and bit width.
3334///
34- /// Returns `2^bit_width` centroids sorted in ascending order, representing
35- /// optimal scalar quantization levels for the coordinate distribution after
36- /// random rotation in `dimension`-dimensional space.
35+ /// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar
36+ /// quantization levels for the coordinate distribution after random rotation in
37+ /// `dimension`-dimensional space.
3738pub fn get_centroids ( dimension : u32 , bit_width : u8 ) -> VortexResult < Vec < f32 > > {
38- if ! ( 1 ..= 8 ) . contains ( & bit_width ) {
39- vortex_bail ! ( "TurboQuant bit_width must be 1-8, got { bit_width}" ) ;
40- }
41- if dimension < TurboQuant :: MIN_DIMENSION {
42- vortex_bail ! (
43- "TurboQuant dimension must be >= {}, got {dimension}" ,
44- TurboQuant :: MIN_DIMENSION
45- ) ;
46- }
39+ vortex_ensure ! (
40+ ( 1 ..= 8 ) . contains ( & bit_width) ,
41+ "TurboQuant bit_width must be 1-8, got {bit_width}"
42+ ) ;
43+ vortex_ensure ! (
44+ dimension >= TurboQuant :: MIN_DIMENSION ,
45+ "TurboQuant dimension must be >= {}, got {dimension}" ,
46+ TurboQuant :: MIN_DIMENSION
47+ ) ;
4748
4849 if let Some ( centroids) = CENTROID_CACHE . get ( & ( dimension, bit_width) ) {
4950 return Ok ( centroids. clone ( ) ) ;
5051 }
5152
5253 let centroids = max_lloyd_centroids ( dimension, bit_width) ;
5354 CENTROID_CACHE . insert ( ( dimension, bit_width) , centroids. clone ( ) ) ;
55+
5456 Ok ( centroids)
5557}
5658
59+ // TODO(connor): It would potentially be more performant if this was modelled as const generic
60+ // parameters to functions.
5761/// Half-integer exponent: represents `int_part + (if has_half { 0.5 } else { 0.0 })`.
5862///
59- /// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd)
60- /// or a half-integer (when `d` is even). This type makes that invariant explicit and
61- /// avoids floating-point comparison in the hot path.
63+ /// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd) or a
64+ /// half-integer (when `d` is even).
65+ ///
66+ /// This type makes that invariant explicit and avoids floating-point comparison in the hot path.
6267#[ derive( Clone , Copy , Debug ) ]
6368struct HalfIntExponent {
6469 int_part : i32 ,
@@ -70,12 +75,7 @@ impl HalfIntExponent {
7075 ///
7176 /// `numerator` is `d - 3` where `d` is the dimension (>= 2), so it can be negative.
7277 fn from_numerator ( numerator : i32 ) -> Self {
73- // Integer division truncates toward zero; for negative odd numerators
74- // (e.g., d=2 → num=-1) this gives int_part=0, has_half=true,
75- // representing -0.5 = 0 + (-0.5). The sign is handled by adjusting
76- // int_part: -1/2 = 0 with has_half, but we need the floor division.
77- // Rust's `/` truncates toward zero, so -1/2 = 0. We want floor: -1.
78- // Use divmod that rounds toward negative infinity.
78+ // Use Euclidean division to get floor division toward negative infinity.
7979 let int_part = numerator. div_euclid ( 2 ) ;
8080 let has_half = numerator. rem_euclid ( 2 ) != 0 ;
8181 Self { int_part, has_half }
@@ -84,12 +84,14 @@ impl HalfIntExponent {
8484
8585/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm.
8686///
87- /// Operates on the marginal distribution of a single coordinate of a randomly
88- /// rotated unit vector in d dimensions. The PDF is:
87+ /// Operates on the marginal distribution of a single coordinate of a randomly rotated unit vector
88+ /// in d dimensions.
89+ ///
90+ /// The probability distribution function is:
8991/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]`
9092/// where `C_d` is the normalizing constant.
91- #[ allow( clippy:: cast_possible_truncation) ] // f64→f32 centroid values are intentional
9293fn max_lloyd_centroids ( dimension : u32 , bit_width : u8 ) -> Vec < f32 > {
94+ debug_assert ! ( ( 1 ..=8 ) . contains( & bit_width) ) ;
9395 let num_centroids = 1usize << bit_width;
9496
9597 // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2.
@@ -114,7 +116,7 @@ fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
114116 for idx in 0 ..num_centroids {
115117 let lo = boundaries[ idx] ;
116118 let hi = boundaries[ idx + 1 ] ;
117- let new_centroid = conditional_mean ( lo, hi, exponent) ;
119+ let new_centroid = mean_between_centroids ( lo, hi, exponent) ;
118120 max_change = max_change. max ( ( new_centroid - centroids[ idx] ) . abs ( ) ) ;
119121 centroids[ idx] = new_centroid;
120122 }
@@ -124,14 +126,19 @@ fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
124126 }
125127 }
126128
129+ #[ expect(
130+ clippy:: cast_possible_truncation,
131+ reason = "all values are in [-1, 1] so this just loses precision"
132+ ) ]
127133 centroids. into_iter ( ) . map ( |val| val as f32 ) . collect ( )
128134}
129135
130136/// Compute the conditional mean of the coordinate distribution on interval [lo, hi].
131137///
132- /// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent`
133- /// on [-1, 1].
134- fn conditional_mean ( lo : f64 , hi : f64 , exponent : HalfIntExponent ) -> f64 {
138+ /// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` on [-1, 1].
139+ ///
140+ /// Since there is no closed form for the integrals, we compute this numerically.
141+ fn mean_between_centroids ( lo : f64 , hi : f64 , exponent : HalfIntExponent ) -> f64 {
135142 if ( hi - lo) . abs ( ) < 1e-15 {
136143 return ( lo + hi) / 2.0 ;
137144 }
@@ -164,9 +171,9 @@ fn conditional_mean(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 {
164171
165172/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`.
166173///
167- /// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents
168- /// that arise from `(d-3)/2`. This is significantly faster than the general
169- /// `powf` which goes through ` exp(exponent * ln(base))`.
174+ /// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents that arise from `(d-3)/2`.
175+ /// This is significantly faster than the general `powf` which goes through
176+ /// `exp(exponent * ln(base))`.
170177#[ inline]
171178fn pdf_unnormalized ( x_val : f64 , exponent : HalfIntExponent ) -> f64 {
172179 let base = ( 1.0 - x_val * x_val) . max ( 0.0 ) ;
@@ -182,10 +189,10 @@ fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 {
182189
183190/// Precompute decision boundaries (midpoints between adjacent centroids).
184191///
185- /// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps
186- /// to centroid 0, a value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`,
187- /// and a value >= ` boundaries[k-2]` maps to centroid `k-1`.
188- pub fn compute_boundaries ( centroids : & [ f32 ] ) -> Vec < f32 > {
192+ /// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps to centroid 0, a
193+ /// value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, and a
194+ /// value ` >= boundaries[k-2]` maps to centroid `k-1`.
195+ pub fn compute_centroid_boundaries ( centroids : & [ f32 ] ) -> Vec < f32 > {
189196 centroids. windows ( 2 ) . map ( |w| ( w[ 0 ] + w[ 1 ] ) * 0.5 ) . collect ( )
190197}
191198
@@ -195,14 +202,21 @@ pub fn compute_boundaries(centroids: &[f32]) -> Vec<f32> {
195202/// centroids. Uses binary search on the midpoints, avoiding distance comparisons
196203/// in the inner loop.
197204#[ inline]
198- #[ allow( clippy:: cast_possible_truncation) ] // bounded by num_centroids <= 256
199205pub fn find_nearest_centroid ( value : f32 , boundaries : & [ f32 ] ) -> u8 {
200206 debug_assert ! (
201207 boundaries. windows( 2 ) . all( |w| w[ 0 ] <= w[ 1 ] ) ,
202208 "boundaries must be sorted"
203209 ) ;
210+ debug_assert ! (
211+ boundaries. len( ) <= 256 , // 1 << 8
212+ "boundaries must be sorted"
213+ ) ;
204214
205- boundaries. partition_point ( |& b| b < value) as u8
215+ #[ expect(
216+ clippy:: cast_possible_truncation,
217+ reason = "num_centroids <= 256 and partition_point will return at most 255"
218+ ) ]
219+ ( boundaries. partition_point ( |& b| b < value) as u8 )
206220}
207221
208222#[ cfg( test) ]
@@ -294,7 +308,7 @@ mod tests {
294308 #[ test]
295309 fn find_nearest_basic ( ) -> VortexResult < ( ) > {
296310 let centroids = get_centroids ( 128 , 2 ) ?;
297- let boundaries = compute_boundaries ( & centroids) ;
311+ let boundaries = compute_centroid_boundaries ( & centroids) ;
298312 assert_eq ! ( find_nearest_centroid( -1.0 , & boundaries) , 0 ) ;
299313
300314 let last_idx = ( centroids. len ( ) - 1 ) as u8 ;
0 commit comments