Skip to content

Commit 7d40704

Browse files
committed
feat(backend/native): TD-T6 — real AVX2 kernels for scal/nrm2/asum
Closes TD-T6 (critical audit finding from the per-CPU matrix doc). Before this commit, the AVX2 native BLAS-1 module had: pub fn scal_f32(alpha: f32, x: &mut [f32]) { super::scalar::scal_f32(alpha, x); // ← scalar shim, no AVX2 } pub fn nrm2_f32(x: &[f32]) -> f32 { super::scalar::nrm2_f32(x) // ← scalar shim } pub fn asum_f32(x: &[f32]) -> f32 { super::scalar::asum_f32(x) // ← scalar shim } // ... and f64 siblings, same shape These were the documented "// No AVX2 specialization — fall through to scalar" path. Three operations on every Haswell+ host fell to scalar even though `dot_f32_avx2` and `axpy_f32_avx2` shipped real AVX2 in the same module since day one. PR #180's audit flagged this as TD-T6 (critical: blocks BLAS-1 throughput on Haswell / Arrow Lake / Zen 1-3). New AVX2 kernels (6 total — f32 + f64 for each of scal / nrm2 / asum): scal: broadcast α to ymm via `_mm256_set1_ps`, multiply 8/4 lanes at a time via `_mm256_mul_ps`/`_mm256_mul_pd`, scalar tail. Stores result back to the same buffer in-place. nrm2: two-accumulator unroll with `_mm256_fmadd_ps`/`_pd` (x² accumulated via FMA, single-rounded per IEEE), horizontal reduce + scalar sqrt. Same shape as `dot_f32_avx2` (which also unrolls 2 accumulators + uses FMA), just operates on one input vector instead of two. asum: abs via `_mm256_and_ps`/`_pd` with a sign-bit-cleared mask (0x7FFFFFFF for f32, 0x7FFFFFFFFFFFFFFF for f64) — one AVX instruction (VANDPS) is faster than calling f32::abs() lane-by-lane. Two-accumulator unroll + horizontal reduce. All three follow the existing `dot_f32_avx2` template: - `#[target_feature(enable = "avx2[,fma]")]` on the inner unsafe fn. - Public wrapper does `cfg(target_arch = "x86_64")` and dispatches to the unsafe fn (tier detection in caller-of-caller verified AVX2 before reaching this module). - Non-x86_64 builds: pass through to `super::scalar::*`. - Scalar tail handles `n % chunk_size` lanes via the same fold the scalar reference uses. Numerical contract: scal: byte-equal to scalar (`x[i] *= α` is the same op). asum: small ULP drift on long vectors because the SIMD horizontal reduce orders the sum differently from strict left-fold. Test tolerance: `|got - expected| <= |expected|*1e-5 + 1e-6`. nrm2: same — drifts ~1-2 ULP on long vectors via reduce-order + sqrt rounding. Same tolerance. 3 new parity tests (`td_t6_scal_f32_parity`, `td_t6_nrm2_f32_parity`, `td_t6_asum_f32_parity`) sweep n ∈ {0, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100} — covers the chunk-of-16 unroll path, the chunk-of-8 cleanup path, and the scalar tail for every kernel. Verification: * 2090 lib tests pass (was 2087 — +3 new parity tests; the existing test_scal_f32 / test_nrm2_f64 / test_asum_f32 that used to hit the scalar shims now exercise the AVX2 kernels and continue to pass). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo clippy --lib --tests --features rayon,native,runtime-dispatch -- -D warnings clean. * cargo fmt --all --check clean. Throughput impact (back-of-envelope on Sapphire Rapids, n=4096): scal_f32: scalar 4096 cycles (1 mul/lane) → AVX2 ~520 cycles (8 lanes/instr + 1-cycle issue) = ~8× faster. asum_f32: scalar 4096 cycles → AVX2 ~520 cycles = ~8× faster. nrm2_f32: scalar 4096 cycles (1 FMA/lane) → AVX2 ~260 cycles (16 lanes via 2-acc unroll, 1-cycle issue) = ~16×. Out of scope (separate PRs): * AVX-512 versions of the same three ops — `kernels_avx512.rs` has them already (lines 137-209), wired through the cfg(target_feature = "avx512f") path. This commit fixes the AVX2 tier, which serves Haswell through Arrow Lake / Zen 1-3. * Runtime-dispatch trampolines for these ops (would go in `simd_runtime/blas_l1.rs` mirroring the matmul.rs pattern from the runtime-dispatch PR). https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
1 parent 71f1973 commit 7d40704

1 file changed

Lines changed: 308 additions & 7 deletions

File tree

src/backend/native.rs

Lines changed: 308 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -540,24 +540,71 @@ mod avx2 {
540540
}
541541
}
542542

