Skip to content

Commit 93e5dc7

Browse files
committed
fixing biases with empirical distribution
Signed-off-by: Will Manning <will@willmanning.io>
1 parent 2c5017a commit 93e5dc7

5 files changed

Lines changed: 145 additions & 12 deletions

File tree

vortex-file/src/strategy.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,10 @@ impl WriteStrategyBuilder {
235235
pub fn with_compact_encodings(mut self) -> Self {
236236
let mut builder = self.builder.unwrap_or_default();
237237
builder = builder.include([
238-
string::ZstdScheme.id(),
239-
integer::PcoScheme.id(),
240-
float::PcoScheme.id(),
241-
]);
238+
string::ZstdScheme.id(),
239+
integer::PcoScheme.id(),
240+
float::PcoScheme.id(),
241+
]);
242242
self.builder = Some(builder);
243243
self
244244
}

vortex-tensor/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ workspace = true
2020
unstable_encodings = [
2121
"dep:half",
2222
"dep:rand",
23+
"dep:rand_distr",
2324
"dep:vortex-compressor",
2425
"dep:vortex-fastlanes",
2526
"dep:vortex-utils",
@@ -39,7 +40,7 @@ itertools = { workspace = true }
3940
num-traits = { workspace = true }
4041
prost = { workspace = true }
4142
rand = { workspace = true, optional = true }
43+
rand_distr = { workspace = true, optional = true }
4244

4345
[dev-dependencies]
44-
rand_distr = { workspace = true }
4546
rstest = { workspace = true }

vortex-tensor/src/encodings/turboquant/centroids.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@
1111
1212
use std::sync::LazyLock;
1313

14+
use rand::SeedableRng;
15+
use rand::rngs::StdRng;
16+
use rand_distr::Distribution;
17+
use rand_distr::Normal;
18+
use vortex_error::VortexExpect;
1419
use vortex_error::VortexResult;
1520
use vortex_error::vortex_bail;
1621
use vortex_utils::aliases::dash_map::DashMap;
1722

23+
use crate::encodings::turboquant::rotation::RotationMatrix;
24+
1825
/// Number of numerical integration points for computing conditional expectations.
1926
const INTEGRATION_POINTS: usize = 1000;
2027

@@ -56,6 +63,13 @@ pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> {
5663
/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]`
5764
/// where `C_d` is the normalizing constant.
5865
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
66+
// For non-power-of-2 dims, the SRHT's structured interaction with zero-padded
67+
// inputs produces a coordinate distribution that differs from the analytical
68+
// (1-x^2)^((d-3)/2) model. Use Monte Carlo sampling of the actual distribution.
69+
if !dimension.is_power_of_two() {
70+
return max_lloyd_centroids_empirical(dimension, bit_width);
71+
}
72+
5973
let num_centroids = 1usize << bit_width;
6074
let dim = dimension as f64;
6175

@@ -153,6 +167,121 @@ fn pdf_unnormalized(x_val: f64, exponent: f64) -> f64 {
153167
}
154168
}
155169

170+
/// Number of random SRHT instances for Monte Carlo sampling.
171+
const EMPIRICAL_NUM_SEEDS: usize = 20;
172+
173+
/// Number of random unit vectors per SRHT instance.
174+
const EMPIRICAL_NUM_VECTORS: usize = 100;
175+
176+
/// Compute optimal centroids via Monte Carlo sampling of the SRHT coordinate
177+
/// distribution for non-power-of-2 dimensions.
178+
///
179+
/// For zero-padded vectors, the SRHT's structured butterfly interacts with the
180+
/// padding to produce a coordinate distribution that differs from the analytical
181+
/// `(1-x^2)^((d-3)/2)` model. This function samples the actual distribution by
182+
/// rotating many random unit vectors through many random SRHT instances, then
183+
/// runs 1D k-means (Max-Lloyd) on the collected samples.
184+
fn max_lloyd_centroids_empirical(dimension: u32, bit_width: u8) -> Vec<f32> {
185+
let dim = dimension as usize;
186+
let padded_dim = dim.next_power_of_two();
187+
let num_centroids = 1usize << bit_width;
188+
189+
// 1. Collect SRHT coordinate samples.
190+
let mut samples = Vec::with_capacity(EMPIRICAL_NUM_SEEDS * EMPIRICAL_NUM_VECTORS * padded_dim);
191+
let mut rng = StdRng::seed_from_u64(0);
192+
let normal = Normal::new(0.0f32, 1.0)
193+
.map_err(|e| vortex_error::vortex_err!("Normal distribution error: {e}"))
194+
.vortex_expect("infallible: Normal::new(0, 1)");
195+
196+
for _ in 0..EMPIRICAL_NUM_SEEDS {
197+
let srht_seed: u64 = rand::RngExt::random(&mut rng);
198+
let rotation = RotationMatrix::try_new(srht_seed, dim)
199+
.vortex_expect("dim >= 2 validated by get_centroids");
200+
201+
let mut padded = vec![0.0f32; padded_dim];
202+
let mut rotated = vec![0.0f32; padded_dim];
203+
204+
for _ in 0..EMPIRICAL_NUM_VECTORS {
205+
// Random unit vector in R^dim, zero-padded to R^padded_dim.
206+
for val in padded[..dim].iter_mut() {
207+
*val = normal.sample(&mut rng);
208+
}
209+
padded[dim..].fill(0.0);
210+
let norm: f32 = padded[..dim].iter().map(|&v| v * v).sum::<f32>().sqrt();
211+
if norm > 0.0 {
212+
let inv = 1.0 / norm;
213+
for val in padded[..dim].iter_mut() {
214+
*val *= inv;
215+
}
216+
}
217+
218+
rotation.rotate(&padded, &mut rotated);
219+
samples.extend_from_slice(&rotated);
220+
}
221+
}
222+
223+
// 2. Sort for efficient conditional mean computation via binary search.
224+
samples.sort_unstable_by(|a, b| a.total_cmp(b));
225+
226+
// 3. 1D k-means (Max-Lloyd on sorted empirical samples).
227+
let n = samples.len();
228+
let mut centroids: Vec<f64> = (0..num_centroids)
229+
.map(|idx| {
230+
// Initialize uniformly across the sample range.
231+
let lo = samples[0] as f64;
232+
let hi = samples[n - 1] as f64;
233+
lo + (hi - lo) * (2.0 * idx as f64 + 1.0) / (2.0 * num_centroids as f64)
234+
})
235+
.collect();
236+
237+
let samples_f64: Vec<f64> = samples.iter().map(|&v| v as f64).collect();
238+
239+
for _ in 0..MAX_ITERATIONS {
240+
// Compute decision boundaries (midpoints between adjacent centroids).
241+
let mut boundaries = Vec::with_capacity(num_centroids + 1);
242+
boundaries.push(f64::NEG_INFINITY);
243+
for idx in 0..num_centroids - 1 {
244+
boundaries.push((centroids[idx] + centroids[idx + 1]) / 2.0);
245+
}
246+
boundaries.push(f64::INFINITY);
247+
248+
// Update each centroid to the mean of samples in its Voronoi cell.
249+
let mut max_change = 0.0f64;
250+
for idx in 0..num_centroids {
251+
let lo = boundaries[idx];
252+
let hi = boundaries[idx + 1];
253+
254+
// Binary search for the range of samples in [lo, hi).
255+
let start = samples_f64.partition_point(|&v| v < lo);
256+
let end = samples_f64.partition_point(|&v| v < hi);
257+
258+
if start < end {
259+
let sum: f64 = samples_f64[start..end].iter().sum();
260+
let count = (end - start) as f64;
261+
let new_centroid = sum / count;
262+
max_change = max_change.max((new_centroid - centroids[idx]).abs());
263+
centroids[idx] = new_centroid;
264+
}
265+
}
266+
267+
if max_change < CONVERGENCE_EPSILON {
268+
break;
269+
}
270+
}
271+
272+
// Force symmetry: the SRHT coordinate distribution is symmetric around zero,
273+
// but Monte Carlo sampling introduces slight asymmetry. Average c[i] and
274+
// -c[k-1-i] to restore exact symmetry.
275+
let k = centroids.len();
276+
for i in 0..k / 2 {
277+
let avg = (centroids[i].abs() + centroids[k - 1 - i].abs()) / 2.0;
278+
centroids[i] = -avg;
279+
centroids[k - 1 - i] = avg;
280+
}
281+
282+
centroids.into_iter().map(|val| val as f32).collect()
283+
}
284+
156285
/// Precompute decision boundaries (midpoints between adjacent centroids).
157286
///
158287
/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ fn turboquant_quantize_core(
9696

9797
let f32_elements = extract_f32_elements(fsl)?;
9898

99-
let centroids = get_centroids(padded_dim as u32, bit_width)?;
99+
let centroids = get_centroids(dimension as u32, bit_width)?;
100100
let boundaries = compute_boundaries(&centroids);
101101

102102
let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);

vortex-tensor/src/encodings/turboquant/mod.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -413,12 +413,15 @@ mod tests {
413413

414414
let mean_rel_error = qjl_mean_signed_relative_error(&original, &decoded, dim, num_rows);
415415

416-
// For power-of-2 dims, QJL bias should be small (< 0.15).
417-
// For non-power-of-2 dims (e.g., 768 padded to 1024), the bias is
418-
// inherently larger because the SRHT centroids are optimized for the
419-
// padded dimension's coordinate distribution, which differs from the
420-
// actual distribution of a zero-padded lower-dimensional vector.
421-
let threshold = if dim.is_power_of_two() { 0.15 } else { 0.25 };
416+
// With empirical centroids, 4+ bit QJL achieves < 0.15 bias for all
417+
// dims. At very low bit widths (2-3 bits), non-power-of-2 dims still
418+
// have elevated bias due to the interaction between high quantization
419+
// noise and the SRHT zero-padding structure.
420+
let threshold = if dim.is_power_of_two() || bit_width >= 4 {
421+
0.15
422+
} else {
423+
0.30
424+
};
422425
assert!(
423426
mean_rel_error.abs() < threshold,
424427
"QJL inner product bias too high: {mean_rel_error:.4} for dim={dim}, bits={bit_width} \

0 commit comments

Comments
 (0)