Skip to content

Commit 86d1451

Browse files
committed
Add WebAssembly SIMD GEMM kernels
1 parent ddfbc64 commit 86d1451

File tree

2 files changed

+262
-2
lines changed

2 files changed

+262
-2
lines changed

kernel/wasm/KERNEL.WASM128_GENERIC

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,13 @@ DTRMMKERNEL = ../generic/trmmkernel_2x2.c
100100
CTRMMKERNEL = ../generic/ztrmmkernel_2x2.c
101101
ZTRMMKERNEL = ../generic/ztrmmkernel_2x2.c
102102

103-
SGEMMKERNEL = ../generic/gemmkernel_2x2.c
103+
SGEMMKERNEL = gemmkernel_wasm128.c
104104
SGEMMONCOPY = ../generic/gemm_ncopy_2.c
105105
SGEMMOTCOPY = ../generic/gemm_tcopy_2.c
106106
SGEMMONCOPYOBJ = sgemm_oncopy$(TSUFFIX).$(SUFFIX)
107107
SGEMMOTCOPYOBJ = sgemm_otcopy$(TSUFFIX).$(SUFFIX)
108108

109-
DGEMMKERNEL = ../generic/gemmkernel_2x2.c
109+
DGEMMKERNEL = gemmkernel_wasm128.c
110110
DGEMMONCOPY = ../generic/gemm_ncopy_2.c
111111
DGEMMOTCOPY = ../generic/gemm_tcopy_2.c
112112
DGEMMONCOPYOBJ = dgemm_oncopy$(TSUFFIX).$(SUFFIX)

