Skip to content

Commit 7189779

Browse files
committed
fix(simd): aarch64 F32x16/F64x8 use real NEON paired loads, not scalar
Burn parity item 9: F32x16/F64x8 on aarch64 previously dispatched to the scalar fallback in simd::scalar (element-wise [f32; 16] loops). Add a real NEON-backed implementation in simd_neon::aarch64_simd, modeled on the AVX2 polyfill's dual-tuple shape: F32x16 = [float32x4_t; 4] (4x vld1q_f32 / vst1q_f32 / vfmaq_f32 / vaddq_f32 etc. per op) F64x8 = [float64x2_t; 4] (4x vld1q_f64 / vst1q_f64 / vfmaq_f64) Hot-path arithmetic (add, sub, mul, div, mul_add, splat, abs, neg, sqrt, round, floor, simd_min/max, reduce_sum) compiles to one NEON instruction per 128-bit lane pair. Comparisons and bit-cast helpers round-trip through to_array, same shape as simd_avx2. simd.rs: mod scalar -> pub(crate) mod scalar (so simd_neon can pull I32x16/U32x16/U64x8 from there). aarch64 branch pulls F32x16/F64x8 from simd_neon::aarch64_simd; integer + 256-bit float types still come from scalar. Other non-x86 targets (wasm/riscv) keep full scalar fallback. simd_neon.rs: pub mod aarch64_simd (~600 LOC) plus 5 smoke tests gated on cfg(target_arch = "aarch64", test). Build: - cargo build --release --lib -p ndarray (x86_64 AVX-512): PASS - aarch64 cross-compile of just our types compiles cleanly (uses only stable core::arch::aarch64 intrinsics shipped since 1.59); full lib cross-compile blocked in this env by blake3 needing aarch64-linux-gnu-gcc which is not installed.
1 parent 888e598 commit 7189779

2 files changed

Lines changed: 661 additions & 2 deletions

File tree

src/simd.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ pub use crate::simd_avx2::{
162162
// ============================================================================
163163

164164
#[cfg(not(target_arch = "x86_64"))]
165-
mod scalar {
165+
pub(crate) mod scalar {
166166
use core::fmt;
167167
use core::ops::{
168168
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign,
@@ -939,7 +939,25 @@ mod scalar {
939939
#[allow(non_camel_case_types)] pub type f64x4 = F64x4;
940940
}
941941

942-
#[cfg(not(target_arch = "x86_64"))]
942+
// aarch64: F32x16/F64x8 come from the real NEON paired-load implementation
943+
// in simd_neon::aarch64_simd (verified 2026-04-30, agent A7 — burn parity item 9).
944+
// Integer + 256-bit float types still come from the scalar fallback; they're
945+
// not on the critical path for f32 BLAS-1 / VML kernels.
946+
#[cfg(target_arch = "aarch64")]
947+
pub use crate::simd_neon::aarch64_simd::{
948+
F32x16, F64x8, F32Mask16, F64Mask8,
949+
f32x16, f64x8,
950+
};
951+
#[cfg(target_arch = "aarch64")]
952+
pub use scalar::{
953+
U8x64, I32x16, I64x8, U16x32, U32x16, U64x8,
954+
F32x8, F64x4,
955+
u8x64, i32x16, i64x8, u32x16, u64x8,
956+
f32x8, f64x4,
957+
};
958+
959+
// Other non-x86 targets (wasm, riscv, etc.): full scalar fallback.
960+
#[cfg(all(not(target_arch = "x86_64"), not(target_arch = "aarch64")))]
943961
pub use scalar::{
944962
F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8,
945963
F32x8, F64x4,

0 commit comments

Comments
 (0)