|
| 1 | +//! F64x8 HEEL distance: 8 f64 distances across 8 HEEL planes in one SIMD pass. |
| 2 | +//! |
| 3 | +//! p64 has 8 HEEL planes (u64 each). For weighted f64 distance computation, |
| 4 | +//! each plane produces one f64 distance value → 8 values = one F64x8 register. |
| 5 | +//! |
| 6 | +//! This module provides the SIMD kernel; p64-bridge calls it. |
| 7 | +//! ndarray = hardware acceleration, consumers use the kernel. |
| 8 | +//! |
| 9 | +//! Dispatch: AVX-512 (native __m512d) → AVX2 (2×__m256d) → scalar. |
| 10 | +//! LazyLock selects at startup. |
| 11 | +
|
| 12 | +use std::sync::LazyLock; |
| 13 | + |
| 14 | +/// Kernel signature: 8 distances in, weighted sum out. |
| 15 | +/// `distances`: 8 f64 values (one per HEEL plane). |
| 16 | +/// `weights`: 8 f64 weights (per-expert importance). |
| 17 | +/// Returns: weighted sum = Σ(distance[i] × weight[i]). |
| 18 | +type HeelF64x8DotFn = unsafe fn(&[f64; 8], &[f64; 8]) -> f64; |
| 19 | + |
| 20 | +#[cfg(target_arch = "x86_64")] |
| 21 | +#[target_feature(enable = "avx512f")] |
| 22 | +unsafe fn heel_dot_avx512(a: &[f64; 8], b: &[f64; 8]) -> f64 { |
| 23 | + use std::arch::x86_64::*; |
| 24 | + let va = _mm512_loadu_pd(a.as_ptr()); |
| 25 | + let vb = _mm512_loadu_pd(b.as_ptr()); |
| 26 | + let prod = _mm512_mul_pd(va, vb); |
| 27 | + _mm512_reduce_add_pd(prod) |
| 28 | +} |
| 29 | + |
| 30 | +#[cfg(target_arch = "x86_64")] |
| 31 | +#[target_feature(enable = "avx2")] |
| 32 | +unsafe fn heel_dot_avx2(a: &[f64; 8], b: &[f64; 8]) -> f64 { |
| 33 | + use std::arch::x86_64::*; |
| 34 | + // 2 passes of 4 lanes |
| 35 | + let va0 = _mm256_loadu_pd(a.as_ptr()); |
| 36 | + let vb0 = _mm256_loadu_pd(b.as_ptr()); |
| 37 | + let p0 = _mm256_mul_pd(va0, vb0); |
| 38 | + |
| 39 | + let va1 = _mm256_loadu_pd(a[4..].as_ptr()); |
| 40 | + let vb1 = _mm256_loadu_pd(b[4..].as_ptr()); |
| 41 | + let p1 = _mm256_mul_pd(va1, vb1); |
| 42 | + |
| 43 | + let sum = _mm256_add_pd(p0, p1); |
| 44 | + // Horizontal sum of 4 f64 |
| 45 | + let hi = _mm256_extractf128_pd(sum, 1); |
| 46 | + let lo = _mm256_castpd256_pd128(sum); |
| 47 | + let pair = _mm_add_pd(lo, hi); |
| 48 | + let hi64 = _mm_unpackhi_pd(pair, pair); |
| 49 | + let result = _mm_add_sd(pair, hi64); |
| 50 | + _mm_cvtsd_f64(result) |
| 51 | +} |
| 52 | + |
| 53 | +fn heel_dot_scalar(a: &[f64; 8], b: &[f64; 8]) -> f64 { |
| 54 | + let mut sum = 0.0f64; |
| 55 | + for i in 0..8 { |
| 56 | + sum += a[i] * b[i]; |
| 57 | + } |
| 58 | + sum |
| 59 | +} |
| 60 | + |
| 61 | +static HEEL_DOT_KERNEL: LazyLock<HeelF64x8DotFn> = LazyLock::new(|| { |
| 62 | + #[cfg(target_arch = "x86_64")] |
| 63 | + { |
| 64 | + if is_x86_feature_detected!("avx512f") { |
| 65 | + return heel_dot_avx512 as HeelF64x8DotFn; |
| 66 | + } |
| 67 | + if is_x86_feature_detected!("avx2") { |
| 68 | + return heel_dot_avx2 as HeelF64x8DotFn; |
| 69 | + } |
| 70 | + } |
| 71 | + heel_dot_scalar as HeelF64x8DotFn |
| 72 | +}); |
| 73 | + |
| 74 | +/// Compute weighted dot product of 8 HEEL plane distances. |
| 75 | +/// |
| 76 | +/// `distances[i]` = distance for HEEL plane i. |
| 77 | +/// `weights[i]` = importance weight for plane i. |
| 78 | +/// Returns: Σ(distances[i] × weights[i]). |
| 79 | +/// |
| 80 | +/// One SIMD pass on AVX-512 (single `vmulpd` + `vreducepd`). |
| 81 | +/// Two passes on AVX2. Scalar fallback for non-x86. |
| 82 | +#[inline] |
| 83 | +pub fn heel_weighted_distance(distances: &[f64; 8], weights: &[f64; 8]) -> f64 { |
| 84 | + unsafe { HEEL_DOT_KERNEL(distances, weights) } |
| 85 | +} |
| 86 | + |
| 87 | +/// Compute L1-like distance across 8 HEEL planes. |
| 88 | +/// |
| 89 | +/// For each plane i: distance[i] = popcount(a[i] XOR b[i]) as f64. |
| 90 | +/// Then weighted sum via F64x8 dot product. |
| 91 | +/// |
| 92 | +/// This converts binary Hamming distances to f64 for weighted combination, |
| 93 | +/// where each plane's contribution is scaled by expert importance. |
| 94 | +pub fn heel_plane_distances(a: &[u64; 8], b: &[u64; 8]) -> [f64; 8] { |
| 95 | + let mut dists = [0.0f64; 8]; |
| 96 | + for i in 0..8 { |
| 97 | + dists[i] = (a[i] ^ b[i]).count_ones() as f64; |
| 98 | + } |
| 99 | + dists |
| 100 | +} |
| 101 | + |
| 102 | +/// Full pipeline: 8 HEEL planes → Hamming per plane → weighted F64x8 dot → scalar distance. |
| 103 | +#[inline] |
| 104 | +pub fn heel_weighted_hamming( |
| 105 | + a_planes: &[u64; 8], |
| 106 | + b_planes: &[u64; 8], |
| 107 | + weights: &[f64; 8], |
| 108 | +) -> f64 { |
| 109 | + let dists = heel_plane_distances(a_planes, b_planes); |
| 110 | + heel_weighted_distance(&dists, weights) |
| 111 | +} |
| 112 | + |
| 113 | +/// Uniform weights (all planes equal). |
| 114 | +pub const UNIFORM_WEIGHTS: [f64; 8] = [1.0; 8]; |
| 115 | + |
| 116 | +/// HEEL-weighted (7 constructive + 1 contradiction at reduced weight). |
| 117 | +/// Contradiction plane (index 7) gets 0.5× weight. |
| 118 | +pub const HEEL_7PLUS1_WEIGHTS: [f64; 8] = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5]; |
| 119 | + |
| 120 | +#[cfg(test)] |
| 121 | +mod tests { |
| 122 | + use super::*; |
| 123 | + |
| 124 | + #[test] |
| 125 | + fn dot_product_basic() { |
| 126 | + let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; |
| 127 | + let b = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; |
| 128 | + let result = heel_weighted_distance(&a, &b); |
| 129 | + assert!((result - 36.0).abs() < 1e-10, "1+2+3+4+5+6+7+8 = 36, got {}", result); |
| 130 | + } |
| 131 | + |
| 132 | + #[test] |
| 133 | + fn dot_product_weighted() { |
| 134 | + let distances = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]; |
| 135 | + let weights = [2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5]; |
| 136 | + let result = heel_weighted_distance(&distances, &weights); |
| 137 | + assert!((result - 60.0).abs() < 1e-10, "10×2 + 80×0.5 = 60, got {}", result); |
| 138 | + } |
| 139 | + |
| 140 | + #[test] |
| 141 | + fn plane_distances_self_zero() { |
| 142 | + let planes = [0x1234u64; 8]; |
| 143 | + let dists = heel_plane_distances(&planes, &planes); |
| 144 | + for d in &dists { |
| 145 | + assert_eq!(*d, 0.0); |
| 146 | + } |
| 147 | + } |
| 148 | + |
| 149 | + #[test] |
| 150 | + fn plane_distances_opposite() { |
| 151 | + let a = [0u64; 8]; |
| 152 | + let b = [u64::MAX; 8]; |
| 153 | + let dists = heel_plane_distances(&a, &b); |
| 154 | + for d in &dists { |
| 155 | + assert_eq!(*d, 64.0); |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + #[test] |
| 160 | + fn full_pipeline_uniform() { |
| 161 | + let a = [0xFFFF_0000_FFFF_0000u64; 8]; |
| 162 | + let b = [0x0000_FFFF_0000_FFFFu64; 8]; |
| 163 | + let d = heel_weighted_hamming(&a, &b, &UNIFORM_WEIGHTS); |
| 164 | + // Each plane: all bits differ = 64 |
| 165 | + assert!((d - 64.0 * 8.0).abs() < 1e-10, "8 planes × 64 bits = 512, got {}", d); |
| 166 | + } |
| 167 | + |
| 168 | + #[test] |
| 169 | + fn seven_plus_one_weights() { |
| 170 | + let a = [0u64; 8]; |
| 171 | + let b = [u64::MAX; 8]; |
| 172 | + let d_uniform = heel_weighted_hamming(&a, &b, &UNIFORM_WEIGHTS); |
| 173 | + let d_7plus1 = heel_weighted_hamming(&a, &b, &HEEL_7PLUS1_WEIGHTS); |
| 174 | + // 7+1: plane 7 at 0.5× = 7×64 + 0.5×64 = 480 vs 512 |
| 175 | + assert!((d_uniform - 512.0).abs() < 1e-10); |
| 176 | + assert!((d_7plus1 - 480.0).abs() < 1e-10, "7×64 + 0.5×64 = 480, got {}", d_7plus1); |
| 177 | + } |
| 178 | +} |
0 commit comments