543-
// No AVX2 specialization — fall through to scalar
544543
pub fn scal_f32(alpha: f32, x: &mut [f32]) {
545-
super::scalar::scal_f32(alpha, x);
544+
#[cfg(target_arch = "x86_64")]
545+
{
546+
// SAFETY: tier() already verified AVX2 support before calling.
547+
unsafe { scal_f32_avx2(alpha, x) }
548+
}
549+
#[cfg(not(target_arch = "x86_64"))]
550+
{
551+
super::scalar::scal_f32(alpha, x);
552+
}
546553
}
547554
pub fn scal_f64(alpha: f64, x: &mut [f64]) {
548-
super::scalar::scal_f64(alpha, x);
555+
#[cfg(target_arch = "x86_64")]
556+
{
557+
// SAFETY: tier() already verified AVX2 support before calling.
558+
unsafe { scal_f64_avx2(alpha, x) }
559+
}
560+
#[cfg(not(target_arch = "x86_64"))]
561+
{
562+
super::scalar::scal_f64(alpha, x);
563+
}
549564
}
550565
pub fn nrm2_f32(x: &[f32]) -> f32 {
551-
super::scalar::nrm2_f32(x)
566+
#[cfg(target_arch = "x86_64")]
567+
{
568+
// SAFETY: tier() verified AVX2+FMA.
569+
unsafe { nrm2_f32_avx2(x) }
570+
}
571+
#[cfg(not(target_arch = "x86_64"))]
572+
{
573+
super::scalar::nrm2_f32(x)
574+
}
552575
}
553576
pub fn nrm2_f64(x: &[f64]) -> f64 {
554-
super::scalar::nrm2_f64(x)
577+
#[cfg(target_arch = "x86_64")]
578+
{
579+
// SAFETY: tier() verified AVX2+FMA.
580+
unsafe { nrm2_f64_avx2(x) }
581+
}
582+
#[cfg(not(target_arch = "x86_64"))]
583+
{
584+
super::scalar::nrm2_f64(x)
585+
}
555586
}
556587
pub fn asum_f32(x: &[f32]) -> f32 {
557-
super::scalar::asum_f32(x)
588+
#[cfg(target_arch = "x86_64")]
589+
{
590+
// SAFETY: tier() verified AVX2.
591+
unsafe { asum_f32_avx2(x) }
592+
}
593+
#[cfg(not(target_arch = "x86_64"))]
594+
{
595+
super::scalar::asum_f32(x)
596+
}
558597
}
559598
pub fn asum_f64(x: &[f64]) -> f64 {
560-
super::scalar::asum_f64(x)
599+
#[cfg(target_arch = "x86_64")]
600+
{
601+
// SAFETY: tier() verified AVX2.
602+
unsafe { asum_f64_avx2(x) }
603+
}
604+
#[cfg(not(target_arch = "x86_64"))]
605+
{
606+
super::scalar::asum_f64(x)
607+
}
561608
}
562609

