Skip to content

Commit 22552b1

Browse files
committed
sgemm: Add wasm32 SIMD128 kernel
8x8 microkernel ported from the AArch64 NEON path. Gated on target_arch="wasm32" + target_feature="simd128"; falls through to the scalar kernel otherwise.
1 parent affbea7 commit 22552b1

1 file changed

Lines changed: 176 additions & 0 deletions

File tree

src/sgemm_kernel.rs

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ struct KernelSse2;
3030
#[cfg(target_arch="aarch64")]
3131
#[cfg(has_aarch64_simd)]
3232
struct KernelNeon;
33+
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
34+
struct KernelWasmSimd;
3335
struct KernelFallback;
3436

3537
type T = f32;
@@ -62,6 +64,11 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
6264
return selector.select(KernelNeon);
6365
}
6466
}
67+
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
68+
{
69+
return selector.select(KernelWasmSimd);
70+
}
71+
#[allow(unreachable_code)]
6572
return selector.select(KernelFallback);
6673
}
6774

@@ -279,6 +286,38 @@ impl GemmKernel for KernelFallback {
279286
}
280287
}
281288

289+
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
290+
impl GemmKernel for KernelWasmSimd {
291+
type Elem = T;
292+
293+
type MRTy = U8;
294+
type NRTy = U8;
295+
296+
#[inline(always)]
297+
fn align_to() -> usize { 16 }
298+
299+
#[inline(always)]
300+
fn always_masked() -> bool { false }
301+
302+
#[inline(always)]
303+
fn nc() -> usize { archparam::S_NC }
304+
#[inline(always)]
305+
fn kc() -> usize { archparam::S_KC }
306+
#[inline(always)]
307+
fn mc() -> usize { archparam::S_MC }
308+
309+
#[inline(always)]
310+
unsafe fn kernel(
311+
k: usize,
312+
alpha: T,
313+
a: *const T,
314+
b: *const T,
315+
beta: T,
316+
c: *mut T, rsc: isize, csc: isize) {
317+
kernel_target_wasm_simd(k, alpha, a, b, beta, c, rsc, csc)
318+
}
319+
}
320+
282321
// no inline for unmasked kernels
283322
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
284323
#[target_feature(enable="fma")]
@@ -692,6 +731,132 @@ unsafe fn kernel_target_neon(k: usize, alpha: T, a: *const T, b: *const T,
692731
}
693732
}
694733

