Skip to content

Commit ee22828

Browse files
michalharakalclaude
andcommitted
perf(kernel): cache-block PanamaVectorMatmulKernel (8x8x128 tiles)
Ports the (m, n, k)-tile blocking pattern from JvmVectorKernels.matmulFloatBlocked into the SPI kernel: 8x8 output tiles, 128-wide K-stripes. Output is zeroed once up front and the K-tile loop accumulates via `+=`, which keeps the contract "fully overwrite the m x n block" intact and avoids the gnarly "init only on first tile" gating in the original blocked kernel. Closes the perf gap that #558 flagged between the SPI kernel and the existing production blocked path. After this change the SPI kernel matches or beats the production path within JMH noise — routing DefaultCpuOpsJvm.matmul through KernelRegistry won't show a regression any more. KernelMatmulBench (JDK 21.0.10, M-series macOS): size scalar panama speedup prior panama (simple) 256 9.77ms 1.13ms 8.61x 1.36ms (-16%) 512 81.55ms 9.47ms 8.62x 13.62ms (-30%) 1024 865.54ms 79.88ms 10.83x 118.24ms (-32%) vs production MatmulBench (vector=true, blas=false) same run: size SPI tiled production blocked delta 256 1.13ms 1.24ms SPI 8.5% faster 512 9.47ms 10.38ms SPI 8.8% faster 1024 79.88ms 78.32ms SPI 2% slower (within noise) Existing parity tests (PanamaVectorMatmulKernelTest, including the 31x17x23 randomized case that exercises partial tiles in all three dims) pass unchanged within the 1e-5*k tolerance. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b3b0c05 commit ee22828

1 file changed

Lines changed: 59 additions & 33 deletions