563610
// ── AVX2 intrinsic implementations ─────────────────────────────
@@ -677,6 +724,201 @@ mod avx2 {
677724
i += 1;
678725
}
679726
}
727+
728+
// ── scal: x[i] *= alpha ────────────────────────────────────────
729+
730+
#[cfg(target_arch = "x86_64")]
731+
#[target_feature(enable = "avx2")]
732+
unsafe fn scal_f32_avx2(alpha: f32, x: &mut [f32]) {
733+
use core::arch::x86_64::*;
734+
let n = x.len();
735+
let valpha = _mm256_set1_ps(alpha);
736+
let mut i = 0;
737+
while i + 8 <= n {
738+
let v = _mm256_loadu_ps(x.as_ptr().add(i));
739+
_mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v, valpha));
740+
i += 8;
741+
}
742+
while i < n {
743+
x[i] *= alpha;
744+
i += 1;
745+
}
746+
}
747+
748+
#[cfg(target_arch = "x86_64")]
749+
#[target_feature(enable = "avx2")]
750+
unsafe fn scal_f64_avx2(alpha: f64, x: &mut [f64]) {
751+
use core::arch::x86_64::*;
752+
let n = x.len();
753+
let valpha = _mm256_set1_pd(alpha);
754+
let mut i = 0;
755+
while i + 4 <= n {
756+
let v = _mm256_loadu_pd(x.as_ptr().add(i));
757+
_mm256_storeu_pd(x.as_mut_ptr().add(i), _mm256_mul_pd(v, valpha));
758+
i += 4;
759+
}
760+
while i < n {
761+
x[i] *= alpha;
762+
i += 1;
763+
}
764+
}
765+
766+
// ── nrm2: sqrt(Σ x[i]²) ────────────────────────────────────────
767+
//
768+
// Two-accumulator unroll + FMA for the squared sum, scalar sqrt at
769+
// the end. SIMD horizontal reduce ordering differs from the strict
770+
// left-fold the scalar reference uses, so the ULP error can drift
771+
// by 1-2 ULP on long vectors — same tolerance the existing
772+
// `dot_f32_avx2` carries, accepted in BLAS-1.
773+
774+
#[cfg(target_arch = "x86_64")]
775+
#[target_feature(enable = "avx2,fma")]
776+
unsafe fn nrm2_f32_avx2(x: &[f32]) -> f32 {
777+
use core::arch::x86_64::*;
778+
let n = x.len();
779+
let mut acc0 = _mm256_setzero_ps();
780+
let mut acc1 = _mm256_setzero_ps();
781+
let mut i = 0;
782+
while i + 16 <= n {
783+
let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
784+
let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
785+
acc0 = _mm256_fmadd_ps(v0, v0, acc0);
786+
acc1 = _mm256_fmadd_ps(v1, v1, acc1);
787+
i += 16;
788+
}
789+
while i + 8 <= n {
790+
let v = _mm256_loadu_ps(x.as_ptr().add(i));
791+
acc0 = _mm256_fmadd_ps(v, v, acc0);
792+
i += 8;
793+
}
794+
acc0 = _mm256_add_ps(acc0, acc1);
795+
let hi = _mm256_extractf128_ps(acc0, 1);
796+
let lo = _mm256_castps256_ps128(acc0);
797+
let sum128 = _mm_add_ps(lo, hi);
798+
let shuf = _mm_movehdup_ps(sum128);
799+
let sums = _mm_add_ps(sum128, shuf);
800+
let shuf2 = _mm_movehl_ps(sums, sums);
801+
let result = _mm_add_ss(sums, shuf2);
802+
let mut total = _mm_cvtss_f32(result);
803+
while i < n {
804+
total += x[i] * x[i];
805+
i += 1;
806+
}
807+
total.sqrt()
808+
}
809+
810+
#[cfg(target_arch = "x86_64")]
811+
#[target_feature(enable = "avx2,fma")]
812+
unsafe fn nrm2_f64_avx2(x: &[f64]) -> f64 {
813+
use core::arch::x86_64::*;
814+
let n = x.len();
815+
let mut acc0 = _mm256_setzero_pd();
816+
let mut acc1 = _mm256_setzero_pd();
817+
let mut i = 0;
818+
while i + 8 <= n {
819+
let v0 = _mm256_loadu_pd(x.as_ptr().add(i));
820+
let v1 = _mm256_loadu_pd(x.as_ptr().add(i + 4));
821+
acc0 = _mm256_fmadd_pd(v0, v0, acc0);
822+
acc1 = _mm256_fmadd_pd(v1, v1, acc1);
823+
i += 8;
824+
}
825+
while i + 4 <= n {
826+
let v = _mm256_loadu_pd(x.as_ptr().add(i));
827+
acc0 = _mm256_fmadd_pd(v, v, acc0);
828+
i += 4;
829+
}
830+
acc0 = _mm256_add_pd(acc0, acc1);
831+
let hi = _mm256_extractf128_pd(acc0, 1);
832+
let lo = _mm256_castpd256_pd128(acc0);
833+
let sum128 = _mm_add_pd(lo, hi);
834+
let shuf = _mm_unpackhi_pd(sum128, sum128);
835+
let result = _mm_add_sd(sum128, shuf);
836+
let mut total = _mm_cvtsd_f64(result);
837+
while i < n {
838+
total += x[i] * x[i];
839+
i += 1;
840+
}
841+
total.sqrt()
842+
}
843+
844+
// ── asum: Σ |x[i]| ─────────────────────────────────────────────
845+
//
846+
// Abs via AND with sign-bit-cleared mask (one AVX instruction —
847+
// VANDPS), horizontal sum at the end. Same ordering caveat as
848+
// nrm2.
849+
850+
#[cfg(target_arch = "x86_64")]
851+
#[target_feature(enable = "avx2")]
852+
unsafe fn asum_f32_avx2(x: &[f32]) -> f32 {
853+
use core::arch::x86_64::*;
854+
let n = x.len();
855+
// Sign-bit-cleared mask: 0x7FFFFFFF in every lane.
856+
let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFF_FFFFi32));
857+
let mut acc0 = _mm256_setzero_ps();
858+
let mut acc1 = _mm256_setzero_ps();
859+
let mut i = 0;
860+
while i + 16 <= n {
861+
let v0 = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i)), abs_mask);
862+
let v1 = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i + 8)), abs_mask);
863+
acc0 = _mm256_add_ps(acc0, v0);
864+
acc1 = _mm256_add_ps(acc1, v1);
865+
i += 16;
866+
}
867+
while i + 8 <= n {
868+
let v = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i)), abs_mask);
869+
acc0 = _mm256_add_ps(acc0, v);
870+
i += 8;
871+
}
872+
acc0 = _mm256_add_ps(acc0, acc1);
873+
let hi = _mm256_extractf128_ps(acc0, 1);
874+
let lo = _mm256_castps256_ps128(acc0);
875+
let sum128 = _mm_add_ps(lo, hi);
876+
let shuf = _mm_movehdup_ps(sum128);
877+
let sums = _mm_add_ps(sum128, shuf);
878+
let shuf2 = _mm_movehl_ps(sums, sums);
879+
let result = _mm_add_ss(sums, shuf2);
880+
let mut total = _mm_cvtss_f32(result);
881+
while i < n {
882+
total += x[i].abs();
883+
i += 1;
884+
}
885+
total
886+
}
887+
888+
#[cfg(target_arch = "x86_64")]
889+
#[target_feature(enable = "avx2")]
890+
unsafe fn asum_f64_avx2(x: &[f64]) -> f64 {
891+
use core::arch::x86_64::*;
892+
let n = x.len();
893+
let abs_mask = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7FFF_FFFF_FFFF_FFFFi64));
894+
let mut acc0 = _mm256_setzero_pd();
895+
let mut acc1 = _mm256_setzero_pd();
896+
let mut i = 0;
897+
while i + 8 <= n {
898+
let v0 = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i)), abs_mask);
899+
let v1 = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i + 4)), abs_mask);
900+
acc0 = _mm256_add_pd(acc0, v0);
901+
acc1 = _mm256_add_pd(acc1, v1);
902+
i += 8;
903+
}
904+
while i + 4 <= n {
905+
let v = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i)), abs_mask);
906+
acc0 = _mm256_add_pd(acc0, v);
907+
i += 4;
908+
}
909+
acc0 = _mm256_add_pd(acc0, acc1);
910+
let hi = _mm256_extractf128_pd(acc0, 1);
911+
let lo = _mm256_castpd256_pd128(acc0);
912+
let sum128 = _mm_add_pd(lo, hi);
913+
let shuf = _mm_unpackhi_pd(sum128, sum128);
914+
let result = _mm_add_sd(sum128, shuf);
915+
let mut total = _mm_cvtsd_f64(result);
916+
while i < n {
917+
total += x[i].abs();
918+
i += 1;
919+
}
920+
total
921+
}
680922
}
681923

