Skip to content

Commit 87589af

Browse files
michalharakalclaude
andcommitted
Tile-microkernel dispatch in PanamaVectorMatmulKernel (~1.7x speedup)
Port the `mnpack` recursive tile dispatch from tinyBLAS (Justine Tunney / llamafile `sgemm.cpp`) into the FP32 Panama Vector kernel. Each `(TILE_M x TILE_N)` sub-tile now dispatches into AVX2-friendly RM x RN microkernels — `gemm4x3`, `gemm2x2`, `gemm2x1`, `gemm1x2`, `gemm1x1` — where every A-row load amortizes across RN columns and every B-column load across RM rows. The largest microkernel keeps 12 FloatVector accumulators live; smaller microkernels cover residual rows and columns that don't divide evenly. No SPI change. B pre-pack and TILE_M/TILE_N/TILE_K outer blocking unchanged. Output accumulation semantics unchanged. ScalarMatmulKernel untouched. Pure-JVM, no native code. Measured on the engine benchmark suite (Intel i7-9750H, AVX2, JDK 21, 8 warmups + 5 measured runs, two independent runs): engine-kernel-matmul panama (1024^3 FP32): 14.06 -> 23.0-23.9 GFLOPS (~1.65x, CoV ~1.5%) engine-fp32-gemm panama (1024^3 via ctx.ops): 11.62 -> 22.3-23.0 (~1.95x, CoV ~1.5%) engine-kernel-matmul scalar (sanity): 1.05 -> 1.13-1.20 (unchanged) 23 GFLOPS on AVX2 single-core ~= 62% of the chip's theoretical FP32 peak, vs 38% before. Two-run reproducibility confirms the gain isn't a measurement fluke. Parity tests extended with three new shapes (7x11x13, 17x19x255, 39x31x129) that cascade through every microkernel arm; all existing shapes still pass within max(1e-5*k, 1e-5f) tolerance. Deferred to follow-up sessions per the implementation plan: - Kahan-corrected FMA precision flag (SPI change) - Cooperative ith/nth threading (SPI change) - AVX-512 microkernel set (5x5, 5x4, ...) — needs an AVX-512 lane Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 748e92d commit 87589af

2 files changed

Lines changed: 374 additions & 26 deletions

File tree

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorMatmulKernel.kt

Lines changed: 339 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ import sk.ainet.backend.api.kernel.Fp32MatmulKernel
1717
* ([TILE_M], [TILE_N], [TILE_K]). Default 8×8×128 keeps a working
1818
* set well under L1 — eight A rows × 128 floats + eight Bᵀ rows ×
1919
* 128 floats ≈ 8 KB, within typical 32 KB L1.
20+
* - Within each (TILE_M × TILE_N) sub-tile, [mnpack] recursively
21+
* dispatches into `RM × RN` micro-kernels — `gemm4x3`, `gemm2x2`,
22+
* `gemm2x1`, `gemm1x2`, `gemm1x1`. Each micro-kernel keeps
23+
* `RM × RN` `FloatVector` accumulators in locals and amortizes
24+
* every A-row load across `RN` columns and every B-column load
25+
* across `RM` rows. This mirrors the tile-dispatch pattern from
26+
* tinyBLAS (`sgemm.cpp`, Justine Tunney / llamafile).
27+
* - On AVX2 the largest microkernel that fits inside 16 YMM registers
28+
* is `4 × 3` (12 accumulators + at most 4 A vectors + 1 B vector
29+
* live at once). Smaller microkernels cover residual rows and
30+
* columns that don't divide evenly into the larger tile shape.
2031
* - Inner reduction is a vector-width FMA accumulator
2132
* (`v.fma(w, acc)`), reduced via `reduceLanes(ADD)` once per
2233
* `(i, j)` cell per K-tile. Tail elements that don't fill a vector
@@ -59,7 +70,7 @@ public object PanamaVectorMatmulKernel : Fp32MatmulKernel {
5970
}
6071
if (k == 0) return
6172

