Skip to content

Commit 2c5cc3a

Browse files
committed
feat(hpc): SIMD-accelerate cam_pq + deepnsm hot paths via crate::simd
All consumer code uses crate::simd only. Zero raw intrinsics. LazyLock dispatch table selects AVX-512 vs AVX2 at startup. cam_pq.rs — squared_l2(): - Called 1,536× per CAM-PQ query (6 subspaces × 256 centroids) - Was: scalar iter().zip().map().sum() - Now: F32x16 for 16D subvectors (one SIMD lane = one subspace dimension) - Fast path: n==16 → single load-subtract-multiply-reduce - Medium path: n>=16 → chunked F32x16 with mul_add + scalar remainder - Estimated 16× speedup on hot path deepnsm.rs — nsm_decompose() normalization: - Was: scalar iter().sum() + scalar /= loop - Now: F32x16 accumulation (4×16=64 elements) + scalar remainder (10) - Normalize via F32x16 * splat(1/sum) + scalar tail deepnsm.rs — nsm_to_fingerprint() XOR: - Was: scalar for j in 0..1250 { result[j] ^= pattern[j] } - Now: U8x64 XOR (19×64=1216 bytes) + scalar remainder (34 bytes) - 64 bytes per SIMD operation vs 1 byte scalar deepnsm.rs — nsm_similarity() cosine: - Was: scalar 3-accumulator loop over 74 elements - Now: F32x16 with mul_add for dot/mag_a/mag_b (4×16=64) + scalar tail (10) - Three reductions in one pass 23 deepnsm tests + 7 dispatch tests passing. Zero regressions. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent e21fbe1 commit 2c5cc3a

2 files changed

Lines changed: 105 additions & 15 deletions

File tree

src/hpc/cam_pq.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,46 @@ pub fn train_hybrid(
456456

457457
// === Internal utilities ===
458458

459-
/// Squared L2 distance between two slices.
459+
/// Squared L2 distance between two slices via `crate::simd`.
460+
///
461+
/// For 16D subvectors (CAM-PQ subspace dimension), this is one F32x16
462+
/// load-subtract-multiply-reduce. Consumer never sees hardware details.
460463
#[inline(always)]
461464
fn squared_l2(a: &[f32], b: &[f32]) -> f32 {
465+
debug_assert_eq!(a.len(), b.len());
466+
let n = a.len();
467+
468+
// Fast path: exactly 16 elements = one F32x16 lane (most common in CAM-PQ).
469+
if n == 16 {
470+
use crate::simd::F32x16;
471+
let va = F32x16::from_slice(a);
472+
let vb = F32x16::from_slice(b);
473+
let diff = va - vb;
474+
return (diff * diff).reduce_sum();
475+
}
476+
477+
// Medium path: process 16 elements at a time, accumulate remainder scalar.
478+
if n >= 16 {
479+
use crate::simd::F32x16;
480+
let mut acc = F32x16::splat(0.0);
481+
let chunks = n / 16;
482+
for i in 0..chunks {
483+
let off = i * 16;
484+
let va = F32x16::from_slice(&a[off..off + 16]);
485+
let vb = F32x16::from_slice(&b[off..off + 16]);
486+
let diff = va - vb;
487+
acc = diff.mul_add(diff, acc);
488+
}
489+
let mut sum = acc.reduce_sum();
490+
// Scalar remainder
491+
for i in (chunks * 16)..n {
492+
let d = a[i] - b[i];
493+
sum += d * d;
494+
}
495+
return sum;
496+
}
497+
498+
// Scalar fallback for tiny slices.
462499
a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
463500
}
464501

src/hpc/deepnsm.rs

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -609,16 +609,34 @@ pub fn nsm_decompose(text: &str) -> NsmDecomposition {
609609
}
610610
}
611611

