|
| 1 | +//! F64x8 HEEL distance: 8 f64 distances across 8 HEEL planes in one SIMD pass. |
| 2 | +//! |
| 3 | +//! p64 has 8 HEEL planes (u64 each). For weighted f64 distance computation, |
| 4 | +//! each plane produces one f64 distance value → 8 values = one F64x8 register. |
| 5 | +//! |
| 6 | +//! Uses `crate::simd::F64x8` polyfill — automatic dispatch: |
| 7 | +//! AVX-512: native __m512d (one register) |
| 8 | +//! AVX2: 2× __m256d (two registers, same API) |
| 9 | +//! Scalar: [f64; 8] fallback |
| 10 | +//! Consumer writes `crate::simd::F64x8`. The polyfill handles the rest. |
| 11 | +
|
| 12 | +use crate::simd::F64x8; |
| 13 | + |
| 14 | +/// Compute weighted dot product of 8 HEEL plane distances. |
| 15 | +/// |
| 16 | +/// `distances[i]` = distance for HEEL plane i. |
| 17 | +/// `weights[i]` = importance weight for plane i. |
| 18 | +/// Returns: Σ(distances[i] × weights[i]). |
| 19 | +/// |
| 20 | +/// One F64x8 multiply + reduce_sum. On AVX-512: single vmulpd + vreducepd. |
| 21 | +/// On AVX2: 2× vmulpd + 2× haddpd. Scalar: 8 multiplies + sum. |
| 22 | +#[inline] |
| 23 | +pub fn heel_weighted_distance(distances: &[f64; 8], weights: &[f64; 8]) -> f64 { |
| 24 | + let vd = F64x8::from_slice(distances); |
| 25 | + let vw = F64x8::from_slice(weights); |
| 26 | + (vd * vw).reduce_sum() |
| 27 | +} |
| 28 | + |
| 29 | +/// Compute L1-like distance across 8 HEEL planes. |
| 30 | +/// |
| 31 | +/// For each plane i: distance[i] = popcount(a[i] XOR b[i]) as f64. |
| 32 | +/// This is Hamming on binary HEEL planes — valid because HEEL planes |
| 33 | +/// ARE uniform binary data (unlike bgz17 i16 which must use L1). |
| 34 | +pub fn heel_plane_distances(a: &[u64; 8], b: &[u64; 8]) -> [f64; 8] { |
| 35 | + let mut dists = [0.0f64; 8]; |
| 36 | + for i in 0..8 { |
| 37 | + dists[i] = (a[i] ^ b[i]).count_ones() as f64; |
| 38 | + } |
| 39 | + dists |
| 40 | +} |
| 41 | + |
| 42 | +/// Full pipeline: 8 HEEL planes → Hamming per plane → weighted F64x8 dot → scalar. |
| 43 | +#[inline] |
| 44 | +pub fn heel_weighted_hamming( |
| 45 | + a_planes: &[u64; 8], |
| 46 | + b_planes: &[u64; 8], |
| 47 | + weights: &[f64; 8], |
| 48 | +) -> f64 { |
| 49 | + let dists = heel_plane_distances(a_planes, b_planes); |
| 50 | + heel_weighted_distance(&dists, weights) |
| 51 | +} |
| 52 | + |
| 53 | +/// Uniform weights (all planes equal). |
| 54 | +pub const UNIFORM_WEIGHTS: [f64; 8] = [1.0; 8]; |
| 55 | + |
| 56 | +/// HEEL-weighted (7 constructive + 1 contradiction at reduced weight). |
| 57 | +/// Contradiction plane (index 7) gets 0.5× weight. |
| 58 | +pub const HEEL_7PLUS1_WEIGHTS: [f64; 8] = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5]; |
| 59 | + |
| 60 | +// ═══════════════════════════════════════════════════════════════════════════ |
| 61 | +// SIMD cosine similarity via F64x8 — for CLAM cosine clustering |
| 62 | +// ═══════════════════════════════════════════════════════════════════════════ |
| 63 | + |
| 64 | +/// SIMD dot product on f64 slices via F64x8. |
| 65 | +/// |
| 66 | +/// Processes 8 elements per iteration. Remainder handled scalar. |
| 67 | +/// Used by cosine_simd as the inner kernel. |
| 68 | +pub fn dot_f64_simd(a: &[f64], b: &[f64]) -> f64 { |
| 69 | + let n = a.len().min(b.len()); |
| 70 | + let chunks = n / 8; |
| 71 | + let remainder = n % 8; |
| 72 | + |
| 73 | + let mut acc = F64x8::splat(0.0); |
| 74 | + for i in 0..chunks { |
| 75 | + let va = F64x8::from_slice(&a[i * 8..]); |
| 76 | + let vb = F64x8::from_slice(&b[i * 8..]); |
| 77 | + acc = va.mul_add(vb, acc); // acc = va * vb + acc (FMA) |
| 78 | + } |
| 79 | + let mut sum = acc.reduce_sum(); |
| 80 | + |
| 81 | + // Scalar remainder |
| 82 | + let offset = chunks * 8; |
| 83 | + for i in 0..remainder { |
| 84 | + sum += a[offset + i] * b[offset + i]; |
| 85 | + } |
| 86 | + sum |
| 87 | +} |
| 88 | + |
| 89 | +/// SIMD sum of squares via F64x8. |
| 90 | +pub fn sum_sq_f64_simd(a: &[f64]) -> f64 { |
| 91 | + let n = a.len(); |
| 92 | + let chunks = n / 8; |
| 93 | + let remainder = n % 8; |
| 94 | + |
| 95 | + let mut acc = F64x8::splat(0.0); |
| 96 | + for i in 0..chunks { |
| 97 | + let va = F64x8::from_slice(&a[i * 8..]); |
| 98 | + acc = va.mul_add(va, acc); // acc = va * va + acc |
| 99 | + } |
| 100 | + let mut sum = acc.reduce_sum(); |
| 101 | + |
| 102 | + let offset = chunks * 8; |
| 103 | + for i in 0..remainder { |
| 104 | + sum += a[offset + i] * a[offset + i]; |
| 105 | + } |
| 106 | + sum |
| 107 | +} |
| 108 | + |
| 109 | +/// SIMD cosine similarity on f64 slices. |
| 110 | +/// |
| 111 | +/// Computes dot(a,b) / (||a|| × ||b||) using F64x8 FMA. |
| 112 | +/// Single pass: accumulates dot, norm_a, norm_b simultaneously. |
| 113 | +pub fn cosine_f64_simd(a: &[f64], b: &[f64]) -> f64 { |
| 114 | + let n = a.len().min(b.len()); |
| 115 | + let chunks = n / 8; |
| 116 | + let remainder = n % 8; |
| 117 | + |
| 118 | + let mut dot_acc = F64x8::splat(0.0); |
| 119 | + let mut na_acc = F64x8::splat(0.0); |
| 120 | + let mut nb_acc = F64x8::splat(0.0); |
| 121 | + |
| 122 | + for i in 0..chunks { |
| 123 | + let va = F64x8::from_slice(&a[i * 8..]); |
| 124 | + let vb = F64x8::from_slice(&b[i * 8..]); |
| 125 | + dot_acc = va.mul_add(vb, dot_acc); // dot += a*b |
| 126 | + na_acc = va.mul_add(va, na_acc); // na += a*a |
| 127 | + nb_acc = vb.mul_add(vb, nb_acc); // nb += b*b |
| 128 | + } |
| 129 | + |
| 130 | + let mut dot = dot_acc.reduce_sum(); |
| 131 | + let mut na = na_acc.reduce_sum(); |
| 132 | + let mut nb = nb_acc.reduce_sum(); |
| 133 | + |
| 134 | + let offset = chunks * 8; |
| 135 | + for i in 0..remainder { |
| 136 | + dot += a[offset + i] * b[offset + i]; |
| 137 | + na += a[offset + i] * a[offset + i]; |
| 138 | + nb += b[offset + i] * b[offset + i]; |
| 139 | + } |
| 140 | + |
| 141 | + let denom = (na * nb).sqrt(); |
| 142 | + if denom < 1e-12 { 0.0 } else { dot / denom } |
| 143 | +} |
| 144 | + |
| 145 | +/// SIMD cosine similarity on f32 slices (converts to f64 internally for precision). |
| 146 | +/// |
| 147 | +/// For hot paths where input is f32 but you need f64 precision cosine. |
| 148 | +/// Converts 8 f32 → 8 f64 per chunk via scalar widening, then F64x8 FMA. |
| 149 | +pub fn cosine_f32_to_f64_simd(a: &[f32], b: &[f32]) -> f64 { |
| 150 | + let n = a.len().min(b.len()); |
| 151 | + let chunks = n / 8; |
| 152 | + let remainder = n % 8; |
| 153 | + |
| 154 | + let mut dot_acc = F64x8::splat(0.0); |
| 155 | + let mut na_acc = F64x8::splat(0.0); |
| 156 | + let mut nb_acc = F64x8::splat(0.0); |
| 157 | + |
| 158 | + let mut buf_a = [0.0f64; 8]; |
| 159 | + let mut buf_b = [0.0f64; 8]; |
| 160 | + |
| 161 | + for i in 0..chunks { |
| 162 | + let off = i * 8; |
| 163 | + for j in 0..8 { |
| 164 | + buf_a[j] = a[off + j] as f64; |
| 165 | + buf_b[j] = b[off + j] as f64; |
| 166 | + } |
| 167 | + let va = F64x8::from_slice(&buf_a); |
| 168 | + let vb = F64x8::from_slice(&buf_b); |
| 169 | + dot_acc = va.mul_add(vb, dot_acc); |
| 170 | + na_acc = va.mul_add(va, na_acc); |
| 171 | + nb_acc = vb.mul_add(vb, nb_acc); |
| 172 | + } |
| 173 | + |
| 174 | + let mut dot = dot_acc.reduce_sum(); |
| 175 | + let mut na = na_acc.reduce_sum(); |
| 176 | + let mut nb = nb_acc.reduce_sum(); |
| 177 | + |
| 178 | + let offset = chunks * 8; |
| 179 | + for i in 0..remainder { |
| 180 | + let ai = a[offset + i] as f64; |
| 181 | + let bi = b[offset + i] as f64; |
| 182 | + dot += ai * bi; |
| 183 | + na += ai * ai; |
| 184 | + nb += bi * bi; |
| 185 | + } |
| 186 | + |
| 187 | + let denom = (na * nb).sqrt(); |
| 188 | + if denom < 1e-12 { 0.0 } else { dot / denom } |
| 189 | +} |
| 190 | + |
| 191 | +#[cfg(test)] |
| 192 | +mod tests { |
| 193 | + use super::*; |
| 194 | + |
| 195 | + #[test] |
| 196 | + fn heel_dot_basic() { |
| 197 | + let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; |
| 198 | + let b = [1.0; 8]; |
| 199 | + let result = heel_weighted_distance(&a, &b); |
| 200 | + assert!((result - 36.0).abs() < 1e-10, "1+2+...+8 = 36, got {}", result); |
| 201 | + } |
| 202 | + |
| 203 | + #[test] |
| 204 | + fn heel_dot_weighted() { |
| 205 | + let distances = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]; |
| 206 | + let weights = [2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5]; |
| 207 | + let result = heel_weighted_distance(&distances, &weights); |
| 208 | + assert!((result - 60.0).abs() < 1e-10, "10×2 + 80×0.5 = 60, got {}", result); |
| 209 | + } |
| 210 | + |
| 211 | + #[test] |
| 212 | + fn plane_distances_self_zero() { |
| 213 | + let planes = [0x1234u64; 8]; |
| 214 | + let dists = heel_plane_distances(&planes, &planes); |
| 215 | + for d in &dists { assert_eq!(*d, 0.0); } |
| 216 | + } |
| 217 | + |
| 218 | + #[test] |
| 219 | + fn plane_distances_opposite() { |
| 220 | + let a = [0u64; 8]; |
| 221 | + let b = [u64::MAX; 8]; |
| 222 | + let dists = heel_plane_distances(&a, &b); |
| 223 | + for d in &dists { assert_eq!(*d, 64.0); } |
| 224 | + } |
| 225 | + |
| 226 | + #[test] |
| 227 | + fn full_pipeline_uniform() { |
| 228 | + let a = [0xFFFF_0000_FFFF_0000u64; 8]; |
| 229 | + let b = [0x0000_FFFF_0000_FFFFu64; 8]; |
| 230 | + let d = heel_weighted_hamming(&a, &b, &UNIFORM_WEIGHTS); |
| 231 | + assert!((d - 512.0).abs() < 1e-10, "8×64 = 512, got {}", d); |
| 232 | + } |
| 233 | + |
| 234 | + #[test] |
| 235 | + fn seven_plus_one_weights() { |
| 236 | + let a = [0u64; 8]; |
| 237 | + let b = [u64::MAX; 8]; |
| 238 | + let d = heel_weighted_hamming(&a, &b, &HEEL_7PLUS1_WEIGHTS); |
| 239 | + assert!((d - 480.0).abs() < 1e-10, "7×64 + 0.5×64 = 480, got {}", d); |
| 240 | + } |
| 241 | + |
| 242 | + // ── SIMD cosine tests ─────────────────────────────────────────── |
| 243 | + |
| 244 | + #[test] |
| 245 | + fn cosine_identical() { |
| 246 | + let a: Vec<f64> = (0..1024).map(|i| (i as f64 * 0.01).sin()).collect(); |
| 247 | + let c = cosine_f64_simd(&a, &a); |
| 248 | + assert!((c - 1.0).abs() < 1e-10, "self-cosine should be 1.0: {}", c); |
| 249 | + } |
| 250 | + |
| 251 | + #[test] |
| 252 | + fn cosine_opposite() { |
| 253 | + let a: Vec<f64> = (0..256).map(|i| i as f64 * 0.1).collect(); |
| 254 | + let b: Vec<f64> = a.iter().map(|v| -v).collect(); |
| 255 | + let c = cosine_f64_simd(&a, &b); |
| 256 | + assert!((c - (-1.0)).abs() < 1e-10, "opposite should be -1.0: {}", c); |
| 257 | + } |
| 258 | + |
| 259 | + #[test] |
| 260 | + fn cosine_orthogonal() { |
| 261 | + let mut a = vec![0.0f64; 256]; |
| 262 | + let mut b = vec![0.0f64; 256]; |
| 263 | + a[0] = 1.0; |
| 264 | + b[1] = 1.0; |
| 265 | + let c = cosine_f64_simd(&a, &b); |
| 266 | + assert!(c.abs() < 1e-10, "orthogonal should be 0.0: {}", c); |
| 267 | + } |
| 268 | + |
| 269 | + #[test] |
| 270 | + fn cosine_matches_scalar() { |
| 271 | + let a: Vec<f64> = (0..333).map(|i| (i as f64 * 0.037).sin()).collect(); |
| 272 | + let b: Vec<f64> = (0..333).map(|i| (i as f64 * 0.023).cos()).collect(); |
| 273 | + |
| 274 | + let simd_cos = cosine_f64_simd(&a, &b); |
| 275 | + |
| 276 | + // Scalar reference |
| 277 | + let dot: f64 = a.iter().zip(&b).map(|(x, y)| x * y).sum(); |
| 278 | + let na: f64 = a.iter().map(|x| x * x).sum(); |
| 279 | + let nb: f64 = b.iter().map(|x| x * x).sum(); |
| 280 | + let scalar_cos = dot / (na * nb).sqrt(); |
| 281 | + |
| 282 | + assert!((simd_cos - scalar_cos).abs() < 1e-10, |
| 283 | + "SIMD {:.12} vs scalar {:.12}", simd_cos, scalar_cos); |
| 284 | + } |
| 285 | + |
| 286 | + #[test] |
| 287 | + fn cosine_f32_matches_f64() { |
| 288 | + let a_f32: Vec<f32> = (0..500).map(|i| (i as f32 * 0.01).sin()).collect(); |
| 289 | + let b_f32: Vec<f32> = (0..500).map(|i| (i as f32 * 0.02).cos()).collect(); |
| 290 | + |
| 291 | + let a_f64: Vec<f64> = a_f32.iter().map(|&v| v as f64).collect(); |
| 292 | + let b_f64: Vec<f64> = b_f32.iter().map(|&v| v as f64).collect(); |
| 293 | + |
| 294 | + let cos_f64 = cosine_f64_simd(&a_f64, &b_f64); |
| 295 | + let cos_f32 = cosine_f32_to_f64_simd(&a_f32, &b_f32); |
| 296 | + |
| 297 | + assert!((cos_f64 - cos_f32).abs() < 1e-6, |
| 298 | + "f32 {:.10} vs f64 {:.10}", cos_f32, cos_f64); |
| 299 | + } |
| 300 | + |
| 301 | + #[test] |
| 302 | + fn dot_f64_simd_basic() { |
| 303 | + let a = [1.0f64; 24]; |
| 304 | + let b = [2.0f64; 24]; |
| 305 | + let d = dot_f64_simd(&a, &b); |
| 306 | + assert!((d - 48.0).abs() < 1e-10, "24×2 = 48, got {}", d); |
| 307 | + } |
| 308 | +} |
0 commit comments