62-
// Pack B^T: bt[j, kk] = b[kk, j].
73+
// Pack B^T: bt[j, kk] = b[kk, j]. Row stride in bt is k.
6374
val bt = FloatArray(n * k)
6475
for (kk in 0 until k) {
6576
val src = bOffset + kk * bStride
@@ -68,8 +79,6 @@ public object PanamaVectorMatmulKernel : Fp32MatmulKernel {
6879
}
6980
}
7081

71-
val step = species.length()
72-
7382
var mTile = 0
7483
while (mTile < m) {
7584
val mEnd = minOf(mTile + TILE_M, m)
@@ -79,34 +88,338 @@ public object PanamaVectorMatmulKernel : Fp32MatmulKernel {
7988
var kTile = 0
8089
while (kTile < k) {
8190
val kEnd = minOf(kTile + TILE_K, k)
82-
val kLen = kEnd - kTile
83-
val loopBound = species.loopBound(kLen)
84-
for (i in mTile until mEnd) {
85-
val aRowBase = aOffset + i * aStride + kTile
86-
val outRowBase = outOffset + i * outStride
87-
for (j in nTile until nEnd) {
88-
val btRowBase = j * k + kTile
89-
var acc = FloatVector.zero(species)
90-
var idx = 0
91-
while (idx < loopBound) {
92-
val va = FloatVector.fromArray(species, a, aRowBase + idx)
93-
val vb = FloatVector.fromArray(species, bt, btRowBase + idx)
94-
acc = va.fma(vb, acc)
95-
idx += step
96-
}
97-
var sum = acc.reduceLanes(VectorOperators.ADD)
98-
while (idx < kLen) {
99-
sum += a[aRowBase + idx] * bt[btRowBase + idx]
100-
idx++
101-
}
102-
out[outRowBase + j] += sum
103-
}
104-
}
91+
mnpack(
92+
a, aOffset, aStride,
93+
bt, k,
94+
out, outOffset, outStride,
95+
mTile, mEnd, nTile, nEnd,
96+
kTile, kEnd - kTile,
97+
)
10598
kTile = kEnd
10699
}
107100
nTile = nEnd
108101
}
109102
mTile = mEnd
110103
}
111104
}
105+
106+
/**
107+
* Recursive (m, n) tile dispatch. Picks the largest microkernel
108+
* shape `(RM, RN)` that fits the residual `(m1-m0, n1-n0)`, calls it
109+
* over the aligned sub-rectangle `[m0..mp) × [n0..np)`, then recurses
110+
* on the residual rows `[mp..m1) × [n0..np)` and the residual columns
111+
* `[m0..m1) × [np..n1)`. Mirrors the tinyBLAS `mnpack` switch but
112+
* uses only the AVX2-friendly microkernel set (16 vector registers).
113+
*/
114+
private fun mnpack(
115+
a: FloatArray, aOffset: Int, aStride: Int,
116+
bt: FloatArray, btStride: Int,
117+
out: FloatArray, outOffset: Int, outStride: Int,
118+
m0: Int, m1: Int, n0: Int, n1: Int,
119+
kStart: Int, kLen: Int,
120+
) {
121+
if (m1 <= m0 || n1 <= n0) return
122+
123+
val rm = minOf(m1 - m0, 4)
124+
val rn = minOf(n1 - n0, 3)
125+
val mc: Int
126+
val nc: Int
127+
when ((rm shl 4) or rn) {
128+
0x43 -> {
129+
mc = 4; nc = 3
130+
gemm4x3(a, aOffset, aStride, bt, btStride, out, outOffset, outStride,
131+
m0, m0 + ((m1 - m0) / mc) * mc, n0, n0 + ((n1 - n0) / nc) * nc, kStart, kLen)
132+
}
133+
0x42, 0x33, 0x32, 0x23, 0x22 -> {
134+
mc = 2; nc = 2
135+
gemm2x2(a, aOffset, aStride, bt, btStride, out, outOffset, outStride,
136+
m0, m0 + ((m1 - m0) / mc) * mc, n0, n0 + ((n1 - n0) / nc) * nc, kStart, kLen)
137+
}
138+
0x41, 0x31, 0x21 -> {
139+
mc = 2; nc = 1
140+
gemm2x1(a, aOffset, aStride, bt, btStride, out, outOffset, outStride,
141+
m0, m0 + ((m1 - m0) / mc) * mc, n0, n0 + ((n1 - n0) / nc) * nc, kStart, kLen)
142+
}
143+
0x13, 0x12 -> {
144+
mc = 1; nc = 2
145+
gemm1x2(a, aOffset, aStride, bt, btStride, out, outOffset, outStride,
146+
m0, m0 + ((m1 - m0) / mc) * mc, n0, n0 + ((n1 - n0) / nc) * nc, kStart, kLen)
147+
}
148+
0x11 -> {
149+
mc = 1; nc = 1
150+
gemm1x1(a, aOffset, aStride, bt, btStride, out, outOffset, outStride,
151+
m0, m0 + ((m1 - m0) / mc) * mc, n0, n0 + ((n1 - n0) / nc) * nc, kStart, kLen)
152+
}
153+
else -> return
154+
}
155+
val mp = m0 + ((m1 - m0) / mc) * mc
156+
val np = n0 + ((n1 - n0) / nc) * nc
157+
if (mp < m1) mnpack(a, aOffset, aStride, bt, btStride, out, outOffset, outStride,
158+
mp, m1, n0, np, kStart, kLen)
159+
if (np < n1) mnpack(a, aOffset, aStride, bt, btStride, out, outOffset, outStride,
160+
m0, m1, np, n1, kStart, kLen)
161+
}
162+
163+
/**
164+
* Largest AVX2-friendly microkernel: 4 rows × 3 cols, 12 accumulators.
165+
* Loads 4 A vectors and 3 B vectors per `k` step, issues 12 FMAs.
166+
* Caller guarantees `(m1 - m0)` is a multiple of 4 and `(n1 - n0)` of 3.
167+
*/
168+
private fun gemm4x3(
169+
a: FloatArray, aOffset: Int, aStride: Int,
170+
bt: FloatArray, btStride: Int,
171+
out: FloatArray, outOffset: Int, outStride: Int,
172+
m0: Int, m1: Int, n0: Int, n1: Int,
173+
kStart: Int, kLen: Int,
174+
) {
175+
val step = species.length()
176+
val loopBound = species.loopBound(kLen)
177+
var ii = m0
178+
while (ii < m1) {
179+
val a0Base = aOffset + ii * aStride + kStart
180+
val a1Base = a0Base + aStride
181+
val a2Base = a1Base + aStride
182+
val a3Base = a2Base + aStride
183+
val outRow0 = outOffset + ii * outStride
184+
val outRow1 = outRow0 + outStride
185+
val outRow2 = outRow1 + outStride
186+
val outRow3 = outRow2 + outStride
187+
var jj = n0
188+
while (jj < n1) {
189+
val b0Base = jj * btStride + kStart
190+
val b1Base = b0Base + btStride
191+
val b2Base = b1Base + btStride
192+
193+
var c00 = FloatVector.zero(species); var c01 = FloatVector.zero(species); var c02 = FloatVector.zero(species)
194+
var c10 = FloatVector.zero(species); var c11 = FloatVector.zero(species); var c12 = FloatVector.zero(species)
195+
var c20 = FloatVector.zero(species); var c21 = FloatVector.zero(species); var c22 = FloatVector.zero(species)
196+
var c30 = FloatVector.zero(species); var c31 = FloatVector.zero(species); var c32 = FloatVector.zero(species)
197+
198+
var idx = 0
199+
while (idx < loopBound) {
200+
val va0 = FloatVector.fromArray(species, a, a0Base + idx)
201+
val va1 = FloatVector.fromArray(species, a, a1Base + idx)
202+
val va2 = FloatVector.fromArray(species, a, a2Base + idx)
203+
val va3 = FloatVector.fromArray(species, a, a3Base + idx)
204+
205+
val vb0 = FloatVector.fromArray(species, bt, b0Base + idx)
206+
c00 = va0.fma(vb0, c00); c10 = va1.fma(vb0, c10); c20 = va2.fma(vb0, c20); c30 = va3.fma(vb0, c30)
207+
208+
val vb1 = FloatVector.fromArray(species, bt, b1Base + idx)
209+
c01 = va0.fma(vb1, c01); c11 = va1.fma(vb1, c11); c21 = va2.fma(vb1, c21); c31 = va3.fma(vb1, c31)
210+
211+
val vb2 = FloatVector.fromArray(species, bt, b2Base + idx)
212+
c02 = va0.fma(vb2, c02); c12 = va1.fma(vb2, c12); c22 = va2.fma(vb2, c22); c32 = va3.fma(vb2, c32)
213+
214+
idx += step
215+
}
216+
217+
var s00 = c00.reduceLanes(VectorOperators.ADD); var s01 = c01.reduceLanes(VectorOperators.ADD); var s02 = c02.reduceLanes(VectorOperators.ADD)
218+
var s10 = c10.reduceLanes(VectorOperators.ADD); var s11 = c11.reduceLanes(VectorOperators.ADD); var s12 = c12.reduceLanes(VectorOperators.ADD)
219+
var s20 = c20.reduceLanes(VectorOperators.ADD); var s21 = c21.reduceLanes(VectorOperators.ADD); var s22 = c22.reduceLanes(VectorOperators.ADD)
220+
var s30 = c30.reduceLanes(VectorOperators.ADD); var s31 = c31.reduceLanes(VectorOperators.ADD); var s32 = c32.reduceLanes(VectorOperators.ADD)
221+
222+
while (idx < kLen) {
223+
val av0 = a[a0Base + idx]; val av1 = a[a1Base + idx]; val av2 = a[a2Base + idx]; val av3 = a[a3Base + idx]
224+
val bv0 = bt[b0Base + idx]; val bv1 = bt[b1Base + idx]; val bv2 = bt[b2Base + idx]
225+
s00 += av0 * bv0; s10 += av1 * bv0; s20 += av2 * bv0; s30 += av3 * bv0
226+
s01 += av0 * bv1; s11 += av1 * bv1; s21 += av2 * bv1; s31 += av3 * bv1
227+
s02 += av0 * bv2; s12 += av1 * bv2; s22 += av2 * bv2; s32 += av3 * bv2
228+
idx++
229+
}
230+
231+
out[outRow0 + jj] += s00; out[outRow0 + jj + 1] += s01; out[outRow0 + jj + 2] += s02
232+
out[outRow1 + jj] += s10; out[outRow1 + jj + 1] += s11; out[outRow1 + jj + 2] += s12
233+
out[outRow2 + jj] += s20; out[outRow2 + jj + 1] += s21; out[outRow2 + jj + 2] += s22
234+
out[outRow3 + jj] += s30; out[outRow3 + jj + 1] += s31; out[outRow3 + jj + 2] += s32
235+
236+
jj += 3
237+
}
238+
ii += 4
239+
}
240+
}
241+
242+
/** 2 × 2 microkernel: 4 accumulators, 2 A loads + 2 B loads + 4 FMAs per step. */
243+
private fun gemm2x2(
244+
a: FloatArray, aOffset: Int, aStride: Int,
245+
bt: FloatArray, btStride: Int,
246+
out: FloatArray, outOffset: Int, outStride: Int,
247+
m0: Int, m1: Int, n0: Int, n1: Int,
248+
kStart: Int, kLen: Int,
249+
) {
250+
val step = species.length()
251+
val loopBound = species.loopBound(kLen)
252+
var ii = m0
253+
while (ii < m1) {
254+
val a0Base = aOffset + ii * aStride + kStart
255+
val a1Base = a0Base + aStride
256+
val outRow0 = outOffset + ii * outStride
257+
val outRow1 = outRow0 + outStride
258+
var jj = n0
259+
while (jj < n1) {
260+
val b0Base = jj * btStride + kStart
261+
val b1Base = b0Base + btStride
262+
263+
var c00 = FloatVector.zero(species); var c01 = FloatVector.zero(species)
264+
var c10 = FloatVector.zero(species); var c11 = FloatVector.zero(species)
265+
266+
var idx = 0
267+
while (idx < loopBound) {
268+
val va0 = FloatVector.fromArray(species, a, a0Base + idx)
269+
val va1 = FloatVector.fromArray(species, a, a1Base + idx)
270+
val vb0 = FloatVector.fromArray(species, bt, b0Base + idx)
271+
val vb1 = FloatVector.fromArray(species, bt, b1Base + idx)
272+
c00 = va0.fma(vb0, c00); c10 = va1.fma(vb0, c10)
273+
c01 = va0.fma(vb1, c01); c11 = va1.fma(vb1, c11)
274+
idx += step
275+
}
276+
277+
var s00 = c00.reduceLanes(VectorOperators.ADD); var s01 = c01.reduceLanes(VectorOperators.ADD)
278+
var s10 = c10.reduceLanes(VectorOperators.ADD); var s11 = c11.reduceLanes(VectorOperators.ADD)
279+
280+
while (idx < kLen) {
281+
val av0 = a[a0Base + idx]; val av1 = a[a1Base + idx]
282+
val bv0 = bt[b0Base + idx]; val bv1 = bt[b1Base + idx]
283+
s00 += av0 * bv0; s10 += av1 * bv0
284+
s01 += av0 * bv1; s11 += av1 * bv1
285+
idx++
286+
}
287+
288+
out[outRow0 + jj] += s00; out[outRow0 + jj + 1] += s01
289+
out[outRow1 + jj] += s10; out[outRow1 + jj + 1] += s11
290+
291+
jj += 2
292+
}
293+
ii += 2
294+
}
295+
}
296+
297+
/** 2 × 1 microkernel: 2 accumulators, 2 A loads + 1 B load + 2 FMAs per step. */
298+
private fun gemm2x1(
299+
a: FloatArray, aOffset: Int, aStride: Int,
300+
bt: FloatArray, btStride: Int,
301+
out: FloatArray, outOffset: Int, outStride: Int,
302+
m0: Int, m1: Int, n0: Int, n1: Int,
303+
kStart: Int, kLen: Int,
304+
) {
305+
val step = species.length()
306+
val loopBound = species.loopBound(kLen)
307+
var ii = m0
308+
while (ii < m1) {
309+
val a0Base = aOffset + ii * aStride + kStart
310+
val a1Base = a0Base + aStride
311+
val outRow0 = outOffset + ii * outStride
312+
val outRow1 = outRow0 + outStride
313+
for (jj in n0 until n1) {
314+
val b0Base = jj * btStride + kStart
315+
316+
var c0 = FloatVector.zero(species)
317+
var c1 = FloatVector.zero(species)
318+
319+
var idx = 0
320+
while (idx < loopBound) {
321+
val va0 = FloatVector.fromArray(species, a, a0Base + idx)
322+
val va1 = FloatVector.fromArray(species, a, a1Base + idx)
323+
val vb = FloatVector.fromArray(species, bt, b0Base + idx)
324+
c0 = va0.fma(vb, c0); c1 = va1.fma(vb, c1)
325+
idx += step
326+
}
327+
328+
var s0 = c0.reduceLanes(VectorOperators.ADD)
329+
var s1 = c1.reduceLanes(VectorOperators.ADD)
330+
331+
while (idx < kLen) {
332+
val bv = bt[b0Base + idx]
333+
s0 += a[a0Base + idx] * bv
334+
s1 += a[a1Base + idx] * bv
335+
idx++
336+
}
337+
338+
out[outRow0 + jj] += s0
339+
out[outRow1 + jj] += s1
340+
}
341+
ii += 2
342+
}
343+
}
344+
345+
/** 1 × 2 microkernel: 2 accumulators, 1 A load + 2 B loads + 2 FMAs per step. */
346+
private fun gemm1x2(
347+
a: FloatArray, aOffset: Int, aStride: Int,
348+
bt: FloatArray, btStride: Int,
349+
out: FloatArray, outOffset: Int, outStride: Int,
350+
m0: Int, m1: Int, n0: Int, n1: Int,
351+
kStart: Int, kLen: Int,
352+
) {
353+
val step = species.length()
354+
val loopBound = species.loopBound(kLen)
355+
for (ii in m0 until m1) {
356+
val aBase = aOffset + ii * aStride + kStart
357+
val outRow = outOffset + ii * outStride
358+
var jj = n0
359+
while (jj < n1) {
360+
val b0Base = jj * btStride + kStart
361+
val b1Base = b0Base + btStride
362+
363+
var c0 = FloatVector.zero(species)
364+
var c1 = FloatVector.zero(species)
365+
366+
var idx = 0
367+
while (idx < loopBound) {
368+
val va = FloatVector.fromArray(species, a, aBase + idx)
369+
val vb0 = FloatVector.fromArray(species, bt, b0Base + idx)
370+
val vb1 = FloatVector.fromArray(species, bt, b1Base + idx)
371+
c0 = va.fma(vb0, c0); c1 = va.fma(vb1, c1)
372+
idx += step
373+
}
374+
375+
var s0 = c0.reduceLanes(VectorOperators.ADD)
376+
var s1 = c1.reduceLanes(VectorOperators.ADD)
377+
378+
while (idx < kLen) {
379+
val av = a[aBase + idx]
380+
s0 += av * bt[b0Base + idx]
381+
s1 += av * bt[b1Base + idx]
382+
idx++
383+
}
384+
385+
out[outRow + jj] += s0
386+
out[outRow + jj + 1] += s1
387+
388+
jj += 2
389+
}
390+
}
391+
}
392+
393+
/** 1 × 1 microkernel: single-cell fallback. Equivalent to the pre-change inner loop. */
394+
private fun gemm1x1(
395+
a: FloatArray, aOffset: Int, aStride: Int,
396+
bt: FloatArray, btStride: Int,
397+
out: FloatArray, outOffset: Int, outStride: Int,
398+
m0: Int, m1: Int, n0: Int, n1: Int,
399+
kStart: Int, kLen: Int,
400+
) {
401+
val step = species.length()
402+
val loopBound = species.loopBound(kLen)
403+
for (ii in m0 until m1) {
404+
val aBase = aOffset + ii * aStride + kStart
405+
val outRow = outOffset + ii * outStride
406+
for (jj in n0 until n1) {
407+
val bBase = jj * btStride + kStart
408+
var acc = FloatVector.zero(species)
409+
var idx = 0
410+
while (idx < loopBound) {
411+
val va = FloatVector.fromArray(species, a, aBase + idx)
412+
val vb = FloatVector.fromArray(species, bt, bBase + idx)
413+
acc = va.fma(vb, acc)
414+
idx += step
415+
}
416+
var sum = acc.reduceLanes(VectorOperators.ADD)
417+
while (idx < kLen) {
418+
sum += a[aBase + idx] * bt[bBase + idx]
419+
idx++
420+
}
421+
out[outRow + jj] += sum
422+
}
423+
}
424+
}
112425
}

0 commit comments

Comments
 (0)