Skip to content

Commit ed78f43

Browse files
committed
sgemm: Add wasm32 SIMD128 kernel
8x8 microkernel ported from the AArch64 NEON path. Gated on target_arch="wasm32" + target_feature="simd128"; falls through to the scalar kernel otherwise.
1 parent affbea7 commit ed78f43

1 file changed

Lines changed: 175 additions & 0 deletions

File tree

src/sgemm_kernel.rs

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ struct KernelSse2;
3030
#[cfg(target_arch="aarch64")]
3131
#[cfg(has_aarch64_simd)]
3232
struct KernelNeon;
33+
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
34+
struct KernelWasmSimd;
3335
struct KernelFallback;
3436

3537
type T = f32;
@@ -62,6 +64,11 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
6264
return selector.select(KernelNeon);
6365
}
6466
}
67+
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
68+
{
69+
return selector.select(KernelWasmSimd);
70+
}
71+
#[allow(unreachable_code)]
6572
return selector.select(KernelFallback);
6673
}
6774

@@ -279,6 +286,38 @@ impl GemmKernel for KernelFallback {
279286
}
280287
}
281288

289+
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
290+
impl GemmKernel for KernelWasmSimd {
291+
type Elem = T;
292+
293+
type MRTy = U8;
294+
type NRTy = U8;
295+
296+
#[inline(always)]
297+
fn align_to() -> usize { 16 }
298+
299+
#[inline(always)]
300+
fn always_masked() -> bool { false }
301+
302+
#[inline(always)]
303+
fn nc() -> usize { archparam::S_NC }
304+
#[inline(always)]
305+
fn kc() -> usize { archparam::S_KC }
306+
#[inline(always)]
307+
fn mc() -> usize { archparam::S_MC }
308+
309+
#[inline(always)]
310+
unsafe fn kernel(
311+
k: usize,
312+
alpha: T,
313+
a: *const T,
314+
b: *const T,
315+
beta: T,
316+
c: *mut T, rsc: isize, csc: isize) {
317+
kernel_target_wasm_simd(k, alpha, a, b, beta, c, rsc, csc)
318+
}
319+
}
320+
282321
// no inline for unmasked kernels
283322
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
284323
#[target_feature(enable="fma")]
@@ -692,6 +731,131 @@ unsafe fn kernel_target_neon(k: usize, alpha: T, a: *const T, b: *const T,
692731
}
693732
}
694733

