@@ -17,6 +17,17 @@ import sk.ainet.backend.api.kernel.Fp32MatmulKernel
1717 * ([TILE_M], [TILE_N], [TILE_K]). Default 8×8×128 keeps a working
1818 * set well under L1 — eight A rows × 128 floats + eight Bᵀ rows ×
1919 * 128 floats ≈ 8 KB, within typical 32 KB L1.
20+ * - Within each (TILE_M × TILE_N) sub-tile, [mnpack] recursively
21+ * dispatches into `RM × RN` micro-kernels — `gemm4x3`, `gemm2x2`,
22+ * `gemm2x1`, `gemm1x2`, `gemm1x1`. Each micro-kernel keeps
23+ * `RM × RN` `FloatVector` accumulators in locals and amortizes
24+ * every A-row load across `RN` columns and every B-column load
25+ * across `RM` rows. This mirrors the tile-dispatch pattern from
26+ * tinyBLAS (`sgemm.cpp`, Justine Tunney / llamafile).
27+ * - On AVX2 the largest microkernel that fits inside 16 YMM registers
28+ * is `4 × 3` (12 accumulators + at most 4 A vectors + 1 B vector
29+ * live at once). Smaller microkernels cover residual rows and
30+ * columns that don't divide evenly into the larger tile shape.
2031 * - Inner reduction is a vector-width FMA accumulator
2132 * (`v.fma(w, acc)`), reduced via `reduceLanes(ADD)` once per
2233 * `(i, j)` cell per K-tile. Tail elements that don't fill a vector
@@ -59,7 +70,7 @@ public object PanamaVectorMatmulKernel : Fp32MatmulKernel {
5970 }
6071 if (k == 0 ) return
6172
62- // Pack B^T: bt[j, kk] = b[kk, j].
73+ // Pack B^T: bt[j, kk] = b[kk, j]. Row stride in bt is k.
6374 val bt = FloatArray (n * k)
6475 for (kk in 0 until k) {
6576 val src = bOffset + kk * bStride
@@ -68,8 +79,6 @@ public object PanamaVectorMatmulKernel : Fp32MatmulKernel {
6879 }
6980 }
7081
71- val step = species.length()
72-
7382 var mTile = 0
7483 while (mTile < m) {
7584 val mEnd = minOf(mTile + TILE_M , m)
@@ -79,34 +88,338 @@ public object PanamaVectorMatmulKernel : Fp32MatmulKernel {
7988 var kTile = 0
8089 while (kTile < k) {
8190 val kEnd = minOf(kTile + TILE_K , k)
82- val kLen = kEnd - kTile
83- val loopBound = species.loopBound(kLen)
84- for (i in mTile until mEnd) {
85- val aRowBase = aOffset + i * aStride + kTile
86- val outRowBase = outOffset + i * outStride
87- for (j in nTile until nEnd) {
88- val btRowBase = j * k + kTile
89- var acc = FloatVector .zero(species)
90- var idx = 0
91- while (idx < loopBound) {
92- val va = FloatVector .fromArray(species, a, aRowBase + idx)
93- val vb = FloatVector .fromArray(species, bt, btRowBase + idx)
94- acc = va.fma(vb, acc)
95- idx + = step
96- }
97- var sum = acc.reduceLanes(VectorOperators .ADD )
98- while (idx < kLen) {
99- sum + = a[aRowBase + idx] * bt[btRowBase + idx]
100- idx++
101- }
102- out [outRowBase + j] + = sum
103- }
104- }
91+ mnpack(
92+ a, aOffset, aStride,
93+ bt, k,
94+ out , outOffset, outStride,
95+ mTile, mEnd, nTile, nEnd,
96+ kTile, kEnd - kTile,
97+ )
10598 kTile = kEnd
10699 }
107100 nTile = nEnd
108101 }
109102 mTile = mEnd
110103 }
111104 }
105+
106+ /* *
107+ * Recursive (m, n) tile dispatch. Picks the largest microkernel
108+ * shape `(RM, RN)` that fits the residual `(m1-m0, n1-n0)`, calls it
109+ * over the aligned sub-rectangle `[m0..mp) × [n0..np)`, then recurses
110+ * on the residual rows `[mp..m1) × [n0..np)` and the residual columns
111+ * `[m0..m1) × [np..n1)`. Mirrors the tinyBLAS `mnpack` switch but
112+ * uses only the AVX2-friendly microkernel set (16 vector registers).
113+ */
114+ private fun mnpack (
115+ a : FloatArray , aOffset : Int , aStride : Int ,
116+ bt : FloatArray , btStride : Int ,
117+ out : FloatArray , outOffset : Int , outStride : Int ,
118+ m0 : Int , m1 : Int , n0 : Int , n1 : Int ,
119+ kStart : Int , kLen : Int ,
120+ ) {
121+ if (m1 <= m0 || n1 <= n0) return
122+
123+ val rm = minOf(m1 - m0, 4 )
124+ val rn = minOf(n1 - n0, 3 )
125+ val mc: Int
126+ val nc: Int
127+ when ((rm shl 4 ) or rn) {
128+ 0x43 -> {
129+ mc = 4 ; nc = 3
130+ gemm4x3(a, aOffset, aStride, bt, btStride, out , outOffset, outStride,
131+ m0, m0 + ((m1 - m0) / mc) * mc, n0, n0 + ((n1 - n0) / nc) * nc, kStart, kLen)
132+ }
133+ 0x42 , 0x33 , 0x32 , 0x23 , 0x22 -> {
134+ mc = 2 ; nc = 2
135+ gemm2x2(a, aOffset, aStride, bt, btStride, out , outOffset, outStride,
136+ m0, m0 + ((m1 - m0) / mc) * mc, n0, n0 + ((n1 - n0) / nc) * nc, kStart, kLen)
137+ }
138+ 0x41 , 0x31 , 0x21 -> {
139+ mc = 2 ; nc = 1
140+ gemm2x1(a, aOffset, aStride, bt, btStride, out , outOffset, outStride,
141+ m0, m0 + ((m1 - m0) / mc) * mc, n0, n0 + ((n1 - n0) / nc) * nc, kStart, kLen)
142+ }
143+ 0x13 , 0x12 -> {
144+ mc = 1 ; nc = 2
145+ gemm1x2(a, aOffset, aStride, bt, btStride, out , outOffset, outStride,
146+ m0, m0 + ((m1 - m0) / mc) * mc, n0, n0 + ((n1 - n0) / nc) * nc, kStart, kLen)
147+ }
148+ 0x11 -> {
149+ mc = 1 ; nc = 1
150+ gemm1x1(a, aOffset, aStride, bt, btStride, out , outOffset, outStride,
151+ m0, m0 + ((m1 - m0) / mc) * mc, n0, n0 + ((n1 - n0) / nc) * nc, kStart, kLen)
152+ }
153+ else -> return
154+ }
155+ val mp = m0 + ((m1 - m0) / mc) * mc
156+ val np = n0 + ((n1 - n0) / nc) * nc
157+ if (mp < m1) mnpack(a, aOffset, aStride, bt, btStride, out , outOffset, outStride,
158+ mp, m1, n0, np, kStart, kLen)
159+ if (np < n1) mnpack(a, aOffset, aStride, bt, btStride, out , outOffset, outStride,
160+ m0, m1, np, n1, kStart, kLen)
161+ }
162+
163+ /* *
164+ * Largest AVX2-friendly microkernel: 4 rows × 3 cols, 12 accumulators.
165+ * Loads 4 A vectors and 3 B vectors per `k` step, issues 12 FMAs.
166+ * Caller guarantees `(m1 - m0)` is a multiple of 4 and `(n1 - n0)` of 3.
167+ */
168+ private fun gemm4x3 (
169+ a : FloatArray , aOffset : Int , aStride : Int ,
170+ bt : FloatArray , btStride : Int ,
171+ out : FloatArray , outOffset : Int , outStride : Int ,
172+ m0 : Int , m1 : Int , n0 : Int , n1 : Int ,
173+ kStart : Int , kLen : Int ,
174+ ) {
175+ val step = species.length()
176+ val loopBound = species.loopBound(kLen)
177+ var ii = m0
178+ while (ii < m1) {
179+ val a0Base = aOffset + ii * aStride + kStart
180+ val a1Base = a0Base + aStride
181+ val a2Base = a1Base + aStride
182+ val a3Base = a2Base + aStride
183+ val outRow0 = outOffset + ii * outStride
184+ val outRow1 = outRow0 + outStride
185+ val outRow2 = outRow1 + outStride
186+ val outRow3 = outRow2 + outStride
187+ var jj = n0
188+ while (jj < n1) {
189+ val b0Base = jj * btStride + kStart
190+ val b1Base = b0Base + btStride
191+ val b2Base = b1Base + btStride
192+
193+ var c00 = FloatVector .zero(species); var c01 = FloatVector .zero(species); var c02 = FloatVector .zero(species)
194+ var c10 = FloatVector .zero(species); var c11 = FloatVector .zero(species); var c12 = FloatVector .zero(species)
195+ var c20 = FloatVector .zero(species); var c21 = FloatVector .zero(species); var c22 = FloatVector .zero(species)
196+ var c30 = FloatVector .zero(species); var c31 = FloatVector .zero(species); var c32 = FloatVector .zero(species)
197+
198+ var idx = 0
199+ while (idx < loopBound) {
200+ val va0 = FloatVector .fromArray(species, a, a0Base + idx)
201+ val va1 = FloatVector .fromArray(species, a, a1Base + idx)
202+ val va2 = FloatVector .fromArray(species, a, a2Base + idx)
203+ val va3 = FloatVector .fromArray(species, a, a3Base + idx)
204+
205+ val vb0 = FloatVector .fromArray(species, bt, b0Base + idx)
206+ c00 = va0.fma(vb0, c00); c10 = va1.fma(vb0, c10); c20 = va2.fma(vb0, c20); c30 = va3.fma(vb0, c30)
207+
208+ val vb1 = FloatVector .fromArray(species, bt, b1Base + idx)
209+ c01 = va0.fma(vb1, c01); c11 = va1.fma(vb1, c11); c21 = va2.fma(vb1, c21); c31 = va3.fma(vb1, c31)
210+
211+ val vb2 = FloatVector .fromArray(species, bt, b2Base + idx)
212+ c02 = va0.fma(vb2, c02); c12 = va1.fma(vb2, c12); c22 = va2.fma(vb2, c22); c32 = va3.fma(vb2, c32)
213+
214+ idx + = step
215+ }
216+
217+ var s00 = c00.reduceLanes(VectorOperators .ADD ); var s01 = c01.reduceLanes(VectorOperators .ADD ); var s02 = c02.reduceLanes(VectorOperators .ADD )
218+ var s10 = c10.reduceLanes(VectorOperators .ADD ); var s11 = c11.reduceLanes(VectorOperators .ADD ); var s12 = c12.reduceLanes(VectorOperators .ADD )
219+ var s20 = c20.reduceLanes(VectorOperators .ADD ); var s21 = c21.reduceLanes(VectorOperators .ADD ); var s22 = c22.reduceLanes(VectorOperators .ADD )
220+ var s30 = c30.reduceLanes(VectorOperators .ADD ); var s31 = c31.reduceLanes(VectorOperators .ADD ); var s32 = c32.reduceLanes(VectorOperators .ADD )
221+
222+ while (idx < kLen) {
223+ val av0 = a[a0Base + idx]; val av1 = a[a1Base + idx]; val av2 = a[a2Base + idx]; val av3 = a[a3Base + idx]
224+ val bv0 = bt[b0Base + idx]; val bv1 = bt[b1Base + idx]; val bv2 = bt[b2Base + idx]
225+ s00 + = av0 * bv0; s10 + = av1 * bv0; s20 + = av2 * bv0; s30 + = av3 * bv0
226+ s01 + = av0 * bv1; s11 + = av1 * bv1; s21 + = av2 * bv1; s31 + = av3 * bv1
227+ s02 + = av0 * bv2; s12 + = av1 * bv2; s22 + = av2 * bv2; s32 + = av3 * bv2
228+ idx++
229+ }
230+
231+ out [outRow0 + jj] + = s00; out [outRow0 + jj + 1 ] + = s01; out [outRow0 + jj + 2 ] + = s02
232+ out [outRow1 + jj] + = s10; out [outRow1 + jj + 1 ] + = s11; out [outRow1 + jj + 2 ] + = s12
233+ out [outRow2 + jj] + = s20; out [outRow2 + jj + 1 ] + = s21; out [outRow2 + jj + 2 ] + = s22
234+ out [outRow3 + jj] + = s30; out [outRow3 + jj + 1 ] + = s31; out [outRow3 + jj + 2 ] + = s32
235+
236+ jj + = 3
237+ }
238+ ii + = 4
239+ }
240+ }
241+
242+ /* * 2 × 2 microkernel: 4 accumulators, 2 A loads + 2 B loads + 4 FMAs per step. */
243+ private fun gemm2x2 (
244+ a : FloatArray , aOffset : Int , aStride : Int ,
245+ bt : FloatArray , btStride : Int ,
246+ out : FloatArray , outOffset : Int , outStride : Int ,
247+ m0 : Int , m1 : Int , n0 : Int , n1 : Int ,
248+ kStart : Int , kLen : Int ,
249+ ) {
250+ val step = species.length()
251+ val loopBound = species.loopBound(kLen)
252+ var ii = m0
253+ while (ii < m1) {
254+ val a0Base = aOffset + ii * aStride + kStart
255+ val a1Base = a0Base + aStride
256+ val outRow0 = outOffset + ii * outStride
257+ val outRow1 = outRow0 + outStride
258+ var jj = n0
259+ while (jj < n1) {
260+ val b0Base = jj * btStride + kStart
261+ val b1Base = b0Base + btStride
262+
263+ var c00 = FloatVector .zero(species); var c01 = FloatVector .zero(species)
264+ var c10 = FloatVector .zero(species); var c11 = FloatVector .zero(species)
265+
266+ var idx = 0
267+ while (idx < loopBound) {
268+ val va0 = FloatVector .fromArray(species, a, a0Base + idx)
269+ val va1 = FloatVector .fromArray(species, a, a1Base + idx)
270+ val vb0 = FloatVector .fromArray(species, bt, b0Base + idx)
271+ val vb1 = FloatVector .fromArray(species, bt, b1Base + idx)
272+ c00 = va0.fma(vb0, c00); c10 = va1.fma(vb0, c10)
273+ c01 = va0.fma(vb1, c01); c11 = va1.fma(vb1, c11)
274+ idx + = step
275+ }
276+
277+ var s00 = c00.reduceLanes(VectorOperators .ADD ); var s01 = c01.reduceLanes(VectorOperators .ADD )
278+ var s10 = c10.reduceLanes(VectorOperators .ADD ); var s11 = c11.reduceLanes(VectorOperators .ADD )
279+
280+ while (idx < kLen) {
281+ val av0 = a[a0Base + idx]; val av1 = a[a1Base + idx]
282+ val bv0 = bt[b0Base + idx]; val bv1 = bt[b1Base + idx]
283+ s00 + = av0 * bv0; s10 + = av1 * bv0
284+ s01 + = av0 * bv1; s11 + = av1 * bv1
285+ idx++
286+ }
287+
288+ out [outRow0 + jj] + = s00; out [outRow0 + jj + 1 ] + = s01
289+ out [outRow1 + jj] + = s10; out [outRow1 + jj + 1 ] + = s11
290+
291+ jj + = 2
292+ }
293+ ii + = 2
294+ }
295+ }
296+
297+ /* * 2 × 1 microkernel: 2 accumulators, 2 A loads + 1 B load + 2 FMAs per step. */
298+ private fun gemm2x1 (
299+ a : FloatArray , aOffset : Int , aStride : Int ,
300+ bt : FloatArray , btStride : Int ,
301+ out : FloatArray , outOffset : Int , outStride : Int ,
302+ m0 : Int , m1 : Int , n0 : Int , n1 : Int ,
303+ kStart : Int , kLen : Int ,
304+ ) {
305+ val step = species.length()
306+ val loopBound = species.loopBound(kLen)
307+ var ii = m0
308+ while (ii < m1) {
309+ val a0Base = aOffset + ii * aStride + kStart
310+ val a1Base = a0Base + aStride
311+ val outRow0 = outOffset + ii * outStride
312+ val outRow1 = outRow0 + outStride
313+ for (jj in n0 until n1) {
314+ val b0Base = jj * btStride + kStart
315+
316+ var c0 = FloatVector .zero(species)
317+ var c1 = FloatVector .zero(species)
318+
319+ var idx = 0
320+ while (idx < loopBound) {
321+ val va0 = FloatVector .fromArray(species, a, a0Base + idx)
322+ val va1 = FloatVector .fromArray(species, a, a1Base + idx)
323+ val vb = FloatVector .fromArray(species, bt, b0Base + idx)
324+ c0 = va0.fma(vb, c0); c1 = va1.fma(vb, c1)
325+ idx + = step
326+ }
327+
328+ var s0 = c0.reduceLanes(VectorOperators .ADD )
329+ var s1 = c1.reduceLanes(VectorOperators .ADD )
330+
331+ while (idx < kLen) {
332+ val bv = bt[b0Base + idx]
333+ s0 + = a[a0Base + idx] * bv
334+ s1 + = a[a1Base + idx] * bv
335+ idx++
336+ }
337+
338+ out [outRow0 + jj] + = s0
339+ out [outRow1 + jj] + = s1
340+ }
341+ ii + = 2
342+ }
343+ }
344+
345+ /* * 1 × 2 microkernel: 2 accumulators, 1 A load + 2 B loads + 2 FMAs per step. */
346+ private fun gemm1x2 (
347+ a : FloatArray , aOffset : Int , aStride : Int ,
348+ bt : FloatArray , btStride : Int ,
349+ out : FloatArray , outOffset : Int , outStride : Int ,
350+ m0 : Int , m1 : Int , n0 : Int , n1 : Int ,
351+ kStart : Int , kLen : Int ,
352+ ) {
353+ val step = species.length()
354+ val loopBound = species.loopBound(kLen)
355+ for (ii in m0 until m1) {
356+ val aBase = aOffset + ii * aStride + kStart
357+ val outRow = outOffset + ii * outStride
358+ var jj = n0
359+ while (jj < n1) {
360+ val b0Base = jj * btStride + kStart
361+ val b1Base = b0Base + btStride
362+
363+ var c0 = FloatVector .zero(species)
364+ var c1 = FloatVector .zero(species)
365+
366+ var idx = 0
367+ while (idx < loopBound) {
368+ val va = FloatVector .fromArray(species, a, aBase + idx)
369+ val vb0 = FloatVector .fromArray(species, bt, b0Base + idx)
370+ val vb1 = FloatVector .fromArray(species, bt, b1Base + idx)
371+ c0 = va.fma(vb0, c0); c1 = va.fma(vb1, c1)
372+ idx + = step
373+ }
374+
375+ var s0 = c0.reduceLanes(VectorOperators .ADD )
376+ var s1 = c1.reduceLanes(VectorOperators .ADD )
377+
378+ while (idx < kLen) {
379+ val av = a[aBase + idx]
380+ s0 + = av * bt[b0Base + idx]
381+ s1 + = av * bt[b1Base + idx]
382+ idx++
383+ }
384+
385+ out [outRow + jj] + = s0
386+ out [outRow + jj + 1 ] + = s1
387+
388+ jj + = 2
389+ }
390+ }
391+ }
392+
393+ /* * 1 × 1 microkernel: single-cell fallback. Equivalent to the pre-change inner loop. */
394+ private fun gemm1x1 (
395+ a : FloatArray , aOffset : Int , aStride : Int ,
396+ bt : FloatArray , btStride : Int ,
397+ out : FloatArray , outOffset : Int , outStride : Int ,
398+ m0 : Int , m1 : Int , n0 : Int , n1 : Int ,
399+ kStart : Int , kLen : Int ,
400+ ) {
401+ val step = species.length()
402+ val loopBound = species.loopBound(kLen)
403+ for (ii in m0 until m1) {
404+ val aBase = aOffset + ii * aStride + kStart
405+ val outRow = outOffset + ii * outStride
406+ for (jj in n0 until n1) {
407+ val bBase = jj * btStride + kStart
408+ var acc = FloatVector .zero(species)
409+ var idx = 0
410+ while (idx < loopBound) {
411+ val va = FloatVector .fromArray(species, a, aBase + idx)
412+ val vb = FloatVector .fromArray(species, bt, bBase + idx)
413+ acc = va.fma(vb, acc)
414+ idx + = step
415+ }
416+ var sum = acc.reduceLanes(VectorOperators .ADD )
417+ while (idx < kLen) {
418+ sum + = a[aBase + idx] * bt[bBase + idx]
419+ idx++
420+ }
421+ out [outRow + jj] + = sum
422+ }
423+ }
424+ }
112425}
0 commit comments