Skip to content

Commit 3313647

Browse files
michalharakalclaude
andcommitted
feat(kernel): add Panama Vector FP32 matmul provider (priority 50)
Implements `PanamaVectorMatmulKernel` (jdk.incubator.vector, FloatVector + fma + reduceLanes) and `PanamaVectorKernelProvider` against the kernel SPI from PR #554. Picks up automatically over `ScalarKernelProvider` once registered, and respects the existing `-Dskainet.cpu.vector.enabled=false` kill switch. Closes the M5 "Panama-first" half of the JVM perf milestone plan. Routing `DefaultCpuOpsJvm.matmul` through the SPI and adding a ServiceLoader-based auto-registration are deferred to follow-ups so this PR stays focused on the kernel itself. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a453045 commit 3313647

4 files changed

Lines changed: 352 additions & 0 deletions

File tree

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package sk.ainet.exec.kernel
2+
3+
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
4+
import sk.ainet.backend.api.kernel.KernelProvider
5+
import sk.ainet.exec.tensor.ops.JvmCpuBackendConfig
6+
7+
/**
8+
* JVM Vector API (`jdk.incubator.vector`) [KernelProvider]. Available
9+
* when the runtime is JDK 21+, the incubator module is loaded
10+
* (`--add-modules jdk.incubator.vector`), and the
11+
* `skainet.cpu.vector.enabled` kill switch hasn't been flipped to
12+
* `false`.
13+
*
14+
* Priority is `50` — above [ScalarKernelProvider] (`0`) and below a
15+
* future hand-tuned native provider (`100`). Concrete kernels are
16+
* exposed via the per-kernel accessors; today only [matmulFp32] is
17+
* specialized — other accessors fall back to `null` so callers can
18+
* cascade to a lower-priority provider.
19+
*
20+
* Registration is **manual** (per the kernel-SPI contract today): the
21+
* runtime that wants this provider must call
22+
* `KernelRegistry.register(PanamaVectorKernelProvider)` at startup.
23+
* Auto-registration via `ServiceLoader` will be layered on once a
24+
* second concrete JVM provider exists.
25+
*/
26+
public object PanamaVectorKernelProvider : KernelProvider {
27+
override val name: String = "panama-vector"
28+
override val priority: Int = 50
29+
30+
private val cachedAvailable: Boolean by lazy {
31+
isJdk21Plus() && isVectorApiClassLoaded()
32+
}
33+
34+
override fun isAvailable(): Boolean =
35+
cachedAvailable && JvmCpuBackendConfig.vectorEnabled
36+
37+
override fun matmulFp32(): Fp32MatmulKernel? =
38+
if (isAvailable()) PanamaVectorMatmulKernel else null
39+
40+
private fun isVectorApiClassLoaded(): Boolean = runCatching {
41+
Class.forName("jdk.incubator.vector.FloatVector")
42+
Class.forName("jdk.incubator.vector.VectorSpecies")
43+
true
44+
}.getOrElse { false }
45+
46+
private fun isJdk21Plus(): Boolean {
47+
val runtimeFeature = runCatching {
48+
val runtimeClass = Class.forName("java.lang.Runtime")
49+
val versionMethod = runtimeClass.getMethod("version")
50+
val versionObj = versionMethod.invoke(Runtime.getRuntime())
51+
val featureMethod = versionObj.javaClass.getMethod("feature")
52+
featureMethod.invoke(versionObj) as Int
53+
}.getOrNull()
54+
if (runtimeFeature != null) return runtimeFeature >= 21
55+
56+
val spec = System.getProperty("java.specification.version") ?: return false
57+
return spec.toIntOrNull()?.let { it >= 21 } ?: run {
58+
val major = spec.split('.', '-').firstOrNull()?.toIntOrNull() ?: return@run false
59+
major >= 21
60+
}
61+
}
62+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package sk.ainet.exec.kernel
2+
3+
import jdk.incubator.vector.FloatVector
4+
import jdk.incubator.vector.VectorOperators
5+
import jdk.incubator.vector.VectorSpecies
6+
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
7+
8+
/**
9+
* SIMD reference [Fp32MatmulKernel] implemented on the JDK Vector API
10+
* (JEP 338+, `jdk.incubator.vector`). Produces results that match
11+
* [ScalarMatmulKernel] within FP-rounding tolerance.
12+
*
13+
* Strategy:
14+
* - Pack `B` into a transposed buffer `bt` of shape `(n, k)` so the
15+
* inner reduction streams contiguously over `k` for both operands —
16+
* `a[i, kk]` walks one row of `A` and `bt[j, kk]` walks one row of
17+
* the packed transpose.
18+
* - Inner loop is a vector-width FMA accumulator (`v.fma(w, acc)`),
19+
* reduced once per `(i, j)` pair via `reduceLanes(ADD)`.
20+
* - Tail elements that don't fill a vector lane are handled in scalar.
21+
*
22+
* The B-pack is `O(n * k)` floats per call; that's cheap relative to
23+
* the `O(m * n * k)` FLOPs but still allocates each invocation. A
24+
* scratch-pool integration is out of scope for this kernel and lives
25+
* one layer up (see `ScratchPool` SPI in `skainet-lang-core`).
26+
*
27+
* Caller contract is identical to [Fp32MatmulKernel]: strides are in
28+
* floats, `out` is fully overwritten in the `m × n` block, and `k == 0`
29+
* zeros the output block.
30+
*/
31+
public object PanamaVectorMatmulKernel : Fp32MatmulKernel {
32+
private val species: VectorSpecies<Float> = FloatVector.SPECIES_PREFERRED
33+
34+
override fun matmul(
35+
a: FloatArray, aOffset: Int, aStride: Int,
36+
b: FloatArray, bOffset: Int, bStride: Int,
37+
out: FloatArray, outOffset: Int, outStride: Int,
38+
m: Int, n: Int, k: Int,
39+
) {
40+
require(m >= 0 && n >= 0 && k >= 0) {
41+
"PanamaVectorMatmulKernel: m, n, k must be non-negative; got m=$m n=$n k=$k"
42+
}
43+
if (m == 0 || n == 0) return
44+
if (k == 0) {
45+
for (i in 0 until m) {
46+
val rowOff = outOffset + i * outStride
47+
for (j in 0 until n) out[rowOff + j] = 0f
48+
}
49+
return
50+
}
51+
52+
// Pack B^T: bt[j, kk] = b[kk, j].
53+
val bt = FloatArray(n * k)
54+
for (kk in 0 until k) {
55+
val src = bOffset + kk * bStride
56+
for (j in 0 until n) {
57+
bt[j * k + kk] = b[src + j]
58+
}
59+
}
60+
61+
val step = species.length()
62+
val loopBound = species.loopBound(k)
63+
64+
for (i in 0 until m) {
65+
val aRow = aOffset + i * aStride
66+
val outRow = outOffset + i * outStride
67+
for (j in 0 until n) {
68+
val btRow = j * k
69+
var acc = FloatVector.zero(species)
70+
var idx = 0
71+
while (idx < loopBound) {
72+
val va = FloatVector.fromArray(species, a, aRow + idx)
73+
val vb = FloatVector.fromArray(species, bt, btRow + idx)
74+
acc = va.fma(vb, acc)
75+
idx += step
76+
}
77+
var sum = acc.reduceLanes(VectorOperators.ADD)
78+
while (idx < k) {
79+
sum += a[aRow + idx] * bt[btRow + idx]
80+
idx++
81+
}
82+
out[outRow + j] = sum
83+
}
84+
}
85+
}
86+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package sk.ainet.exec.kernel
2+
3+
import kotlin.test.AfterTest
4+
import kotlin.test.BeforeTest
5+
import kotlin.test.Test
6+
import kotlin.test.assertEquals
7+
import kotlin.test.assertSame
8+
import kotlin.test.assertTrue
9+
import sk.ainet.backend.api.kernel.KernelRegistry
10+
11+
class PanamaVectorKernelProviderTest {
12+
13+
@BeforeTest
14+
fun setUp() = KernelRegistry.clearForTesting()
15+
16+
@AfterTest
17+
fun tearDown() = KernelRegistry.clearForTesting()
18+
19+
@Test
20+
fun providerHasExpectedNameAndPriority() {
21+
assertEquals("panama-vector", PanamaVectorKernelProvider.name)
22+
assertEquals(50, PanamaVectorKernelProvider.priority)
23+
}
24+
25+
@Test
26+
fun isAvailableOnTestJdk() {
27+
// The cpu-backend test suite runs on JDK 21+ with the incubator
28+
// module on the module path (see jvm-cpu-jmh build script and the
29+
// project's JDK requirement). Vector should be available here.
30+
assertTrue(
31+
PanamaVectorKernelProvider.isAvailable(),
32+
"expected Panama provider to be available on the test JDK",
33+
)
34+
}
35+
36+
@Test
37+
fun matmulFp32IsTheVectorKernelWhenAvailable() {
38+
assertSame(PanamaVectorMatmulKernel, PanamaVectorKernelProvider.matmulFp32())
39+
}
40+
41+
@Test
42+
fun beatsScalarInRegistryWhenBothRegistered() {
43+
KernelRegistry.register(ScalarKernelProvider)
44+
KernelRegistry.register(PanamaVectorKernelProvider)
45+
// Higher priority wins.
46+
assertSame(PanamaVectorKernelProvider, KernelRegistry.bestAvailable())
47+
assertEquals(
48+
listOf("panama-vector", "scalar"),
49+
KernelRegistry.availableNames(),
50+
)
51+
}
52+
53+
@Test
54+
fun killSwitchDisablesProvider() {
55+
val key = "skainet.cpu.vector.enabled"
56+
val previous = System.getProperty(key)
57+
try {
58+
System.setProperty(key, "false")
59+
assertEquals(false, PanamaVectorKernelProvider.isAvailable())
60+
assertEquals(null, PanamaVectorKernelProvider.matmulFp32())
61+
} finally {
62+
if (previous == null) System.clearProperty(key)
63+
else System.setProperty(key, previous)
64+
}
65+
}
66+
}
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package sk.ainet.exec.kernel
2+
3+
import kotlin.math.abs
4+
import kotlin.random.Random
5+
import kotlin.test.Test
6+
import kotlin.test.assertEquals
7+
import kotlin.test.assertFailsWith
8+
import kotlin.test.assertTrue
9+
10+
/**
11+
* Parity tests for [PanamaVectorMatmulKernel]. Every case runs the
12+
* Panama kernel and the [ScalarMatmulKernel] reference on the same
13+
* inputs and asserts the outputs agree within FP-rounding tolerance
14+
* (FMA + reordered reduction can differ from a left-to-right scalar
15+
* sum at the last few ULP).
16+
*
17+
* Tolerance scales with the contraction dimension `k`: each summand
18+
* carries up to ~`eps * |a|*|b|` rounding error, and we accumulate `k`
19+
* of them. `1e-5 * k` is comfortable for the inputs used here
20+
* (clamped to `[-0.5, 0.5]`).
21+
*/
22+
class PanamaVectorMatmulKernelTest {
23+
24+
private fun assertParity(
25+
m: Int, n: Int, k: Int,
26+
a: FloatArray, aOffset: Int, aStride: Int,
27+
b: FloatArray, bOffset: Int, bStride: Int,
28+
outStride: Int,
29+
) {
30+
val outScalar = FloatArray(m * outStride)
31+
val outPanama = FloatArray(m * outStride)
32+
ScalarMatmulKernel.matmul(
33+
a, aOffset, aStride,
34+
b, bOffset, bStride,
35+
outScalar, 0, outStride,
36+
m, n, k,
37+
)
38+
PanamaVectorMatmulKernel.matmul(
39+
a, aOffset, aStride,
40+
b, bOffset, bStride,
41+
outPanama, 0, outStride,
42+
m, n, k,
43+
)
44+
val tol = (1e-5f * k.coerceAtLeast(1)).coerceAtLeast(1e-5f)
45+
assertEquals(outScalar.size, outPanama.size, "length mismatch")
46+
for (i in outScalar.indices) {
47+
val diff = abs(outScalar[i] - outPanama[i])
48+
assertTrue(
49+
diff <= tol,
50+
"mismatch at $i: scalar=${outScalar[i]} panama=${outPanama[i]} diff=$diff tol=$tol",
51+
)
52+
}
53+
}
54+
55+
@Test
56+
fun small_2x3x4_contiguous_matches_scalar() {
57+
val a = floatArrayOf(1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f) // [2, 4]
58+
val b = FloatArray(4 * 3) { it.toFloat() } // [4, 3]
59+
assertParity(m = 2, n = 3, k = 4, a = a, aOffset = 0, aStride = 4, b = b, bOffset = 0, bStride = 3, outStride = 3)
60+
}
61+
62+
@Test
63+
fun random_8x16x32_matches_scalar() {
64+
val rng = Random(42)
65+
val a = FloatArray(8 * 32) { rng.nextFloat() - 0.5f }
66+
val b = FloatArray(32 * 16) { rng.nextFloat() - 0.5f }
67+
assertParity(m = 8, n = 16, k = 32, a = a, aOffset = 0, aStride = 32, b = b, bOffset = 0, bStride = 16, outStride = 16)
68+
}
69+
70+
@Test
71+
fun non_aligned_k_exercises_tail_loop() {
72+
// k = 23 is not a multiple of any common vector lane count (4, 8, 16),
73+
// so this forces the scalar tail loop to run.
74+
val rng = Random(1234)
75+
val m = 5; val n = 7; val k = 23
76+
val a = FloatArray(m * k) { rng.nextFloat() - 0.5f }
77+
val b = FloatArray(k * n) { rng.nextFloat() - 0.5f }
78+
assertParity(m = m, n = n, k = k, a = a, aOffset = 0, aStride = k, b = b, bOffset = 0, bStride = n, outStride = n)
79+
}
80+
81+
@Test
82+
fun strided_a_sub_block_matches_scalar() {
83+
// Parent A is [4, 8]; take rows 1..2 as a 2×8 sub-block.
84+
val parentA = FloatArray(4 * 8) { it.toFloat() }
85+
val b = FloatArray(8 * 3) { (it + 1).toFloat() }
86+
assertParity(
87+
m = 2, n = 3, k = 8,
88+
a = parentA, aOffset = 1 * 8, aStride = 8,
89+
b = b, bOffset = 0, bStride = 3,
90+
outStride = 3,
91+
)
92+
}
93+
94+
@Test
95+
fun large_irregular_31x17x23_matches_scalar() {
96+
val rng = Random(7)
97+
val m = 31; val n = 17; val k = 23
98+
val a = FloatArray(m * k) { rng.nextFloat() - 0.5f }
99+
val b = FloatArray(k * n) { rng.nextFloat() - 0.5f }
100+
assertParity(m = m, n = n, k = k, a = a, aOffset = 0, aStride = k, b = b, bOffset = 0, bStride = n, outStride = n)
101+
}
102+
103+
@Test
104+
fun zero_m_or_n_no_op() {
105+
val out = FloatArray(5) { 7f }
106+
PanamaVectorMatmulKernel.matmul(
107+
FloatArray(0), 0, 0,
108+
FloatArray(0), 0, 0,
109+
out, 0, 0,
110+
m = 0, n = 5, k = 0,
111+
)
112+
for (v in out) assertEquals(7f, v, "out should be unchanged when m == 0")
113+
}
114+
115+
@Test
116+
fun zero_k_zeros_output() {
117+
val out = FloatArray(2 * 3) { 9f }
118+
PanamaVectorMatmulKernel.matmul(
119+
FloatArray(0), 0, 0,
120+
FloatArray(0), 0, 0,
121+
out, 0, 3,
122+
m = 2, n = 3, k = 0,
123+
)
124+
for (v in out) assertEquals(0f, v, "out block should be zeroed when k == 0")
125+
}
126+
127+
@Test
128+
fun rejects_negative_dimensions() {
129+
assertFailsWith<IllegalArgumentException> {
130+
PanamaVectorMatmulKernel.matmul(
131+
FloatArray(0), 0, 0,
132+
FloatArray(0), 0, 0,
133+
FloatArray(0), 0, 0,
134+
m = -1, n = 1, k = 1,
135+
)
136+
}
137+
}
138+
}

0 commit comments

Comments
 (0)