612-
// Normalize weights to sum to 1.0
613-
let sum: f32 = weights.iter().sum();
614-
if sum > 0.0 {
615-
for w in weights.iter_mut() {
616-
*w /= sum;
612+
// Normalize weights to sum to 1.0 via crate::simd.
613+
// 74 elements = 4×F32x16 (64) + 10 scalar remainder.
614+
let weight_sum = {
615+
use crate::simd::F32x16;
616+
let mut simd_sum = F32x16::splat(0.0);
617+
for chunk in weights[..64].chunks_exact(16) {
618+
simd_sum = simd_sum + F32x16::from_slice(chunk);
617619
}
618-
}
620+
let mut s: f32 = simd_sum.reduce_sum();
621+
for &w in &weights[64..74] {
622+
s += w;
623+
}
624+
if s > 0.0 {
625+
let inv = F32x16::splat(1.0 / s);
626+
for chunk in weights[..64].chunks_exact_mut(16) {
627+
let v = F32x16::from_slice(chunk) * inv;
628+
v.copy_to_slice(chunk);
629+
}
630+
let inv_s = 1.0 / s;
631+
for w in weights[64..74].iter_mut() {
632+
*w *= inv_s;
633+
}
634+
}
635+
s
636+
};
619637

620638
// Determine dominant primes (weight > threshold)
621-
let threshold = if sum > 0.0 { 1.0 / 74.0 } else { 0.0 };
639+
let threshold = if weight_sum > 0.0 { 1.0 / 74.0 } else { 0.0 };
622640
let dominant: Vec<NsmPrime> = ALL_PRIMES
623641
.iter()
624642
.filter(|p| weights[**p as u8 as usize] > threshold)
@@ -656,24 +674,59 @@ pub fn nsm_to_fingerprint(decomp: &NsmDecomposition) -> [u8; 1250] {
656674
let mut reader = hasher.finalize_xof();
657675
reader.fill(&mut pattern);
658676

659-
for j in 0..1250 {
660-
result[j] ^= pattern[j];
677+
// XOR 1250 bytes via crate::simd::U8x64.
678+
// 1250 = 19×64 (1216) + 34 scalar remainder.
679+
{
680+
use crate::simd::U8x64;
681+
let chunks = 1250 / 64; // 19
682+
for c in 0..chunks {
683+
let off = c * 64;
684+
let vr = U8x64::from_slice(&result[off..off + 64]);
685+
let vp = U8x64::from_slice(&pattern[off..off + 64]);
686+
let xored = vr ^ vp;
687+
xored.copy_to_slice(&mut result[off..off + 64]);
688+
}
689+
// Scalar remainder (34 bytes).
690+
for j in (chunks * 64)..1250 {
691+
result[j] ^= pattern[j];
692+
}
661693
}
662694
}
663695

664696
result
665697
}
666698

667-
/// Cosine similarity between two NSM decompositions.
699+
/// Cosine similarity between two NSM decompositions via `crate::simd`.
700+
///
701+
/// 74 elements = 4×F32x16 (64) + 10 scalar remainder.
702+
/// Three accumulations in one pass: dot product, magnitude_a², magnitude_b².
668703
pub fn nsm_similarity(a: &NsmDecomposition, b: &NsmDecomposition) -> f32 {
669-
let mut dot = 0.0f32;
670-
let mut mag_a = 0.0f32;
671-
let mut mag_b = 0.0f32;
672-
for i in 0..74 {
704+
use crate::simd::F32x16;
705+
706+
let mut sdot = F32x16::splat(0.0);
707+
let mut smag_a = F32x16::splat(0.0);
708+
let mut smag_b = F32x16::splat(0.0);
709+
710+
// SIMD: first 64 elements (4 × 16 lanes).
711+
for i in (0..64).step_by(16) {
712+
let va = F32x16::from_slice(&a.weights[i..i + 16]);
713+
let vb = F32x16::from_slice(&b.weights[i..i + 16]);
714+
sdot = va.mul_add(vb, sdot);
715+
smag_a = va.mul_add(va, smag_a);
716+
smag_b = vb.mul_add(vb, smag_b);
717+
}
718+
719+
let mut dot = sdot.reduce_sum();
720+
let mut mag_a = smag_a.reduce_sum();
721+
let mut mag_b = smag_b.reduce_sum();
722+
723+
// Scalar: remaining 10 elements (indices 64..74).
724+
for i in 64..74 {
673725
dot += a.weights[i] * b.weights[i];
674726
mag_a += a.weights[i] * a.weights[i];
675727
mag_b += b.weights[i] * b.weights[i];
676728
}
729+
677730
let denom = (mag_a * mag_b).sqrt();
678731
if denom < 1e-10 {
679732
0.0

0 commit comments

Comments
 (0)