734+
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
735+
unsafe fn kernel_target_wasm_simd(k: usize, alpha: T, a: *const T, b: *const T,
736+
beta: T, c: *mut T, rsc: isize, csc: isize)
737+
{
738+
use core::arch::wasm32::*;
739+
const MR: usize = KernelWasmSimd::MR;
740+
const NR: usize = KernelWasmSimd::NR;
741+
742+
let (mut a, mut b, rsc, csc) = if rsc == 1 { (b, a, csc, rsc) } else { (a, b, rsc, csc) };
743+
744+
// Kernel 8 x 8 (a x b)
745+
// Four quadrants of 4 x 4
746+
let zero = f32x4_splat(0.);
747+
let mut ab11 = [zero; 4];
748+
let mut ab12 = [zero; 4];
749+
let mut ab21 = [zero; 4];
750+
let mut ab22 = [zero; 4];
751+
752+
// ab_ij = a_i * b_j for all i, j
753+
// (wasm SIMD has no lane-FMA; extract+splat into mul+add)
754+
macro_rules! ab_ij_equals_ai_bj {
755+
($dest:ident, $av:expr, $bv:expr) => {
756+
$dest[0] = f32x4_add($dest[0], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<0>($av))));
757+
$dest[1] = f32x4_add($dest[1], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<1>($av))));
758+
$dest[2] = f32x4_add($dest[2], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<2>($av))));
759+
$dest[3] = f32x4_add($dest[3], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<3>($av))));
760+
}
761+
}
762+
763+
for _ in 0..k {
764+
let a1 = v128_load(a as *const v128);
765+
let b1 = v128_load(b as *const v128);
766+
let a2 = v128_load(a.add(4) as *const v128);
767+
let b2 = v128_load(b.add(4) as *const v128);
768+
769+
ab_ij_equals_ai_bj!(ab11, a1, b1);
770+
ab_ij_equals_ai_bj!(ab12, a1, b2);
771+
ab_ij_equals_ai_bj!(ab21, a2, b1);
772+
ab_ij_equals_ai_bj!(ab22, a2, b2);
773+
774+
a = a.add(MR);
775+
b = b.add(NR);
776+
}
777+
778+
macro_rules! c {
779+
($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
780+
}
781+
782+
// ab *= alpha
783+
let alphav = f32x4_splat(alpha);
784+
loop4!(i, ab11[i] = f32x4_mul(ab11[i], alphav));
785+
loop4!(i, ab12[i] = f32x4_mul(ab12[i], alphav));
786+
loop4!(i, ab21[i] = f32x4_mul(ab21[i], alphav));
787+
loop4!(i, ab22[i] = f32x4_mul(ab22[i], alphav));
788+
789+
// load one v128 from four pointers
790+
macro_rules! loadq_from_pointers {
791+
($p0:expr, $p1:expr, $p2:expr, $p3:expr) => ({
792+
let v = f32x4_splat(0.);
793+
let v = v128_load32_lane::<0>(v, $p0 as *const u32);
794+
let v = v128_load32_lane::<1>(v, $p1 as *const u32);
795+
let v = v128_load32_lane::<2>(v, $p2 as *const u32);
796+
let v = v128_load32_lane::<3>(v, $p3 as *const u32);
797+
v
798+
});
799+
}
800+
801+
if beta != 0. {
802+
// load existing value in C
803+
let mut c11 = [zero; 4];
804+
let mut c12 = [zero; 4];
805+
let mut c21 = [zero; 4];
806+
let mut c22 = [zero; 4];
807+
808+
if csc == 1 {
809+
loop4!(i, c11[i] = v128_load(c![i + 0, 0] as *const v128));
810+
loop4!(i, c12[i] = v128_load(c![i + 0, 4] as *const v128));
811+
loop4!(i, c21[i] = v128_load(c![i + 4, 0] as *const v128));
812+
loop4!(i, c22[i] = v128_load(c![i + 4, 4] as *const v128));
813+
} else {
814+
loop4!(i, c11[i] = loadq_from_pointers!(c![i + 0, 0], c![i + 0, 1], c![i + 0, 2], c![i + 0, 3]));
815+
loop4!(i, c12[i] = loadq_from_pointers!(c![i + 0, 4], c![i + 0, 5], c![i + 0, 6], c![i + 0, 7]));
816+
loop4!(i, c21[i] = loadq_from_pointers!(c![i + 4, 0], c![i + 4, 1], c![i + 4, 2], c![i + 4, 3]));
817+
loop4!(i, c22[i] = loadq_from_pointers!(c![i + 4, 4], c![i + 4, 5], c![i + 4, 6], c![i + 4, 7]));
818+
}
819+
820+
let betav = f32x4_splat(beta);
821+
// ab += β C
822+
loop4!(i, ab11[i] = f32x4_add(ab11[i], f32x4_mul(c11[i], betav)));
823+
loop4!(i, ab12[i] = f32x4_add(ab12[i], f32x4_mul(c12[i], betav)));
824+
loop4!(i, ab21[i] = f32x4_add(ab21[i], f32x4_mul(c21[i], betav)));
825+
loop4!(i, ab22[i] = f32x4_add(ab22[i], f32x4_mul(c22[i], betav)));
826+
}
827+
828+
// c <- ab
829+
// which is in full
830+
// C <- α A B (+ β C)
831+
if csc == 1 {
832+
loop4!(i, v128_store(c![i + 0, 0] as *mut v128, ab11[i]));
833+
loop4!(i, v128_store(c![i + 0, 4] as *mut v128, ab12[i]));
834+
loop4!(i, v128_store(c![i + 4, 0] as *mut v128, ab21[i]));
835+
loop4!(i, v128_store(c![i + 4, 4] as *mut v128, ab22[i]));
836+
} else {
837+
loop4!(i, v128_store32_lane::<0>(ab11[i], c![i + 0, 0] as *mut u32));
838+
loop4!(i, v128_store32_lane::<1>(ab11[i], c![i + 0, 1] as *mut u32));
839+
loop4!(i, v128_store32_lane::<2>(ab11[i], c![i + 0, 2] as *mut u32));
840+
loop4!(i, v128_store32_lane::<3>(ab11[i], c![i + 0, 3] as *mut u32));
841+
842+
loop4!(i, v128_store32_lane::<0>(ab12[i], c![i + 0, 4] as *mut u32));
843+
loop4!(i, v128_store32_lane::<1>(ab12[i], c![i + 0, 5] as *mut u32));
844+
loop4!(i, v128_store32_lane::<2>(ab12[i], c![i + 0, 6] as *mut u32));
845+
loop4!(i, v128_store32_lane::<3>(ab12[i], c![i + 0, 7] as *mut u32));
846+
847+
loop4!(i, v128_store32_lane::<0>(ab21[i], c![i + 4, 0] as *mut u32));
848+
loop4!(i, v128_store32_lane::<1>(ab21[i], c![i + 4, 1] as *mut u32));
849+
loop4!(i, v128_store32_lane::<2>(ab21[i], c![i + 4, 2] as *mut u32));
850+
loop4!(i, v128_store32_lane::<3>(ab21[i], c![i + 4, 3] as *mut u32));
851+
852+
loop4!(i, v128_store32_lane::<0>(ab22[i], c![i + 4, 4] as *mut u32));
853+
loop4!(i, v128_store32_lane::<1>(ab22[i], c![i + 4, 5] as *mut u32));
854+
loop4!(i, v128_store32_lane::<2>(ab22[i], c![i + 4, 6] as *mut u32));
855+
loop4!(i, v128_store32_lane::<3>(ab22[i], c![i + 4, 7] as *mut u32));
856+
}
857+
}
858+
695859
#[inline]
696860
unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
697861
beta: T, c: *mut T, rsc: isize, csc: isize)
@@ -775,6 +939,17 @@ mod tests {
775939
}
776940
}
777941

942+
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
943+
mod test_kernel_wasm {
944+
use super::test_a_kernel;
945+
use super::super::*;
946+
947+
#[test]
948+
fn wasm_simd_8x8() {
949+
test_a_kernel::<KernelWasmSimd, _>("wasm_simd_8x8");
950+
}
951+
}
952+
778953
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
779954
mod test_kernel_x86 {
780955
use super::test_a_kernel;

0 commit comments

Comments
 (0)