@@ -30,6 +30,8 @@ struct KernelSse2;
3030#[ cfg( target_arch="aarch64" ) ]
3131#[ cfg( has_aarch64_simd) ]
3232struct KernelNeon ;
33+ #[ cfg( all( target_arch="wasm32" , target_feature="simd128" ) ) ]
34+ struct KernelWasmSimd ;
3335struct KernelFallback ;
3436
3537type T = f32 ;
@@ -62,6 +64,11 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
6264 return selector. select ( KernelNeon ) ;
6365 }
6466 }
67+ #[ cfg( all( target_arch="wasm32" , target_feature="simd128" ) ) ]
68+ {
69+ return selector. select ( KernelWasmSimd ) ;
70+ }
71+ #[ allow( unreachable_code) ]
6572 return selector. select ( KernelFallback ) ;
6673}
6774
@@ -279,6 +286,38 @@ impl GemmKernel for KernelFallback {
279286 }
280287}
281288
289+ #[ cfg( all( target_arch="wasm32" , target_feature="simd128" ) ) ]
290+ impl GemmKernel for KernelWasmSimd {
291+ type Elem = T ;
292+
293+ type MRTy = U8 ;
294+ type NRTy = U8 ;
295+
296+ #[ inline( always) ]
297+ fn align_to ( ) -> usize { 16 }
298+
299+ #[ inline( always) ]
300+ fn always_masked ( ) -> bool { false }
301+
302+ #[ inline( always) ]
303+ fn nc ( ) -> usize { archparam:: S_NC }
304+ #[ inline( always) ]
305+ fn kc ( ) -> usize { archparam:: S_KC }
306+ #[ inline( always) ]
307+ fn mc ( ) -> usize { archparam:: S_MC }
308+
309+ #[ inline( always) ]
310+ unsafe fn kernel (
311+ k : usize ,
312+ alpha : T ,
313+ a : * const T ,
314+ b : * const T ,
315+ beta : T ,
316+ c : * mut T , rsc : isize , csc : isize ) {
317+ kernel_target_wasm_simd ( k, alpha, a, b, beta, c, rsc, csc)
318+ }
319+ }
320+
282321// no inline for unmasked kernels
283322#[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
284323#[ target_feature( enable="fma" ) ]
@@ -692,6 +731,132 @@ unsafe fn kernel_target_neon(k: usize, alpha: T, a: *const T, b: *const T,
692731 }
693732}
694733
734+ #[ cfg( all( target_arch="wasm32" , target_feature="simd128" ) ) ]
735+ unsafe fn kernel_target_wasm_simd ( k : usize , alpha : T , a : * const T , b : * const T ,
736+ beta : T , c : * mut T , rsc : isize , csc : isize )
737+ {
738+ use core:: arch:: wasm32:: * ;
739+ const MR : usize = KernelWasmSimd :: MR ;
740+ const NR : usize = KernelWasmSimd :: NR ;
741+
742+ let ( mut a, mut b, rsc, csc) = if rsc == 1 { ( b, a, csc, rsc) } else { ( a, b, rsc, csc) } ;
743+
744+ // 8x8 microkernel with 16 v128 accumulators (4 quadrants of 4x4 each).
745+ // Mirrors the AArch64 NEON kernel structure: per k step we read 8 floats
746+ // of A (one packed column slice) and 8 floats of B (one packed row slice),
747+ // then accumulate the outer product into ab[ij].
748+ let zero = f32x4_splat ( 0. ) ;
749+ let mut ab11 = [ zero; 4 ] ;
750+ let mut ab12 = [ zero; 4 ] ;
751+ let mut ab21 = [ zero; 4 ] ;
752+ let mut ab22 = [ zero; 4 ] ;
753+
754+ // dest[i] += b * splat(a[i]) for i in 0..4, given a v128 a-vec and a
755+ // v128 b-vec. wasm SIMD has no lane-broadcast-fma op, so we extract+splat.
756+ // Cranelift compiles `f32x4_splat(extract_lane::<L>(v))` to a single
757+ // shuffle on aarch64 hosts, and the add+mul pair to fmla.
758+ macro_rules! ab_ij_equals_ai_bj {
759+ ( $dest: ident, $av: expr, $bv: expr) => {
760+ $dest[ 0 ] = f32x4_add( $dest[ 0 ] , f32x4_mul( $bv, f32x4_splat( f32x4_extract_lane:: <0 >( $av) ) ) ) ;
761+ $dest[ 1 ] = f32x4_add( $dest[ 1 ] , f32x4_mul( $bv, f32x4_splat( f32x4_extract_lane:: <1 >( $av) ) ) ) ;
762+ $dest[ 2 ] = f32x4_add( $dest[ 2 ] , f32x4_mul( $bv, f32x4_splat( f32x4_extract_lane:: <2 >( $av) ) ) ) ;
763+ $dest[ 3 ] = f32x4_add( $dest[ 3 ] , f32x4_mul( $bv, f32x4_splat( f32x4_extract_lane:: <3 >( $av) ) ) ) ;
764+ }
765+ }
766+
767+ for _ in 0 ..k {
768+ let a1 = v128_load ( a as * const v128 ) ;
769+ let b1 = v128_load ( b as * const v128 ) ;
770+ let a2 = v128_load ( a. add ( 4 ) as * const v128 ) ;
771+ let b2 = v128_load ( b. add ( 4 ) as * const v128 ) ;
772+
773+ ab_ij_equals_ai_bj ! ( ab11, a1, b1) ;
774+ ab_ij_equals_ai_bj ! ( ab12, a1, b2) ;
775+ ab_ij_equals_ai_bj ! ( ab21, a2, b1) ;
776+ ab_ij_equals_ai_bj ! ( ab22, a2, b2) ;
777+
778+ a = a. add ( MR ) ;
779+ b = b. add ( NR ) ;
780+ }
781+
782+ macro_rules! c {
783+ ( $i: expr, $j: expr) => ( c. offset( rsc * $i as isize + csc * $j as isize ) ) ;
784+ }
785+
786+ // ab *= alpha
787+ let alphav = f32x4_splat ( alpha) ;
788+ loop4 ! ( i, ab11[ i] = f32x4_mul( ab11[ i] , alphav) ) ;
789+ loop4 ! ( i, ab12[ i] = f32x4_mul( ab12[ i] , alphav) ) ;
790+ loop4 ! ( i, ab21[ i] = f32x4_mul( ab21[ i] , alphav) ) ;
791+ loop4 ! ( i, ab22[ i] = f32x4_mul( ab22[ i] , alphav) ) ;
792+
793+ // Build a v128 by gathering four scalars from arbitrary pointers.
794+ macro_rules! loadq_from_pointers {
795+ ( $p0: expr, $p1: expr, $p2: expr, $p3: expr) => ( {
796+ let v = f32x4_splat( 0. ) ;
797+ let v = v128_load32_lane:: <0 >( v, $p0 as * const u32 ) ;
798+ let v = v128_load32_lane:: <1 >( v, $p1 as * const u32 ) ;
799+ let v = v128_load32_lane:: <2 >( v, $p2 as * const u32 ) ;
800+ let v = v128_load32_lane:: <3 >( v, $p3 as * const u32 ) ;
801+ v
802+ } ) ;
803+ }
804+
805+ if beta != 0. {
806+ let mut c11 = [ zero; 4 ] ;
807+ let mut c12 = [ zero; 4 ] ;
808+ let mut c21 = [ zero; 4 ] ;
809+ let mut c22 = [ zero; 4 ] ;
810+
811+ if csc == 1 {
812+ loop4 ! ( i, c11[ i] = v128_load( c![ i + 0 , 0 ] as * const v128) ) ;
813+ loop4 ! ( i, c12[ i] = v128_load( c![ i + 0 , 4 ] as * const v128) ) ;
814+ loop4 ! ( i, c21[ i] = v128_load( c![ i + 4 , 0 ] as * const v128) ) ;
815+ loop4 ! ( i, c22[ i] = v128_load( c![ i + 4 , 4 ] as * const v128) ) ;
816+ } else {
817+ loop4 ! ( i, c11[ i] = loadq_from_pointers!( c![ i + 0 , 0 ] , c![ i + 0 , 1 ] , c![ i + 0 , 2 ] , c![ i + 0 , 3 ] ) ) ;
818+ loop4 ! ( i, c12[ i] = loadq_from_pointers!( c![ i + 0 , 4 ] , c![ i + 0 , 5 ] , c![ i + 0 , 6 ] , c![ i + 0 , 7 ] ) ) ;
819+ loop4 ! ( i, c21[ i] = loadq_from_pointers!( c![ i + 4 , 0 ] , c![ i + 4 , 1 ] , c![ i + 4 , 2 ] , c![ i + 4 , 3 ] ) ) ;
820+ loop4 ! ( i, c22[ i] = loadq_from_pointers!( c![ i + 4 , 4 ] , c![ i + 4 , 5 ] , c![ i + 4 , 6 ] , c![ i + 4 , 7 ] ) ) ;
821+ }
822+
823+ let betav = f32x4_splat ( beta) ;
824+ // ab += β C
825+ loop4 ! ( i, ab11[ i] = f32x4_add( ab11[ i] , f32x4_mul( c11[ i] , betav) ) ) ;
826+ loop4 ! ( i, ab12[ i] = f32x4_add( ab12[ i] , f32x4_mul( c12[ i] , betav) ) ) ;
827+ loop4 ! ( i, ab21[ i] = f32x4_add( ab21[ i] , f32x4_mul( c21[ i] , betav) ) ) ;
828+ loop4 ! ( i, ab22[ i] = f32x4_add( ab22[ i] , f32x4_mul( c22[ i] , betav) ) ) ;
829+ }
830+
831+ // C <- α A B (+ β C)
832+ if csc == 1 {
833+ loop4 ! ( i, v128_store( c![ i + 0 , 0 ] as * mut v128, ab11[ i] ) ) ;
834+ loop4 ! ( i, v128_store( c![ i + 0 , 4 ] as * mut v128, ab12[ i] ) ) ;
835+ loop4 ! ( i, v128_store( c![ i + 4 , 0 ] as * mut v128, ab21[ i] ) ) ;
836+ loop4 ! ( i, v128_store( c![ i + 4 , 4 ] as * mut v128, ab22[ i] ) ) ;
837+ } else {
838+ loop4 ! ( i, v128_store32_lane:: <0 >( ab11[ i] , c![ i + 0 , 0 ] as * mut u32 ) ) ;
839+ loop4 ! ( i, v128_store32_lane:: <1 >( ab11[ i] , c![ i + 0 , 1 ] as * mut u32 ) ) ;
840+ loop4 ! ( i, v128_store32_lane:: <2 >( ab11[ i] , c![ i + 0 , 2 ] as * mut u32 ) ) ;
841+ loop4 ! ( i, v128_store32_lane:: <3 >( ab11[ i] , c![ i + 0 , 3 ] as * mut u32 ) ) ;
842+
843+ loop4 ! ( i, v128_store32_lane:: <0 >( ab12[ i] , c![ i + 0 , 4 ] as * mut u32 ) ) ;
844+ loop4 ! ( i, v128_store32_lane:: <1 >( ab12[ i] , c![ i + 0 , 5 ] as * mut u32 ) ) ;
845+ loop4 ! ( i, v128_store32_lane:: <2 >( ab12[ i] , c![ i + 0 , 6 ] as * mut u32 ) ) ;
846+ loop4 ! ( i, v128_store32_lane:: <3 >( ab12[ i] , c![ i + 0 , 7 ] as * mut u32 ) ) ;
847+
848+ loop4 ! ( i, v128_store32_lane:: <0 >( ab21[ i] , c![ i + 4 , 0 ] as * mut u32 ) ) ;
849+ loop4 ! ( i, v128_store32_lane:: <1 >( ab21[ i] , c![ i + 4 , 1 ] as * mut u32 ) ) ;
850+ loop4 ! ( i, v128_store32_lane:: <2 >( ab21[ i] , c![ i + 4 , 2 ] as * mut u32 ) ) ;
851+ loop4 ! ( i, v128_store32_lane:: <3 >( ab21[ i] , c![ i + 4 , 3 ] as * mut u32 ) ) ;
852+
853+ loop4 ! ( i, v128_store32_lane:: <0 >( ab22[ i] , c![ i + 4 , 4 ] as * mut u32 ) ) ;
854+ loop4 ! ( i, v128_store32_lane:: <1 >( ab22[ i] , c![ i + 4 , 5 ] as * mut u32 ) ) ;
855+ loop4 ! ( i, v128_store32_lane:: <2 >( ab22[ i] , c![ i + 4 , 6 ] as * mut u32 ) ) ;
856+ loop4 ! ( i, v128_store32_lane:: <3 >( ab22[ i] , c![ i + 4 , 7 ] as * mut u32 ) ) ;
857+ }
858+ }
859+
695860#[ inline]
696861unsafe fn kernel_fallback_impl ( k : usize , alpha : T , a : * const T , b : * const T ,
697862 beta : T , c : * mut T , rsc : isize , csc : isize )
@@ -775,6 +940,17 @@ mod tests {
775940 }
776941 }
777942
943+ #[ cfg( all( target_arch="wasm32" , target_feature="simd128" ) ) ]
944+ mod test_kernel_wasm {
945+ use super :: test_a_kernel;
946+ use super :: super :: * ;
947+
948+ #[ test]
949+ fn wasm_simd_8x8 ( ) {
950+ test_a_kernel :: < KernelWasmSimd , _ > ( "wasm_simd_8x8" ) ;
951+ }
952+ }
953+
778954 #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
779955 mod test_kernel_x86 {
780956 use super :: test_a_kernel;
0 commit comments