Skip to content

Commit be35795

Browse files
committed
fix(simd): aarch64 F32x16/F64x8 use real NEON paired loads, not scalar
Burn parity item 9: F32x16/F64x8 on aarch64 previously dispatched to the scalar fallback in simd::scalar (element-wise [f32; 16] loops). Add a real NEON-backed implementation in simd_neon::aarch64_simd, modeled on the AVX2 polyfill's dual-tuple shape: F32x16 = [float32x4_t; 4] (4x vld1q_f32 / vst1q_f32 / vfmaq_f32 / vaddq_f32 etc. per op) F64x8 = [float64x2_t; 4] (4x vld1q_f64 / vst1q_f64 / vfmaq_f64) Hot-path arithmetic (add, sub, mul, div, mul_add, splat, abs, neg, sqrt, round, floor, simd_min/max, reduce_sum) compiles to one NEON instruction per 128-bit lane pair. Comparisons and bit-cast helpers round-trip through to_array, same shape as simd_avx2. simd.rs: mod scalar -> pub(crate) mod scalar (so simd_neon can pull I32x16/U32x16/U64x8 from there). aarch64 branch pulls F32x16/F64x8 from simd_neon::aarch64_simd; integer + 256-bit float types still come from scalar. Other non-x86 targets (wasm/riscv) keep full scalar fallback. simd_neon.rs: pub mod aarch64_simd (~600 LOC) plus 5 smoke tests gated on cfg(target_arch = "aarch64", test). Build: - cargo build --release --lib -p ndarray (x86_64 AVX-512): PASS - aarch64 cross-compile of just our types compiles cleanly (uses only stable core::arch::aarch64 intrinsics shipped since 1.59); full lib cross-compile blocked in this env by blake3 needing aarch64-linux-gnu-gcc which is not installed.
1 parent 4eca4e0 commit be35795

2 files changed

Lines changed: 661 additions & 2 deletions

File tree

src/simd.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ pub use crate::simd_avx2::{
237237
// ============================================================================
238238

239239
#[cfg(not(target_arch = "x86_64"))]
240-
mod scalar {
240+
pub(crate) mod scalar {
241241
use core::fmt;
242242
use core::ops::{
243243
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign,
@@ -1014,7 +1014,25 @@ mod scalar {
10141014
#[allow(non_camel_case_types)] pub type f64x4 = F64x4;
10151015
}
10161016

1017-
#[cfg(not(target_arch = "x86_64"))]
1017+
// aarch64: F32x16/F64x8 come from the real NEON paired-load implementation
1018+
// in simd_neon::aarch64_simd (verified 2026-04-30, agent A7 — burn parity item 9).
1019+
// Integer + 256-bit float types still come from the scalar fallback; they're
1020+
// not on the critical path for f32 BLAS-1 / VML kernels.
1021+
#[cfg(target_arch = "aarch64")]
1022+
pub use crate::simd_neon::aarch64_simd::{
1023+
F32x16, F64x8, F32Mask16, F64Mask8,
1024+
f32x16, f64x8,
1025+
};
1026+
#[cfg(target_arch = "aarch64")]
1027+
pub use scalar::{
1028+
U8x64, I32x16, I64x8, U16x32, U32x16, U64x8,
1029+
F32x8, F64x4,
1030+
u8x64, i32x16, i64x8, u32x16, u64x8,
1031+
f32x8, f64x4,
1032+
};
1033+
1034+
// Other non-x86 targets (wasm, riscv, etc.): full scalar fallback.
1035+
#[cfg(all(not(target_arch = "x86_64"), not(target_arch = "aarch64")))]
10181036
pub use scalar::{
10191037
F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8,
10201038
F32x8, F64x4,

0 commit comments

Comments
 (0)