@@ -30,6 +30,8 @@ struct KernelSse2;
3030#[ cfg( target_arch="aarch64" ) ]
3131#[ cfg( has_aarch64_simd) ]
3232struct KernelNeon ;
33+ #[ cfg( all( target_arch="wasm32" , target_feature="simd128" ) ) ]
34+ struct KernelWasmSimd ;
3335struct KernelFallback ;
3436
3537type T = f32 ;
@@ -62,6 +64,11 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
6264 return selector. select ( KernelNeon ) ;
6365 }
6466 }
67+ #[ cfg( all( target_arch="wasm32" , target_feature="simd128" ) ) ]
68+ {
69+ return selector. select ( KernelWasmSimd ) ;
70+ }
71+ #[ allow( unreachable_code) ]
6572 return selector. select ( KernelFallback ) ;
6673}
6774
@@ -279,6 +286,38 @@ impl GemmKernel for KernelFallback {
279286 }
280287}
281288
289+ #[ cfg( all( target_arch="wasm32" , target_feature="simd128" ) ) ]
290+ impl GemmKernel for KernelWasmSimd {
291+ type Elem = T ;
292+
293+ type MRTy = U8 ;
294+ type NRTy = U8 ;
295+
296+ #[ inline( always) ]
297+ fn align_to ( ) -> usize { 16 }
298+
299+ #[ inline( always) ]
300+ fn always_masked ( ) -> bool { false }
301+
302+ #[ inline( always) ]
303+ fn nc ( ) -> usize { archparam:: S_NC }
304+ #[ inline( always) ]
305+ fn kc ( ) -> usize { archparam:: S_KC }
306+ #[ inline( always) ]
307+ fn mc ( ) -> usize { archparam:: S_MC }
308+
309+ #[ inline( always) ]
310+ unsafe fn kernel (
311+ k : usize ,
312+ alpha : T ,
313+ a : * const T ,
314+ b : * const T ,
315+ beta : T ,
316+ c : * mut T , rsc : isize , csc : isize ) {
317+ kernel_target_wasm_simd ( k, alpha, a, b, beta, c, rsc, csc)
318+ }
319+ }
320+
282321// no inline for unmasked kernels
283322#[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
284323#[ target_feature( enable="fma" ) ]
@@ -692,6 +731,131 @@ unsafe fn kernel_target_neon(k: usize, alpha: T, a: *const T, b: *const T,
692731 }
693732}
694733
734+ #[ cfg( all( target_arch="wasm32" , target_feature="simd128" ) ) ]
735+ unsafe fn kernel_target_wasm_simd ( k : usize , alpha : T , a : * const T , b : * const T ,
736+ beta : T , c : * mut T , rsc : isize , csc : isize )
737+ {
738+ use core:: arch:: wasm32:: * ;
739+ const MR : usize = KernelWasmSimd :: MR ;
740+ const NR : usize = KernelWasmSimd :: NR ;
741+
742+ let ( mut a, mut b, rsc, csc) = if rsc == 1 { ( b, a, csc, rsc) } else { ( a, b, rsc, csc) } ;
743+
744+ // Kernel 8 x 8 (a x b)
745+ // Four quadrants of 4 x 4
746+ let zero = f32x4_splat ( 0. ) ;
747+ let mut ab11 = [ zero; 4 ] ;
748+ let mut ab12 = [ zero; 4 ] ;
749+ let mut ab21 = [ zero; 4 ] ;
750+ let mut ab22 = [ zero; 4 ] ;
751+
752+ // ab_ij = a_i * b_j for all i, j
753+ // (wasm SIMD has no lane-FMA; extract+splat into mul+add)
754+ macro_rules! ab_ij_equals_ai_bj {
755+ ( $dest: ident, $av: expr, $bv: expr) => {
756+ $dest[ 0 ] = f32x4_add( $dest[ 0 ] , f32x4_mul( $bv, f32x4_splat( f32x4_extract_lane:: <0 >( $av) ) ) ) ;
757+ $dest[ 1 ] = f32x4_add( $dest[ 1 ] , f32x4_mul( $bv, f32x4_splat( f32x4_extract_lane:: <1 >( $av) ) ) ) ;
758+ $dest[ 2 ] = f32x4_add( $dest[ 2 ] , f32x4_mul( $bv, f32x4_splat( f32x4_extract_lane:: <2 >( $av) ) ) ) ;
759+ $dest[ 3 ] = f32x4_add( $dest[ 3 ] , f32x4_mul( $bv, f32x4_splat( f32x4_extract_lane:: <3 >( $av) ) ) ) ;
760+ }
761+ }
762+
763+ for _ in 0 ..k {
764+ let a1 = v128_load ( a as * const v128 ) ;
765+ let b1 = v128_load ( b as * const v128 ) ;
766+ let a2 = v128_load ( a. add ( 4 ) as * const v128 ) ;
767+ let b2 = v128_load ( b. add ( 4 ) as * const v128 ) ;
768+
769+ ab_ij_equals_ai_bj ! ( ab11, a1, b1) ;
770+ ab_ij_equals_ai_bj ! ( ab12, a1, b2) ;
771+ ab_ij_equals_ai_bj ! ( ab21, a2, b1) ;
772+ ab_ij_equals_ai_bj ! ( ab22, a2, b2) ;
773+
774+ a = a. add ( MR ) ;
775+ b = b. add ( NR ) ;
776+ }
777+
778+ macro_rules! c {
779+ ( $i: expr, $j: expr) => ( c. offset( rsc * $i as isize + csc * $j as isize ) ) ;
780+ }
781+
782+ // ab *= alpha
783+ let alphav = f32x4_splat ( alpha) ;
784+ loop4 ! ( i, ab11[ i] = f32x4_mul( ab11[ i] , alphav) ) ;
785+ loop4 ! ( i, ab12[ i] = f32x4_mul( ab12[ i] , alphav) ) ;
786+ loop4 ! ( i, ab21[ i] = f32x4_mul( ab21[ i] , alphav) ) ;
787+ loop4 ! ( i, ab22[ i] = f32x4_mul( ab22[ i] , alphav) ) ;
788+
789+ // load one v128 from four pointers
790+ macro_rules! loadq_from_pointers {
791+ ( $p0: expr, $p1: expr, $p2: expr, $p3: expr) => ( {
792+ let v = f32x4_splat( 0. ) ;
793+ let v = v128_load32_lane:: <0 >( v, $p0 as * const u32 ) ;
794+ let v = v128_load32_lane:: <1 >( v, $p1 as * const u32 ) ;
795+ let v = v128_load32_lane:: <2 >( v, $p2 as * const u32 ) ;
796+ let v = v128_load32_lane:: <3 >( v, $p3 as * const u32 ) ;
797+ v
798+ } ) ;
799+ }
800+
801+ if beta != 0. {
802+ // load existing value in C
803+ let mut c11 = [ zero; 4 ] ;
804+ let mut c12 = [ zero; 4 ] ;
805+ let mut c21 = [ zero; 4 ] ;
806+ let mut c22 = [ zero; 4 ] ;
807+
808+ if csc == 1 {
809+ loop4 ! ( i, c11[ i] = v128_load( c![ i + 0 , 0 ] as * const v128) ) ;
810+ loop4 ! ( i, c12[ i] = v128_load( c![ i + 0 , 4 ] as * const v128) ) ;
811+ loop4 ! ( i, c21[ i] = v128_load( c![ i + 4 , 0 ] as * const v128) ) ;
812+ loop4 ! ( i, c22[ i] = v128_load( c![ i + 4 , 4 ] as * const v128) ) ;
813+ } else {
814+ loop4 ! ( i, c11[ i] = loadq_from_pointers!( c![ i + 0 , 0 ] , c![ i + 0 , 1 ] , c![ i + 0 , 2 ] , c![ i + 0 , 3 ] ) ) ;
815+ loop4 ! ( i, c12[ i] = loadq_from_pointers!( c![ i + 0 , 4 ] , c![ i + 0 , 5 ] , c![ i + 0 , 6 ] , c![ i + 0 , 7 ] ) ) ;
816+ loop4 ! ( i, c21[ i] = loadq_from_pointers!( c![ i + 4 , 0 ] , c![ i + 4 , 1 ] , c![ i + 4 , 2 ] , c![ i + 4 , 3 ] ) ) ;
817+ loop4 ! ( i, c22[ i] = loadq_from_pointers!( c![ i + 4 , 4 ] , c![ i + 4 , 5 ] , c![ i + 4 , 6 ] , c![ i + 4 , 7 ] ) ) ;
818+ }
819+
820+ let betav = f32x4_splat ( beta) ;
821+ // ab += β C
822+ loop4 ! ( i, ab11[ i] = f32x4_add( ab11[ i] , f32x4_mul( c11[ i] , betav) ) ) ;
823+ loop4 ! ( i, ab12[ i] = f32x4_add( ab12[ i] , f32x4_mul( c12[ i] , betav) ) ) ;
824+ loop4 ! ( i, ab21[ i] = f32x4_add( ab21[ i] , f32x4_mul( c21[ i] , betav) ) ) ;
825+ loop4 ! ( i, ab22[ i] = f32x4_add( ab22[ i] , f32x4_mul( c22[ i] , betav) ) ) ;
826+ }
827+
828+ // c <- ab
829+ // which is in full
830+ // C <- α A B (+ β C)
831+ if csc == 1 {
832+ loop4 ! ( i, v128_store( c![ i + 0 , 0 ] as * mut v128, ab11[ i] ) ) ;
833+ loop4 ! ( i, v128_store( c![ i + 0 , 4 ] as * mut v128, ab12[ i] ) ) ;
834+ loop4 ! ( i, v128_store( c![ i + 4 , 0 ] as * mut v128, ab21[ i] ) ) ;
835+ loop4 ! ( i, v128_store( c![ i + 4 , 4 ] as * mut v128, ab22[ i] ) ) ;
836+ } else {
837+ loop4 ! ( i, v128_store32_lane:: <0 >( ab11[ i] , c![ i + 0 , 0 ] as * mut u32 ) ) ;
838+ loop4 ! ( i, v128_store32_lane:: <1 >( ab11[ i] , c![ i + 0 , 1 ] as * mut u32 ) ) ;
839+ loop4 ! ( i, v128_store32_lane:: <2 >( ab11[ i] , c![ i + 0 , 2 ] as * mut u32 ) ) ;
840+ loop4 ! ( i, v128_store32_lane:: <3 >( ab11[ i] , c![ i + 0 , 3 ] as * mut u32 ) ) ;
841+
842+ loop4 ! ( i, v128_store32_lane:: <0 >( ab12[ i] , c![ i + 0 , 4 ] as * mut u32 ) ) ;
843+ loop4 ! ( i, v128_store32_lane:: <1 >( ab12[ i] , c![ i + 0 , 5 ] as * mut u32 ) ) ;
844+ loop4 ! ( i, v128_store32_lane:: <2 >( ab12[ i] , c![ i + 0 , 6 ] as * mut u32 ) ) ;
845+ loop4 ! ( i, v128_store32_lane:: <3 >( ab12[ i] , c![ i + 0 , 7 ] as * mut u32 ) ) ;
846+
847+ loop4 ! ( i, v128_store32_lane:: <0 >( ab21[ i] , c![ i + 4 , 0 ] as * mut u32 ) ) ;
848+ loop4 ! ( i, v128_store32_lane:: <1 >( ab21[ i] , c![ i + 4 , 1 ] as * mut u32 ) ) ;
849+ loop4 ! ( i, v128_store32_lane:: <2 >( ab21[ i] , c![ i + 4 , 2 ] as * mut u32 ) ) ;
850+ loop4 ! ( i, v128_store32_lane:: <3 >( ab21[ i] , c![ i + 4 , 3 ] as * mut u32 ) ) ;
851+
852+ loop4 ! ( i, v128_store32_lane:: <0 >( ab22[ i] , c![ i + 4 , 4 ] as * mut u32 ) ) ;
853+ loop4 ! ( i, v128_store32_lane:: <1 >( ab22[ i] , c![ i + 4 , 5 ] as * mut u32 ) ) ;
854+ loop4 ! ( i, v128_store32_lane:: <2 >( ab22[ i] , c![ i + 4 , 6 ] as * mut u32 ) ) ;
855+ loop4 ! ( i, v128_store32_lane:: <3 >( ab22[ i] , c![ i + 4 , 7 ] as * mut u32 ) ) ;
856+ }
857+ }
858+
695859#[ inline]
696860unsafe fn kernel_fallback_impl ( k : usize , alpha : T , a : * const T , b : * const T ,
697861 beta : T , c : * mut T , rsc : isize , csc : isize )
@@ -775,6 +939,17 @@ mod tests {
775939 }
776940 }
777941
942+ #[ cfg( all( target_arch="wasm32" , target_feature="simd128" ) ) ]
943+ mod test_kernel_wasm {
944+ use super :: test_a_kernel;
945+ use super :: super :: * ;
946+
947+ #[ test]
948+ fn wasm_simd_8x8 ( ) {
949+ test_a_kernel :: < KernelWasmSimd , _ > ( "wasm_simd_8x8" ) ;
950+ }
951+ }
952+
778953 #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
779954 mod test_kernel_x86 {
780955 use super :: test_a_kernel;
0 commit comments