734+
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
735+
unsafe fn kernel_target_wasm_simd(k: usize, alpha: T, a: *const T, b: *const T,
736+
beta: T, c: *mut T, rsc: isize, csc: isize)
737+
{
738+
use core::arch::wasm32::*;
739+
const MR: usize = KernelWasmSimd::MR;
740+
const NR: usize = KernelWasmSimd::NR;
741+
742+
let (mut a, mut b, rsc, csc) = if rsc == 1 { (b, a, csc, rsc) } else { (a, b, rsc, csc) };
743+
744+
// 8x8 microkernel with 16 v128 accumulators (4 quadrants of 4x4 each).
745+
// Mirrors the AArch64 NEON kernel structure: per k step we read 8 floats
746+
// of A (one packed column slice) and 8 floats of B (one packed row slice),
747+
// then accumulate the outer product into ab[ij].
748+
let zero = f32x4_splat(0.);
749+
let mut ab11 = [zero; 4];
750+
let mut ab12 = [zero; 4];
751+
let mut ab21 = [zero; 4];
752+
let mut ab22 = [zero; 4];
753+
754+
// dest[i] += b * splat(a[i]) for i in 0..4, given a v128 a-vec and a
755+
// v128 b-vec. wasm SIMD has no lane-broadcast-fma op, so we extract+splat.
756+
// Cranelift compiles `f32x4_splat(extract_lane::<L>(v))` to a single
757+
// shuffle on aarch64 hosts, and the add+mul pair to fmla.
758+
macro_rules! ab_ij_equals_ai_bj {
759+
($dest:ident, $av:expr, $bv:expr) => {
760+
$dest[0] = f32x4_add($dest[0], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<0>($av))));
761+
$dest[1] = f32x4_add($dest[1], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<1>($av))));
762+
$dest[2] = f32x4_add($dest[2], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<2>($av))));
763+
$dest[3] = f32x4_add($dest[3], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<3>($av))));
764+
}
765+
}
766+
767+
for _ in 0..k {
768+
let a1 = v128_load(a as *const v128);
769+
let b1 = v128_load(b as *const v128);
770+
let a2 = v128_load(a.add(4) as *const v128);
771+
let b2 = v128_load(b.add(4) as *const v128);
772+
773+
ab_ij_equals_ai_bj!(ab11, a1, b1);
774+
ab_ij_equals_ai_bj!(ab12, a1, b2);
775+
ab_ij_equals_ai_bj!(ab21, a2, b1);
776+
ab_ij_equals_ai_bj!(ab22, a2, b2);
777+
778+
a = a.add(MR);
779+
b = b.add(NR);
780+
}
781+
782+
macro_rules! c {
783+
($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
784+
}
785+
786+
// ab *= alpha
787+
let alphav = f32x4_splat(alpha);
788+
loop4!(i, ab11[i] = f32x4_mul(ab11[i], alphav));
789+
loop4!(i, ab12[i] = f32x4_mul(ab12[i], alphav));
790+
loop4!(i, ab21[i] = f32x4_mul(ab21[i], alphav));
791+
loop4!(i, ab22[i] = f32x4_mul(ab22[i], alphav));
792+
793+
// Build a v128 by gathering four scalars from arbitrary pointers.
794+
macro_rules! loadq_from_pointers {
795+
($p0:expr, $p1:expr, $p2:expr, $p3:expr) => ({
796+
let v = f32x4_splat(0.);
797+
let v = v128_load32_lane::<0>(v, $p0 as *const u32);
798+
let v = v128_load32_lane::<1>(v, $p1 as *const u32);
799+
let v = v128_load32_lane::<2>(v, $p2 as *const u32);
800+
let v = v128_load32_lane::<3>(v, $p3 as *const u32);
801+
v
802+
});
803+
}
804+
805+
if beta != 0. {
806+
let mut c11 = [zero; 4];
807+
let mut c12 = [zero; 4];
808+
let mut c21 = [zero; 4];
809+
let mut c22 = [zero; 4];
810+
811+
if csc == 1 {
812+
loop4!(i, c11[i] = v128_load(c![i + 0, 0] as *const v128));
813+
loop4!(i, c12[i] = v128_load(c![i + 0, 4] as *const v128));
814+
loop4!(i, c21[i] = v128_load(c![i + 4, 0] as *const v128));
815+
loop4!(i, c22[i] = v128_load(c![i + 4, 4] as *const v128));
816+
} else {
817+
loop4!(i, c11[i] = loadq_from_pointers!(c![i + 0, 0], c![i + 0, 1], c![i + 0, 2], c![i + 0, 3]));
818+
loop4!(i, c12[i] = loadq_from_pointers!(c![i + 0, 4], c![i + 0, 5], c![i + 0, 6], c![i + 0, 7]));
819+
loop4!(i, c21[i] = loadq_from_pointers!(c![i + 4, 0], c![i + 4, 1], c![i + 4, 2], c![i + 4, 3]));
820+
loop4!(i, c22[i] = loadq_from_pointers!(c![i + 4, 4], c![i + 4, 5], c![i + 4, 6], c![i + 4, 7]));
821+
}
822+
823+
let betav = f32x4_splat(beta);
824+
// ab += β C
825+
loop4!(i, ab11[i] = f32x4_add(ab11[i], f32x4_mul(c11[i], betav)));
826+
loop4!(i, ab12[i] = f32x4_add(ab12[i], f32x4_mul(c12[i], betav)));
827+
loop4!(i, ab21[i] = f32x4_add(ab21[i], f32x4_mul(c21[i], betav)));
828+
loop4!(i, ab22[i] = f32x4_add(ab22[i], f32x4_mul(c22[i], betav)));
829+
}
830+
831+
// C <- α A B (+ β C)
832+
if csc == 1 {
833+
loop4!(i, v128_store(c![i + 0, 0] as *mut v128, ab11[i]));
834+
loop4!(i, v128_store(c![i + 0, 4] as *mut v128, ab12[i]));
835+
loop4!(i, v128_store(c![i + 4, 0] as *mut v128, ab21[i]));
836+
loop4!(i, v128_store(c![i + 4, 4] as *mut v128, ab22[i]));
837+
} else {
838+
loop4!(i, v128_store32_lane::<0>(ab11[i], c![i + 0, 0] as *mut u32));
839+
loop4!(i, v128_store32_lane::<1>(ab11[i], c![i + 0, 1] as *mut u32));
840+
loop4!(i, v128_store32_lane::<2>(ab11[i], c![i + 0, 2] as *mut u32));
841+
loop4!(i, v128_store32_lane::<3>(ab11[i], c![i + 0, 3] as *mut u32));
842+
843+
loop4!(i, v128_store32_lane::<0>(ab12[i], c![i + 0, 4] as *mut u32));
844+
loop4!(i, v128_store32_lane::<1>(ab12[i], c![i + 0, 5] as *mut u32));
845+
loop4!(i, v128_store32_lane::<2>(ab12[i], c![i + 0, 6] as *mut u32));
846+
loop4!(i, v128_store32_lane::<3>(ab12[i], c![i + 0, 7] as *mut u32));
847+
848+
loop4!(i, v128_store32_lane::<0>(ab21[i], c![i + 4, 0] as *mut u32));
849+
loop4!(i, v128_store32_lane::<1>(ab21[i], c![i + 4, 1] as *mut u32));
850+
loop4!(i, v128_store32_lane::<2>(ab21[i], c![i + 4, 2] as *mut u32));
851+
loop4!(i, v128_store32_lane::<3>(ab21[i], c![i + 4, 3] as *mut u32));
852+
853+
loop4!(i, v128_store32_lane::<0>(ab22[i], c![i + 4, 4] as *mut u32));
854+
loop4!(i, v128_store32_lane::<1>(ab22[i], c![i + 4, 5] as *mut u32));
855+
loop4!(i, v128_store32_lane::<2>(ab22[i], c![i + 4, 6] as *mut u32));
856+
loop4!(i, v128_store32_lane::<3>(ab22[i], c![i + 4, 7] as *mut u32));
857+
}
858+
}
859+
695860
#[inline]
696861
unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
697862
beta: T, c: *mut T, rsc: isize, csc: isize)
@@ -775,6 +940,17 @@ mod tests {
775940
}
776941
}
777942

943+
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
944+
mod test_kernel_wasm {
945+
use super::test_a_kernel;
946+
use super::super::*;
947+
948+
#[test]
949+
fn wasm_simd_8x8() {
950+
test_a_kernel::<KernelWasmSimd, _>("wasm_simd_8x8");
951+
}
952+
}
953+
778954
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
779955
mod test_kernel_x86 {
780956
use super::test_a_kernel;

0 commit comments

Comments
 (0)