Skip to content

Commit 5277b29

Browse files
committed
feat: heel_f64x8 — F64x8 SIMD kernel for 8 HEEL plane weighted distance
8 HEEL planes × 1 distance each = 8 f64 = one F64x8 register. LazyLock dispatch: AVX-512 (vmulpd + vreducepd) → AVX2 (2×vmulpd) → scalar. Functions: heel_weighted_distance(&[f64;8], &[f64;8]) → f64 (weighted dot) heel_plane_distances(&[u64;8], &[u64;8]) → [f64;8] (Hamming per plane) heel_weighted_hamming(a, b, weights) → f64 (full pipeline) Predefined weights: UNIFORM_WEIGHTS = [1.0; 8] HEEL_7PLUS1_WEIGHTS = [1,1,1,1,1,1,1, 0.5] (contradiction at half) 6 tests passing. https://claude.ai/code/session_01ChLvBfpJS8dQhHxRD4pYNp
1 parent 226d7ce commit 5277b29

2 files changed

Lines changed: 180 additions & 0 deletions

File tree

src/hpc/heel_f64x8.rs

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
//! F64x8 HEEL distance: 8 f64 distances across 8 HEEL planes in one SIMD pass.
2+
//!
3+
//! p64 has 8 HEEL planes (u64 each). For weighted f64 distance computation,
4+
//! each plane produces one f64 distance value → 8 values = one F64x8 register.
5+
//!
6+
//! This module provides the SIMD kernel; p64-bridge calls it.
7+
//! ndarray = hardware acceleration, consumers use the kernel.
8+
//!
9+
//! Dispatch: AVX-512 (native __m512d) → AVX2 (2×__m256d) → scalar.
10+
//! LazyLock selects at startup.
11+
12+
use std::sync::LazyLock;
13+
14+
/// Kernel signature: 8 distances in, weighted sum out.
15+
/// `distances`: 8 f64 values (one per HEEL plane).
16+
/// `weights`: 8 f64 weights (per-expert importance).
17+
/// Returns: weighted sum = Σ(distance[i] × weight[i]).
18+
type HeelF64x8DotFn = unsafe fn(&[f64; 8], &[f64; 8]) -> f64;
19+
20+
#[cfg(target_arch = "x86_64")]
21+
#[target_feature(enable = "avx512f")]
22+
unsafe fn heel_dot_avx512(a: &[f64; 8], b: &[f64; 8]) -> f64 {
23+
use std::arch::x86_64::*;
24+
let va = _mm512_loadu_pd(a.as_ptr());
25+
let vb = _mm512_loadu_pd(b.as_ptr());
26+
let prod = _mm512_mul_pd(va, vb);
27+
_mm512_reduce_add_pd(prod)
28+
}
29+
30+
#[cfg(target_arch = "x86_64")]
31+
#[target_feature(enable = "avx2")]
32+
unsafe fn heel_dot_avx2(a: &[f64; 8], b: &[f64; 8]) -> f64 {
33+
use std::arch::x86_64::*;
34+
// 2 passes of 4 lanes
35+
let va0 = _mm256_loadu_pd(a.as_ptr());
36+
let vb0 = _mm256_loadu_pd(b.as_ptr());
37+
let p0 = _mm256_mul_pd(va0, vb0);
38+
39+
let va1 = _mm256_loadu_pd(a[4..].as_ptr());
40+
let vb1 = _mm256_loadu_pd(b[4..].as_ptr());
41+
let p1 = _mm256_mul_pd(va1, vb1);
42+
43+
let sum = _mm256_add_pd(p0, p1);
44+
// Horizontal sum of 4 f64
45+
let hi = _mm256_extractf128_pd(sum, 1);
46+
let lo = _mm256_castpd256_pd128(sum);
47+
let pair = _mm_add_pd(lo, hi);
48+
let hi64 = _mm_unpackhi_pd(pair, pair);
49+
let result = _mm_add_sd(pair, hi64);
50+
_mm_cvtsd_f64(result)
51+
}
52+
53+
fn heel_dot_scalar(a: &[f64; 8], b: &[f64; 8]) -> f64 {
54+
let mut sum = 0.0f64;
55+
for i in 0..8 {
56+
sum += a[i] * b[i];
57+
}
58+
sum
59+
}
60+
61+
static HEEL_DOT_KERNEL: LazyLock<HeelF64x8DotFn> = LazyLock::new(|| {
62+
#[cfg(target_arch = "x86_64")]
63+
{
64+
if is_x86_feature_detected!("avx512f") {
65+
return heel_dot_avx512 as HeelF64x8DotFn;
66+
}
67+
if is_x86_feature_detected!("avx2") {
68+
return heel_dot_avx2 as HeelF64x8DotFn;
69+
}
70+
}
71+
heel_dot_scalar as HeelF64x8DotFn
72+
});
73+
74+
/// Compute weighted dot product of 8 HEEL plane distances.
75+
///
76+
/// `distances[i]` = distance for HEEL plane i.
77+
/// `weights[i]` = importance weight for plane i.
78+
/// Returns: Σ(distances[i] × weights[i]).
79+
///
80+
/// One SIMD pass on AVX-512 (single `vmulpd` + `vreducepd`).
81+
/// Two passes on AVX2. Scalar fallback for non-x86.
82+
#[inline]
83+
pub fn heel_weighted_distance(distances: &[f64; 8], weights: &[f64; 8]) -> f64 {
84+
unsafe { HEEL_DOT_KERNEL(distances, weights) }
85+
}
86+
87+
/// Compute L1-like distance across 8 HEEL planes.
88+
///
89+
/// For each plane i: distance[i] = popcount(a[i] XOR b[i]) as f64.
90+
/// Then weighted sum via F64x8 dot product.
91+
///
92+
/// This converts binary Hamming distances to f64 for weighted combination,
93+
/// where each plane's contribution is scaled by expert importance.
94+
pub fn heel_plane_distances(a: &[u64; 8], b: &[u64; 8]) -> [f64; 8] {
95+
let mut dists = [0.0f64; 8];
96+
for i in 0..8 {
97+
dists[i] = (a[i] ^ b[i]).count_ones() as f64;
98+
}
99+
dists
100+
}
101+
102+
/// Full pipeline: 8 HEEL planes → Hamming per plane → weighted F64x8 dot → scalar distance.
103+
#[inline]
104+
pub fn heel_weighted_hamming(
105+
a_planes: &[u64; 8],
106+
b_planes: &[u64; 8],
107+
weights: &[f64; 8],
108+
) -> f64 {
109+
let dists = heel_plane_distances(a_planes, b_planes);
110+
heel_weighted_distance(&dists, weights)
111+
}
112+
113+
/// Uniform weights (all planes equal).
114+
pub const UNIFORM_WEIGHTS: [f64; 8] = [1.0; 8];
115+
116+
/// HEEL-weighted (7 constructive + 1 contradiction at reduced weight).
117+
/// Contradiction plane (index 7) gets 0.5× weight.
118+
pub const HEEL_7PLUS1_WEIGHTS: [f64; 8] = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5];
119+
120+
#[cfg(test)]
121+
mod tests {
122+
use super::*;
123+
124+
#[test]
125+
fn dot_product_basic() {
126+
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
127+
let b = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
128+
let result = heel_weighted_distance(&a, &b);
129+
assert!((result - 36.0).abs() < 1e-10, "1+2+3+4+5+6+7+8 = 36, got {}", result);
130+
}
131+
132+
#[test]
133+
fn dot_product_weighted() {
134+
let distances = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0];
135+
let weights = [2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5];
136+
let result = heel_weighted_distance(&distances, &weights);
137+
assert!((result - 60.0).abs() < 1e-10, "10×2 + 80×0.5 = 60, got {}", result);
138+
}
139+
140+
#[test]
141+
fn plane_distances_self_zero() {
142+
let planes = [0x1234u64; 8];
143+
let dists = heel_plane_distances(&planes, &planes);
144+
for d in &dists {
145+
assert_eq!(*d, 0.0);
146+
}
147+
}
148+
149+
#[test]
150+
fn plane_distances_opposite() {
151+
let a = [0u64; 8];
152+
let b = [u64::MAX; 8];
153+
let dists = heel_plane_distances(&a, &b);
154+
for d in &dists {
155+
assert_eq!(*d, 64.0);
156+
}
157+
}
158+
159+
#[test]
160+
fn full_pipeline_uniform() {
161+
let a = [0xFFFF_0000_FFFF_0000u64; 8];
162+
let b = [0x0000_FFFF_0000_FFFFu64; 8];
163+
let d = heel_weighted_hamming(&a, &b, &UNIFORM_WEIGHTS);
164+
// Each plane: all bits differ = 64
165+
assert!((d - 64.0 * 8.0).abs() < 1e-10, "8 planes × 64 bits = 512, got {}", d);
166+
}
167+
168+
#[test]
169+
fn seven_plus_one_weights() {
170+
let a = [0u64; 8];
171+
let b = [u64::MAX; 8];
172+
let d_uniform = heel_weighted_hamming(&a, &b, &UNIFORM_WEIGHTS);
173+
let d_7plus1 = heel_weighted_hamming(&a, &b, &HEEL_7PLUS1_WEIGHTS);
174+
// 7+1: plane 7 at 0.5× = 7×64 + 0.5×64 = 480 vs 512
175+
assert!((d_uniform - 512.0).abs() < 1e-10);
176+
assert!((d_7plus1 - 480.0).abs() < 1e-10, "7×64 + 0.5×64 = 480, got {}", d_7plus1);
177+
}
178+
}

src/hpc/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ pub mod node;
5252
#[allow(missing_docs)]
5353
pub mod cascade;
5454
#[allow(missing_docs)]
55+
pub mod heel_f64x8;
56+
#[allow(missing_docs)]
5557
pub mod bf16_truth;
5658
#[allow(missing_docs)]
5759
pub mod causality;

0 commit comments

Comments
 (0)