kernel/wasm/gemmkernel_wasm128.c

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
/***************************************************************************
2+
Copyright (c) 2026, The OpenBLAS Project
3+
All rights reserved.
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are
6+
met:
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in
11+
the documentation and/or other materials provided with the
12+
distribution.
13+
3. Neither the name of the OpenBLAS project nor the names of
14+
its contributors may be used to endorse or promote products
15+
derived from this software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
*****************************************************************************/
27+
28+
#include "common.h"
29+
#include "../generic/conversion_macros.h"
30+
31+
#if defined(__wasm_simd128__)
32+
#include <wasm_simd128.h>
33+
#endif
34+
35+
#if defined(__wasm_simd128__)
36+
#ifndef DOUBLE
37+
static inline FLOAT hsum_vec(v128_t v) {
38+
return wasm_f32x4_extract_lane(v, 0) + wasm_f32x4_extract_lane(v, 1) +
39+
wasm_f32x4_extract_lane(v, 2) + wasm_f32x4_extract_lane(v, 3);
40+
}
41+
#else
42+
static inline FLOAT hsum_vec(v128_t v) {
43+
return wasm_f64x2_extract_lane(v, 0) + wasm_f64x2_extract_lane(v, 1);
44+
}
45+
#endif
46+
#endif
47+
48+
int CNAME(BLASLONG bm, BLASLONG bn, BLASLONG bk, FLOAT alpha, IFLOAT *ba,
49+
IFLOAT *bb, FLOAT *C, BLASLONG ldc
50+
#ifdef TRMMKERNEL
51+
,
52+
BLASLONG offset
53+
#endif
54+
)
55+
{
56+
BLASLONG i, j, k;
57+
FLOAT *C0, *C1;
58+
IFLOAT *ptrba, *ptrbb;
59+
#ifdef BGEMM
60+
float res0, res1, res2, res3;
61+
#else
62+
FLOAT res0, res1, res2, res3;
63+
#endif
64+
IFLOAT load0, load1, load2, load3, load4, load5, load6, load7;
65+
66+
for (j = 0; j < bn / 2; j += 1) {
67+
C0 = C;
68+
C1 = C0 + ldc;
69+
ptrba = ba;
70+
71+
for (i = 0; i < bm / 2; i += 1) {
72+
ptrbb = bb;
73+
res0 = 0;
74+
res1 = 0;
75+
res2 = 0;
76+
res3 = 0;
77+
78+
#if defined(__wasm_simd128__) && !defined(BGEMM)
79+
#ifndef DOUBLE
80+
{
81+
v128_t vacc00 = wasm_f32x4_splat(0.0f);
82+
v128_t vacc10 = wasm_f32x4_splat(0.0f);
83+
v128_t vacc01 = wasm_f32x4_splat(0.0f);
84+
v128_t vacc11 = wasm_f32x4_splat(0.0f);
85+
86+
k = 0;
87+
for (; k + 4 <= bk; k += 4) {
88+
v128_t va01 = wasm_v128_load(ptrba);
89+
v128_t va23 = wasm_v128_load(ptrba + 4);
90+
v128_t vb01 = wasm_v128_load(ptrbb);
91+
v128_t vb23 = wasm_v128_load(ptrbb + 4);
92+
93+
v128_t vrow0 =
94+
wasm_i32x4_shuffle(va01, va23, 0, 2, 4, 6);
95+
v128_t vrow1 =
96+
wasm_i32x4_shuffle(va01, va23, 1, 3, 5, 7);
97+
v128_t vcol0 =
98+
wasm_i32x4_shuffle(vb01, vb23, 0, 2, 4, 6);
99+
v128_t vcol1 =
100+
wasm_i32x4_shuffle(vb01, vb23, 1, 3, 5, 7);
101+
102+
vacc00 = wasm_f32x4_add(
103+
vacc00, wasm_f32x4_mul(vrow0, vcol0));
104+
vacc10 = wasm_f32x4_add(
105+
vacc10, wasm_f32x4_mul(vrow1, vcol0));
106+
vacc01 = wasm_f32x4_add(
107+
vacc01, wasm_f32x4_mul(vrow0, vcol1));
108+
vacc11 = wasm_f32x4_add(
109+
vacc11, wasm_f32x4_mul(vrow1, vcol1));
110+
111+
ptrba += 8;
112+
ptrbb += 8;
113+
}
114+
115+
res0 += hsum_vec(vacc00);
116+
res1 += hsum_vec(vacc10);
117+
res2 += hsum_vec(vacc01);
118+
res3 += hsum_vec(vacc11);
119+
}
120+
#else
121+
{
122+
v128_t vacc00 = wasm_f64x2_splat(0.0);
123+
v128_t vacc10 = wasm_f64x2_splat(0.0);
124+
v128_t vacc01 = wasm_f64x2_splat(0.0);
125+
v128_t vacc11 = wasm_f64x2_splat(0.0);
126+
127+
for (k = 0; k + 2 <= bk; k += 2) {
128+
v128_t va01 = wasm_v128_load(ptrba);
129+
v128_t va23 = wasm_v128_load(ptrba + 2);
130+
v128_t vb01 = wasm_v128_load(ptrbb);
131+
v128_t vb23 = wasm_v128_load(ptrbb + 2);
132+
133+
v128_t vrow0 =
134+
wasm_i64x2_shuffle(va01, va23, 0, 2);
135+
v128_t vrow1 =
136+
wasm_i64x2_shuffle(va01, va23, 1, 3);
137+
v128_t vcol0 =
138+
wasm_i64x2_shuffle(vb01, vb23, 0, 2);
139+
v128_t vcol1 =
140+
wasm_i64x2_shuffle(vb01, vb23, 1, 3);
141+
142+
vacc00 = wasm_f64x2_add(
143+
vacc00, wasm_f64x2_mul(vrow0, vcol0));
144+
vacc10 = wasm_f64x2_add(
145+
vacc10, wasm_f64x2_mul(vrow1, vcol0));
146+
vacc01 = wasm_f64x2_add(
147+
vacc01, wasm_f64x2_mul(vrow0, vcol1));
148+
vacc11 = wasm_f64x2_add(
149+
vacc11, wasm_f64x2_mul(vrow1, vcol1));
150+
151+
ptrba += 4;
152+
ptrbb += 4;
153+
}
154+
155+
res0 += hsum_vec(vacc00);
156+
res1 += hsum_vec(vacc10);
157+
res2 += hsum_vec(vacc01);
158+
res3 += hsum_vec(vacc11);
159+
}
160+
#endif
161+
#else
162+
k = 0;
163+
#endif
164+
165+
for (; k < bk; k += 1) {
166+
load0 = ptrba[2 * 0 + 0];
167+
load1 = ptrbb[2 * 0 + 0];
168+
res0 = res0 + TO_F32(load0) * TO_F32(load1);
169+
load2 = ptrba[2 * 0 + 1];
170+
res1 = res1 + TO_F32(load2) * TO_F32(load1);
171+
load3 = ptrbb[2 * 0 + 1];
172+
res2 = res2 + TO_F32(load0) * TO_F32(load3);
173+
res3 = res3 + TO_F32(load2) * TO_F32(load3);
174+
ptrba = ptrba + 2;
175+
ptrbb = ptrbb + 2;
176+
}
177+
178+
res0 = res0 * ALPHA;
179+
C0[0] = TO_OUTPUT(TO_F32(C0[0]) + res0);
180+
res1 = res1 * ALPHA;
181+
C0[1] = TO_OUTPUT(TO_F32(C0[1]) + res1);
182+
res2 = res2 * ALPHA;
183+
C1[0] = TO_OUTPUT(TO_F32(C1[0]) + res2);
184+
res3 = res3 * ALPHA;
185+
C1[1] = TO_OUTPUT(TO_F32(C1[1]) + res3);
186+
C0 = C0 + 2;
187+
C1 = C1 + 2;
188+
}
189+
190+
for (i = 0; i < (bm & 1); i += 1) {
191+
ptrbb = bb;
192+
res0 = 0;
193+
res1 = 0;
194+
for (k = 0; k < bk; k += 1) {
195+
load0 = ptrba[0 + 0];
196+
load1 = ptrbb[2 * 0 + 0];
197+
res0 = res0 + TO_F32(load0) * TO_F32(load1);
198+
load2 = ptrbb[2 * 0 + 1];
199+
res1 = res1 + TO_F32(load0) * TO_F32(load2);
200+
ptrba = ptrba + 1;
201+
ptrbb = ptrbb + 2;
202+
}
203+
res0 = res0 * ALPHA;
204+
C0[0] = TO_OUTPUT(TO_F32(C0[0]) + res0);
205+
res1 = res1 * ALPHA;
206+
C1[0] = TO_OUTPUT(TO_F32(C1[0]) + res1);
207+
C0 = C0 + 1;
208+
C1 = C1 + 1;
209+
}
210+
211+
k = (bk << 1);
212+
bb = bb + k;
213+
i = (ldc << 1);
214+
C = C + i;
215+
}
216+
217+
for (j = 0; j < (bn & 1); j += 1) {
218+
C0 = C;
219+
ptrba = ba;
220+
for (i = 0; i < bm / 2; i += 1) {
221+
ptrbb = bb;
222+
res0 = 0;
223+
res1 = 0;
224+
for (k = 0; k < bk; k += 1) {
225+
load0 = ptrba[2 * 0 + 0];
226+
load1 = ptrbb[0 + 0];
227+
res0 = res0 + TO_F32(load0) * TO_F32(load1);
228+
load2 = ptrba[2 * 0 + 1];
229+
res1 = res1 + TO_F32(load2) * TO_F32(load1);
230+
ptrba = ptrba + 2;
231+
ptrbb = ptrbb + 1;
232+
}
233+
res0 = res0 * ALPHA;
234+
C0[0] = TO_OUTPUT(TO_F32(C0[0]) + res0);
235+
res1 = res1 * ALPHA;
236+
C0[1] = TO_OUTPUT(TO_F32(C0[1]) + res1);
237+
C0 = C0 + 2;
238+
}
239+
240+
for (i = 0; i < (bm & 1); i += 1) {
241+
ptrbb = bb;
242+
res0 = 0;
243+
for (k = 0; k < bk; k += 1) {
244+
load0 = ptrba[0 + 0];
245+
load1 = ptrbb[0 + 0];
246+
res0 = res0 + TO_F32(load0) * TO_F32(load1);
247+
ptrba = ptrba + 1;
248+
ptrbb = ptrbb + 1;
249+
}
250+
res0 = res0 * ALPHA;
251+
C0[0] = TO_OUTPUT(TO_F32(C0[0]) + res0);
252+
C0 = C0 + 1;
253+
}
254+
k = bk;
255+
bb = bb + k;
256+
i = ldc;
257+
C = C + i;
258+
}
259+
return 0;
260+
}

0 commit comments

Comments
 (0)