|
11 | 11 |
|
12 | 12 | use std::sync::LazyLock; |
13 | 13 |
|
| 14 | +use rand::SeedableRng; |
| 15 | +use rand::rngs::StdRng; |
| 16 | +use rand_distr::Distribution; |
| 17 | +use rand_distr::Normal; |
| 18 | +use vortex_error::VortexExpect; |
14 | 19 | use vortex_error::VortexResult; |
15 | 20 | use vortex_error::vortex_bail; |
16 | 21 | use vortex_utils::aliases::dash_map::DashMap; |
17 | 22 |
|
| 23 | +use crate::encodings::turboquant::rotation::RotationMatrix; |
| 24 | + |
18 | 25 | /// Number of numerical integration points for computing conditional expectations. |
19 | 26 | const INTEGRATION_POINTS: usize = 1000; |
20 | 27 |
|
@@ -56,6 +63,13 @@ pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> { |
56 | 63 | /// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` |
57 | 64 | /// where `C_d` is the normalizing constant. |
58 | 65 | fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> { |
| 66 | + // For non-power-of-2 dims, the SRHT's structured interaction with zero-padded |
| 67 | + // inputs produces a coordinate distribution that differs from the analytical |
| 68 | + // (1-x^2)^((d-3)/2) model. Use Monte Carlo sampling of the actual distribution. |
| 69 | + if !dimension.is_power_of_two() { |
| 70 | + return max_lloyd_centroids_empirical(dimension, bit_width); |
| 71 | + } |
| 72 | + |
59 | 73 | let num_centroids = 1usize << bit_width; |
60 | 74 | let dim = dimension as f64; |
61 | 75 |
|
@@ -153,6 +167,121 @@ fn pdf_unnormalized(x_val: f64, exponent: f64) -> f64 { |
153 | 167 | } |
154 | 168 | } |
155 | 169 |
|
| 170 | +/// Number of random SRHT instances for Monte Carlo sampling. |
| 171 | +const EMPIRICAL_NUM_SEEDS: usize = 20; |
| 172 | + |
| 173 | +/// Number of random unit vectors per SRHT instance. |
| 174 | +const EMPIRICAL_NUM_VECTORS: usize = 100; |
| 175 | + |
| 176 | +/// Compute optimal centroids via Monte Carlo sampling of the SRHT coordinate |
| 177 | +/// distribution for non-power-of-2 dimensions. |
| 178 | +/// |
| 179 | +/// For zero-padded vectors, the SRHT's structured butterfly interacts with the |
| 180 | +/// padding to produce a coordinate distribution that differs from the analytical |
| 181 | +/// `(1-x^2)^((d-3)/2)` model. This function samples the actual distribution by |
| 182 | +/// rotating many random unit vectors through many random SRHT instances, then |
| 183 | +/// runs 1D k-means (Max-Lloyd) on the collected samples. |
| 184 | +fn max_lloyd_centroids_empirical(dimension: u32, bit_width: u8) -> Vec<f32> { |
| 185 | + let dim = dimension as usize; |
| 186 | + let padded_dim = dim.next_power_of_two(); |
| 187 | + let num_centroids = 1usize << bit_width; |
| 188 | + |
| 189 | + // 1. Collect SRHT coordinate samples. |
| 190 | + let mut samples = Vec::with_capacity(EMPIRICAL_NUM_SEEDS * EMPIRICAL_NUM_VECTORS * padded_dim); |
| 191 | + let mut rng = StdRng::seed_from_u64(0); |
| 192 | + let normal = Normal::new(0.0f32, 1.0) |
| 193 | + .map_err(|e| vortex_error::vortex_err!("Normal distribution error: {e}")) |
| 194 | + .vortex_expect("infallible: Normal::new(0, 1)"); |
| 195 | + |
| 196 | + for _ in 0..EMPIRICAL_NUM_SEEDS { |
| 197 | + let srht_seed: u64 = rand::RngExt::random(&mut rng); |
| 198 | + let rotation = RotationMatrix::try_new(srht_seed, dim) |
| 199 | + .vortex_expect("dim >= 2 validated by get_centroids"); |
| 200 | + |
| 201 | + let mut padded = vec![0.0f32; padded_dim]; |
| 202 | + let mut rotated = vec![0.0f32; padded_dim]; |
| 203 | + |
| 204 | + for _ in 0..EMPIRICAL_NUM_VECTORS { |
| 205 | + // Random unit vector in R^dim, zero-padded to R^padded_dim. |
| 206 | + for val in padded[..dim].iter_mut() { |
| 207 | + *val = normal.sample(&mut rng); |
| 208 | + } |
| 209 | + padded[dim..].fill(0.0); |
| 210 | + let norm: f32 = padded[..dim].iter().map(|&v| v * v).sum::<f32>().sqrt(); |
| 211 | + if norm > 0.0 { |
| 212 | + let inv = 1.0 / norm; |
| 213 | + for val in padded[..dim].iter_mut() { |
| 214 | + *val *= inv; |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + rotation.rotate(&padded, &mut rotated); |
| 219 | + samples.extend_from_slice(&rotated); |
| 220 | + } |
| 221 | + } |
| 222 | + |
| 223 | + // 2. Sort for efficient conditional mean computation via binary search. |
| 224 | + samples.sort_unstable_by(|a, b| a.total_cmp(b)); |
| 225 | + |
| 226 | + // 3. 1D k-means (Max-Lloyd on sorted empirical samples). |
| 227 | + let n = samples.len(); |
| 228 | + let mut centroids: Vec<f64> = (0..num_centroids) |
| 229 | + .map(|idx| { |
| 230 | + // Initialize uniformly across the sample range. |
| 231 | + let lo = samples[0] as f64; |
| 232 | + let hi = samples[n - 1] as f64; |
| 233 | + lo + (hi - lo) * (2.0 * idx as f64 + 1.0) / (2.0 * num_centroids as f64) |
| 234 | + }) |
| 235 | + .collect(); |
| 236 | + |
| 237 | + let samples_f64: Vec<f64> = samples.iter().map(|&v| v as f64).collect(); |
| 238 | + |
| 239 | + for _ in 0..MAX_ITERATIONS { |
| 240 | + // Compute decision boundaries (midpoints between adjacent centroids). |
| 241 | + let mut boundaries = Vec::with_capacity(num_centroids + 1); |
| 242 | + boundaries.push(f64::NEG_INFINITY); |
| 243 | + for idx in 0..num_centroids - 1 { |
| 244 | + boundaries.push((centroids[idx] + centroids[idx + 1]) / 2.0); |
| 245 | + } |
| 246 | + boundaries.push(f64::INFINITY); |
| 247 | + |
| 248 | + // Update each centroid to the mean of samples in its Voronoi cell. |
| 249 | + let mut max_change = 0.0f64; |
| 250 | + for idx in 0..num_centroids { |
| 251 | + let lo = boundaries[idx]; |
| 252 | + let hi = boundaries[idx + 1]; |
| 253 | + |
| 254 | + // Binary search for the range of samples in [lo, hi). |
| 255 | + let start = samples_f64.partition_point(|&v| v < lo); |
| 256 | + let end = samples_f64.partition_point(|&v| v < hi); |
| 257 | + |
| 258 | + if start < end { |
| 259 | + let sum: f64 = samples_f64[start..end].iter().sum(); |
| 260 | + let count = (end - start) as f64; |
| 261 | + let new_centroid = sum / count; |
| 262 | + max_change = max_change.max((new_centroid - centroids[idx]).abs()); |
| 263 | + centroids[idx] = new_centroid; |
| 264 | + } |
| 265 | + } |
| 266 | + |
| 267 | + if max_change < CONVERGENCE_EPSILON { |
| 268 | + break; |
| 269 | + } |
| 270 | + } |
| 271 | + |
| 272 | + // Force symmetry: the SRHT coordinate distribution is symmetric around zero, |
| 273 | + // but Monte Carlo sampling introduces slight asymmetry. Average c[i] and |
| 274 | + // -c[k-1-i] to restore exact symmetry. |
| 275 | + let k = centroids.len(); |
| 276 | + for i in 0..k / 2 { |
| 277 | + let avg = (centroids[i].abs() + centroids[k - 1 - i].abs()) / 2.0; |
| 278 | + centroids[i] = -avg; |
| 279 | + centroids[k - 1 - i] = avg; |
| 280 | + } |
| 281 | + |
| 282 | + centroids.into_iter().map(|val| val as f32).collect() |
| 283 | +} |
| 284 | + |
156 | 285 | /// Precompute decision boundaries (midpoints between adjacent centroids). |
157 | 286 | /// |
158 | 287 | /// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps |
|
0 commit comments