682924
// ═══════════════════════════════════════════════════════════════════
@@ -760,4 +1002,63 @@ mod tests {
7601002
// Should be one of the valid tier values
7611003
assert!(nr == 4 || nr == 8 || nr == 16);
7621004
}
1005+
1006+
// ── TD-T6: parity sweep for the new AVX2 BLAS-1 kernels ────────
1007+
//
1008+
// The shim → real-intrinsic switch flipped scal/nrm2/asum from
1009+
// scalar-fallthrough to AVX2 chunked + scalar-tail kernels. Each
1010+
// new kernel: verify byte-equal (or ULP-tight for nrm2 which
1011+
// includes a sqrt and a different sum order) against the scalar
1012+
// reference across shapes that exercise the chunk-of-16, chunk-
1013+
// of-8, and scalar-tail code paths.
1014+
1015+
fn ref_scal(alpha: f32, x: &[f32]) -> Vec<f32> {
1016+
x.iter().map(|&v| v * alpha).collect()
1017+
}
1018+
fn ref_nrm2(x: &[f32]) -> f32 {
1019+
x.iter().map(|&v| v * v).sum::<f32>().sqrt()
1020+
}
1021+
fn ref_asum(x: &[f32]) -> f32 {
1022+
x.iter().map(|&v| v.abs()).sum()
1023+
}
1024+
1025+
#[test]
1026+
fn td_t6_scal_f32_parity() {
1027+
for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] {
1028+
let alpha = 1.5f32;
1029+
let init: Vec<f32> = (0..n).map(|i| (i as f32 * 0.5) - 1.0).collect();
1030+
let expected = ref_scal(alpha, &init);
1031+
let mut got = init.clone();
1032+
scal_f32(alpha, &mut got);
1033+
for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() {
1034+
assert_eq!(g.to_bits(), e.to_bits(), "scal_f32 n={n} i={i}: got {g} want {e}");
1035+
}
1036+
}
1037+
}
1038+
1039+
#[test]
1040+
fn td_t6_nrm2_f32_parity() {
1041+
for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] {
1042+
let x: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3) - 0.5).collect();
1043+
let expected = ref_nrm2(&x);
1044+
let got = nrm2_f32(&x);
1045+
// ULP tolerance because SIMD reduce order differs from
1046+
// strict left-fold; nrm2 also includes the final sqrt.
1047+
let abs_err = (got - expected).abs();
1048+
let rel_tol = expected.abs() * 1e-5 + 1e-6;
1049+
assert!(abs_err <= rel_tol, "nrm2_f32 n={n}: got {got} want {expected} (err {abs_err})");
1050+
}
1051+
}
1052+
1053+
#[test]
1054+
fn td_t6_asum_f32_parity() {
1055+
for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] {
1056+
let x: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3) - 0.5).collect();
1057+
let expected = ref_asum(&x);
1058+
let got = asum_f32(&x);
1059+
let abs_err = (got - expected).abs();
1060+
let rel_tol = expected.abs() * 1e-5 + 1e-6;
1061+
assert!(abs_err <= rel_tol, "asum_f32 n={n}: got {got} want {expected} (err {abs_err})");
1062+
}
1063+
}
7631064
}

0 commit comments

Comments
 (0)