@@ -1000,6 +1000,104 @@ impl BenchmarkTranscript {
10001000 }
10011001}
10021002
1003+ // ============================================================================
1004+ // simd_apply — Generic fused SIMD kernel over aligned f32 slices
1005+ // ============================================================================
1006+
1007+ use crate :: simd:: F32x16 ;
1008+
1009+ /// Apply a generic SIMD operation element-wise over two aligned f32 slices.
1010+ ///
1011+ /// Processes 16 elements per iteration using `F32x16`, with a scalar tail.
1012+ /// This is the generic fusion primitive: callers pass any `Fn(F32x16, F32x16) -> F32x16`.
1013+ ///
1014+ /// # Examples
1015+ ///
1016+ /// ```ignore
1017+ /// use ndarray::hpc::kernels::simd_apply;
1018+ /// let a = vec![1.0f32; 64];
1019+ /// let b = vec![2.0f32; 64];
1020+ /// let mut out = vec![0.0f32; 64];
1021+ /// // Fused multiply-add: a * b + a
1022+ /// simd_apply(&a, &b, &mut out, |va, vb| va.mul_add(vb, va));
1023+ /// ```
1024+ #[ inline]
1025+ pub fn simd_apply < F > ( a : & [ f32 ] , b : & [ f32 ] , out : & mut [ f32 ] , f : F )
1026+ where
1027+ F : Fn ( F32x16 , F32x16 ) -> F32x16 ,
1028+ {
1029+ let n = a. len ( ) . min ( b. len ( ) ) . min ( out. len ( ) ) ;
1030+ let mut i = 0 ;
1031+ while i + 16 <= n {
1032+ let va = F32x16 :: from_slice ( & a[ i..] ) ;
1033+ let vb = F32x16 :: from_slice ( & b[ i..] ) ;
1034+ f ( va, vb) . copy_to_slice ( & mut out[ i..] ) ;
1035+ i += 16 ;
1036+ }
1037+ // Scalar tail: extract one lane at a time
1038+ if i < n {
1039+ let tail_len = n - i;
1040+ let mut a_pad = [ 0.0f32 ; 16 ] ;
1041+ let mut b_pad = [ 0.0f32 ; 16 ] ;
1042+ a_pad[ ..tail_len] . copy_from_slice ( & a[ i..n] ) ;
1043+ b_pad[ ..tail_len] . copy_from_slice ( & b[ i..n] ) ;
1044+ let result = f ( F32x16 :: from_array ( a_pad) , F32x16 :: from_array ( b_pad) ) ;
1045+ let arr = result. to_array ( ) ;
1046+ out[ i..n] . copy_from_slice ( & arr[ ..tail_len] ) ;
1047+ }
1048+ }
1049+
1050+ /// Apply a generic unary SIMD operation element-wise over an f32 slice.
1051+ ///
1052+ /// Single-input variant of [`simd_apply`].
1053+ #[ inline]
1054+ pub fn simd_apply_unary < F > ( x : & [ f32 ] , out : & mut [ f32 ] , f : F )
1055+ where
1056+ F : Fn ( F32x16 ) -> F32x16 ,
1057+ {
1058+ let n = x. len ( ) . min ( out. len ( ) ) ;
1059+ let mut i = 0 ;
1060+ while i + 16 <= n {
1061+ let v = F32x16 :: from_slice ( & x[ i..] ) ;
1062+ f ( v) . copy_to_slice ( & mut out[ i..] ) ;
1063+ i += 16 ;
1064+ }
1065+ if i < n {
1066+ let tail_len = n - i;
1067+ let mut pad = [ 0.0f32 ; 16 ] ;
1068+ pad[ ..tail_len] . copy_from_slice ( & x[ i..n] ) ;
1069+ let result = f ( F32x16 :: from_array ( pad) ) ;
1070+ let arr = result. to_array ( ) ;
1071+ out[ i..n] . copy_from_slice ( & arr[ ..tail_len] ) ;
1072+ }
1073+ }
1074+
1075+ /// Apply a generic SIMD operation in-place: `a[i] = f(a[i], b[i])`.
1076+ #[ inline]
1077+ pub fn simd_apply_inplace < F > ( a : & mut [ f32 ] , b : & [ f32 ] , f : F )
1078+ where
1079+ F : Fn ( F32x16 , F32x16 ) -> F32x16 ,
1080+ {
1081+ let n = a. len ( ) . min ( b. len ( ) ) ;
1082+ let mut i = 0 ;
1083+ while i + 16 <= n {
1084+ let va = F32x16 :: from_slice ( & a[ i..] ) ;
1085+ let vb = F32x16 :: from_slice ( & b[ i..] ) ;
1086+ f ( va, vb) . copy_to_slice ( & mut a[ i..] ) ;
1087+ i += 16 ;
1088+ }
1089+ if i < n {
1090+ let tail_len = n - i;
1091+ let mut a_pad = [ 0.0f32 ; 16 ] ;
1092+ let mut b_pad = [ 0.0f32 ; 16 ] ;
1093+ a_pad[ ..tail_len] . copy_from_slice ( & a[ i..n] ) ;
1094+ b_pad[ ..tail_len] . copy_from_slice ( & b[ i..n] ) ;
1095+ let result = f ( F32x16 :: from_array ( a_pad) , F32x16 :: from_array ( b_pad) ) ;
1096+ let arr = result. to_array ( ) ;
1097+ a[ i..n] . copy_from_slice ( & arr[ ..tail_len] ) ;
1098+ }
1099+ }
1100+
10031101// ============================================================================
10041102// Tests
10051103// ============================================================================
@@ -1586,4 +1684,68 @@ mod tests {
15861684 assert_eq ! ( exact. sigma. level, SignificanceLevel :: Discovery ) ;
15871685 assert ! ( exact. sigma. sigma > 100.0 ) ;
15881686 }
1687+
1688+ // --- simd_apply tests ---
1689+
1690+ #[ test]
1691+ fn test_simd_apply_add ( ) {
1692+ let a: Vec < f32 > = ( 0 ..100 ) . map ( |i| i as f32 ) . collect ( ) ;
1693+ let b: Vec < f32 > = ( 0 ..100 ) . map ( |i| ( i * 2 ) as f32 ) . collect ( ) ;
1694+ let mut out = vec ! [ 0.0f32 ; 100 ] ;
1695+ simd_apply ( & a, & b, & mut out, |va, vb| va + vb) ;
1696+ for i in 0 ..100 {
1697+ assert_eq ! ( out[ i] , ( i + i * 2 ) as f32 , "mismatch at {i}" ) ;
1698+ }
1699+ }
1700+
1701+ #[ test]
1702+ fn test_simd_apply_fma ( ) {
1703+ let a = vec ! [ 2.0f32 ; 35 ] ; // Not divisible by 16 — tests tail
1704+ let b = vec ! [ 3.0f32 ; 35 ] ;
1705+ let mut out = vec ! [ 0.0f32 ; 35 ] ;
1706+ // a * b + a = 2*3 + 2 = 8
1707+ simd_apply ( & a, & b, & mut out, |va, vb| va. mul_add ( vb, va) ) ;
1708+ for i in 0 ..35 {
1709+ assert ! ( ( out[ i] - 8.0 ) . abs( ) < 1e-5 , "mismatch at {i}: {}" , out[ i] ) ;
1710+ }
1711+ }
1712+
1713+ #[ test]
1714+ fn test_simd_apply_unary_sqrt ( ) {
1715+ let x: Vec < f32 > = ( 1 ..=50 ) . map ( |i| ( i * i) as f32 ) . collect ( ) ;
1716+ let mut out = vec ! [ 0.0f32 ; 50 ] ;
1717+ simd_apply_unary ( & x, & mut out, |v| v. sqrt ( ) ) ;
1718+ for i in 0 ..50 {
1719+ assert ! ( ( out[ i] - ( i + 1 ) as f32 ) . abs( ) < 1e-4 , "mismatch at {i}" ) ;
1720+ }
1721+ }
1722+
1723+ #[ test]
1724+ fn test_simd_apply_inplace ( ) {
1725+ let mut a: Vec < f32 > = ( 0 ..48 ) . map ( |i| i as f32 ) . collect ( ) ;
1726+ let b = vec ! [ 1.0f32 ; 48 ] ;
1727+ simd_apply_inplace ( & mut a, & b, |va, vb| va + vb) ;
1728+ for i in 0 ..48 {
1729+ assert_eq ! ( a[ i] , ( i + 1 ) as f32 ) ;
1730+ }
1731+ }
1732+
1733+ #[ test]
1734+ fn test_simd_apply_empty ( ) {
1735+ let a: Vec < f32 > = vec ! [ ] ;
1736+ let b: Vec < f32 > = vec ! [ ] ;
1737+ let mut out: Vec < f32 > = vec ! [ ] ;
1738+ simd_apply ( & a, & b, & mut out, |va, vb| va + vb) ;
1739+ assert ! ( out. is_empty( ) ) ;
1740+ }
1741+
1742+ #[ test]
1743+ fn test_simd_apply_small_tail_only ( ) {
1744+ // Only 3 elements — entirely tail path
1745+ let a = vec ! [ 1.0f32 , 2.0 , 3.0 ] ;
1746+ let b = vec ! [ 4.0f32 , 5.0 , 6.0 ] ;
1747+ let mut out = vec ! [ 0.0f32 ; 3 ] ;
1748+ simd_apply ( & a, & b, & mut out, |va, vb| va * vb) ;
1749+ assert_eq ! ( out, [ 4.0 , 10.0 , 18.0 ] ) ;
1750+ }
15891751}
0 commit comments