File tree

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorMatmulKernel.kt

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,20 @@ import sk.ainet.backend.api.kernel.Fp32MatmulKernel
1212
*
1313
* Strategy:
1414
* - Pack `B` into a transposed buffer `bt` of shape `(n, k)` so the
15-
* inner reduction streams contiguously over `k` for both operands —
16-
* `a[i, kk]` walks one row of `A` and `bt[j, kk]` walks one row of
17-
* the packed transpose.
18-
* - Inner loop is a vector-width FMA accumulator (`v.fma(w, acc)`),
19-
* reduced once per `(i, j)` pair via `reduceLanes(ADD)`.
20-
* - Tail elements that don't fill a vector lane are handled in scalar.
15+
* inner reduction streams contiguously over `k` for both operands.
16+
* - Cache-block the `(m, n, k)` iteration space with tiles
17+
* ([TILE_M], [TILE_N], [TILE_K]). Default 8×8×128 keeps a working
18+
* set well under L1 — eight A rows × 128 floats + eight Bᵀ rows ×
19+
* 128 floats ≈ 8 KB, within typical 32 KB L1.
20+
* - Inner reduction is a vector-width FMA accumulator
21+
* (`v.fma(w, acc)`), reduced via `reduceLanes(ADD)` once per
22+
* `(i, j)` cell per K-tile. Tail elements that don't fill a vector
23+
* lane are handled in scalar.
24+
* - Output is zeroed once up front; per-tile work accumulates via `+=`
25+
* so the K-loop can split across multiple tiles cleanly.
2126
*
22-
* The B-pack is `O(n * k)` floats per call; that's cheap relative to
23-
* the `O(m * n * k)` FLOPs but still allocates each invocation. A
27+
* The B-pack is `O(n * k)` floats per call; cheap relative to the
28+
* `O(m * n * k)` FLOPs but still allocates each invocation. A
2429
* scratch-pool integration is out of scope for this kernel and lives
2530
* one layer up (see `ScratchPool` SPI in `skainet-lang-core`).
2631
*
@@ -31,6 +36,10 @@ import sk.ainet.backend.api.kernel.Fp32MatmulKernel
3136
public object PanamaVectorMatmulKernel : Fp32MatmulKernel {
3237
private val species: VectorSpecies<Float> = FloatVector.SPECIES_PREFERRED
3338

39+
private const val TILE_M = 8
40+
private const val TILE_N = 8
41+
private const val TILE_K = 128
42+
3443
override fun matmul(
3544
a: FloatArray, aOffset: Int, aStride: Int,
3645
b: FloatArray, bOffset: Int, bStride: Int,
@@ -41,13 +50,14 @@ public object PanamaVectorMatmulKernel : Fp32MatmulKernel {
4150
"PanamaVectorMatmulKernel: m, n, k must be non-negative; got m=$m n=$n k=$k"
4251
}
4352
if (m == 0 || n == 0) return
44-
if (k == 0) {
45-
for (i in 0 until m) {
46-
val rowOff = outOffset + i * outStride
47-
for (j in 0 until n) out[rowOff + j] = 0f
48-
}
49-
return
53+
// Zero the m×n output block once. The K-tile loop accumulates
54+
// via `+=`, so the contract "fully overwrite the output block"
55+
// is preserved even when k == 0 (no tile loop runs).
56+
for (i in 0 until m) {
57+
val rowOff = outOffset + i * outStride
58+
for (j in 0 until n) out[rowOff + j] = 0f
5059
}
60+
if (k == 0) return
5161

5262
// Pack B^T: bt[j, kk] = b[kk, j].
5363
val bt = FloatArray(n * k)
@@ -59,28 +69,44 @@ public object PanamaVectorMatmulKernel : Fp32MatmulKernel {
5969
}
6070

6171
val step = species.length()
62-
val loopBound = species.loopBound(k)
6372

64-
for (i in 0 until m) {
65-
val aRow = aOffset + i * aStride
66-
val outRow = outOffset + i * outStride
67-
for (j in 0 until n) {
68-
val btRow = j * k
69-
var acc = FloatVector.zero(species)
70-
var idx = 0
71-
while (idx < loopBound) {
72-
val va = FloatVector.fromArray(species, a, aRow + idx)
73-
val vb = FloatVector.fromArray(species, bt, btRow + idx)
74-
acc = va.fma(vb, acc)
75-
idx += step
76-
}
77-
var sum = acc.reduceLanes(VectorOperators.ADD)
78-
while (idx < k) {
79-
sum += a[aRow + idx] * bt[btRow + idx]
80-
idx++
73+
var mTile = 0
74+
while (mTile < m) {
75+
val mEnd = minOf(mTile + TILE_M, m)
76+
var nTile = 0
77+
while (nTile < n) {
78+
val nEnd = minOf(nTile + TILE_N, n)
79+
var kTile = 0
80+
while (kTile < k) {
81+
val kEnd = minOf(kTile + TILE_K, k)
82+
val kLen = kEnd - kTile
83+
val loopBound = species.loopBound(kLen)
84+
for (i in mTile until mEnd) {
85+
val aRowBase = aOffset + i * aStride + kTile
86+
val outRowBase = outOffset + i * outStride
87+
for (j in nTile until nEnd) {
88+
val btRowBase = j * k + kTile
89+
var acc = FloatVector.zero(species)
90+
var idx = 0
91+
while (idx < loopBound) {
92+
val va = FloatVector.fromArray(species, a, aRowBase + idx)
93+
val vb = FloatVector.fromArray(species, bt, btRowBase + idx)
94+
acc = va.fma(vb, acc)
95+
idx += step
96+
}
97+
var sum = acc.reduceLanes(VectorOperators.ADD)
98+
while (idx < kLen) {
99+
sum += a[aRowBase + idx] * bt[btRowBase + idx]
100+
idx++
101+
}
102+
out[outRowBase + j] += sum
103+
}
104+
}
105+
kTile = kEnd
81106
}
82-
out[outRow + j] = sum
107+
nTile = nEnd
83108
}
109+
mTile = mEnd
84110
}
85111
}
86112
}

0 commit comments

Comments
 (0)