Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package sk.ainet.backend.api.kernel

/**
* FP32 matrix multiplication kernel: `C(m, n) = A(m, k) · B(k, n)` in
* row-major layout.
*
* This is a thin SPI between high-level tensor ops and the actual
* numeric kernel that does the FLOPs. It exists so a SIMD-accelerated
* `matmul` can be plugged in without re-implementing the rest of an
* op-level backend, and so a hand-written kernel can be tested against
* a scalar reference.
*
* Strides are in **floats** (not bytes) and let callers pass sub-blocks
* of larger arrays without copying. For a contiguous matrix of shape
* `(m, k)`, `aStride == k`. For a sub-block, `aStride` is the leading
* dimension of the *parent* matrix.
*
* Implementations must NOT mutate `a` or `b`. They MAY assume the
* arrays do not alias each other or `out`. Implementations MUST fully
* overwrite the `m × n` block of `out` they're responsible for —
* accumulator semantics are caller-controlled (e.g. zero `out` first if
* you want C = A·B; pre-fill `out` if you want C += A·B and the kernel
* is fused for that — no fused-accumulate kernel is in scope yet).
*/
public interface Fp32MatmulKernel {
/**
* @param a left operand `(m, k)`, row-major, with stride `aStride` along
* the leading (row) dimension.
* @param aOffset element offset into [a] where the (0, 0) entry lives.
* @param aStride distance in floats between consecutive rows of [a].
* For a contiguous matrix this equals `k`.
* @param b right operand `(k, n)`, row-major, with stride `bStride`.
* @param bOffset element offset into [b].
* @param bStride distance in floats between consecutive rows of [b].
* For a contiguous matrix this equals `n`.
* @param out output `(m, n)`, row-major, with stride `outStride`.
* @param outOffset element offset into [out].
* @param outStride distance in floats between consecutive rows of [out].
* For a contiguous matrix this equals `n`.
* @param m number of rows of A and C.
* @param n number of columns of B and C.
* @param k contraction dimension (cols of A == rows of B).
*/
public fun matmul(
a: FloatArray, aOffset: Int, aStride: Int,
b: FloatArray, bOffset: Int, bStride: Int,
out: FloatArray, outOffset: Int, outStride: Int,
m: Int, n: Int, k: Int
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package sk.ainet.backend.api.kernel

/**
* Provider for a family of related numeric kernels (matmul, SDPA, ...).
*
* A backend (Panama Vector, native FFM, IREE, Metal, ...) bundles its
* kernels behind a single provider so callers only need to know the
* top-level provider name. Providers self-report whether they are
* available on the current platform / runtime — e.g. a Panama Vector
* provider returns `false` from [isAvailable] on a JDK that doesn't
* have the incubator module loaded.
*
* Lookup rules:
* - Higher [priority] wins when multiple providers report
* [isAvailable] = `true`. Providers should rank themselves by
* expected performance: scalar ≈ 0, Panama Vector ≈ 50, hand-tuned
* native ≈ 100.
* - Each per-kernel accessor returns `null` when the provider does not
* carry that kernel, so callers can fall through to a lower-priority
* provider.
*/
public interface KernelProvider {
/** Stable, human-readable identifier. */
public val name: String

/**
* Relative ranking versus other providers. Higher = preferred when
* available. The scalar reference uses `0`; SIMD-accelerated
* providers should use a larger value.
*/
public val priority: Int

/**
* Reports whether this provider's kernels can run in the current
* process. Expensive checks (probing CPU features, loading native
* libraries) should be done once and cached.
*/
public fun isAvailable(): Boolean

/**
* FP32 matmul kernel exposed by this provider, or `null` if this
* provider does not specialize matmul.
*/
public fun matmulFp32(): Fp32MatmulKernel?
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package sk.ainet.backend.api.kernel

/**
* Process-wide registry of [KernelProvider] instances.
*
* Backends that ship a [KernelProvider] register it via [register] at
* load time. Callers that need a kernel ask [bestAvailable] (or
* [find] with a name) for the highest-priority provider that reports
* itself available, then pull the specific kernel they need from the
* provider's accessors.
*
* The registry is plain manual registration today — JVM auto-discovery
* via `java.util.ServiceLoader` can be layered on in a follow-up PR
* once a second concrete provider exists (Panama Vector). Callers that
* want a guaranteed scalar fallback can pin
* `sk.ainet.exec.kernel.ScalarKernelProvider` directly without going
* through the registry.
*
* Thread safety: [register] is not thread-safe. Call it during
* single-threaded startup or guard with your own lock.
*/
public object KernelRegistry {
private val providers: MutableList<KernelProvider> = mutableListOf()

/**
* Register a provider. Re-registering the same instance is a no-op.
*/
public fun register(provider: KernelProvider) {
if (providers.any { it === provider }) return
providers.add(provider)
providers.sortByDescending { it.priority }
}

/** All registered providers, sorted by priority descending. */
public fun providers(): List<KernelProvider> = providers.toList()

/** Find a provider by name (case-insensitive), or `null`. */
public fun find(name: String): KernelProvider? =
providers.firstOrNull { it.name.equals(name, ignoreCase = true) }

/**
* Highest-priority [isAvailable] provider, or `null` if none is
* registered or available. Callers that absolutely need a kernel
* should use the explicit scalar fallback instead.
*/
public fun bestAvailable(): KernelProvider? =
providers.firstOrNull { it.isAvailable() }

/** Names of all currently-available providers. */
public fun availableNames(): List<String> =
providers.filter { it.isAvailable() }.map { it.name }

/** Test/diagnostic helper. Removes all registered providers. */
public fun clearForTesting() {
providers.clear()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package sk.ainet.exec.kernel

import sk.ainet.backend.api.kernel.Fp32MatmulKernel
import sk.ainet.backend.api.kernel.KernelProvider

/**
* Scalar (non-SIMD) [KernelProvider] — always available, lowest
* priority. Acts as the correctness reference and the guaranteed
* fallback when no accelerated provider is registered.
*
* Callers can pin this provider directly when they want deterministic
* scalar arithmetic without registry interaction (useful in tests):
*
* ```kotlin
* val kernel = ScalarKernelProvider.matmulFp32()!!
* kernel.matmul(a, 0, k, b, 0, n, out, 0, n, m, n, k)
* ```
*/
public object ScalarKernelProvider : KernelProvider {
override val name: String = "scalar"
override val priority: Int = 0
override fun isAvailable(): Boolean = true
override fun matmulFp32(): Fp32MatmulKernel = ScalarMatmulKernel
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package sk.ainet.exec.kernel

import sk.ainet.backend.api.kernel.Fp32MatmulKernel

/**
* Scalar reference implementation of [Fp32MatmulKernel] — three nested
* loops, no SIMD. Always available on every KMP target. Used as:
*
* - The correctness reference that accelerated kernels (Panama, native)
* must match bit-for-bit (within FP order tolerance).
* - A guaranteed fallback when no accelerated provider is registered or
* available.
*
* Performance is modest (no vectorization, no cache-blocking), so
* production code should layer a Panama or native provider on top via
* [sk.ainet.backend.api.kernel.KernelRegistry].
*/
public object ScalarMatmulKernel : Fp32MatmulKernel {
override fun matmul(
a: FloatArray, aOffset: Int, aStride: Int,
b: FloatArray, bOffset: Int, bStride: Int,
out: FloatArray, outOffset: Int, outStride: Int,
m: Int, n: Int, k: Int
) {
require(m >= 0 && n >= 0 && k >= 0) {
"ScalarMatmulKernel: m, n, k must be non-negative; got m=$m n=$n k=$k"
}
if (m == 0 || n == 0) return
// k == 0 → C = 0; the strides may still be > 0 so we need to
// explicitly zero the output block.
if (k == 0) {
for (i in 0 until m) {
val rowOff = outOffset + i * outStride
for (j in 0 until n) out[rowOff + j] = 0f
}
return
}
for (i in 0 until m) {
val aRowOff = aOffset + i * aStride
val outRowOff = outOffset + i * outStride
for (j in 0 until n) {
var sum = 0f
for (kk in 0 until k) {
sum += a[aRowOff + kk] * b[bOffset + kk * bStride + j]
}
out[outRowOff + j] = sum
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package sk.ainet.exec.kernel

import kotlin.test.AfterTest
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull
import kotlin.test.assertSame
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
import sk.ainet.backend.api.kernel.KernelProvider
import sk.ainet.backend.api.kernel.KernelRegistry

class KernelRegistryTest {

@BeforeTest
fun setUp() = KernelRegistry.clearForTesting()

@AfterTest
fun tearDown() = KernelRegistry.clearForTesting()

@Test
fun emptyRegistryHasNoBest() {
assertNull(KernelRegistry.bestAvailable())
assertEquals(emptyList(), KernelRegistry.availableNames())
}

@Test
fun scalarRegistersAndIsBest() {
KernelRegistry.register(ScalarKernelProvider)
assertSame(ScalarKernelProvider, KernelRegistry.bestAvailable())
assertEquals(listOf("scalar"), KernelRegistry.availableNames())
}

@Test
fun higherPriorityWins() {
val fast = object : KernelProvider {
override val name = "fake-fast"
override val priority = 50
override fun isAvailable() = true
override fun matmulFp32(): Fp32MatmulKernel = ScalarMatmulKernel
}
KernelRegistry.register(ScalarKernelProvider)
KernelRegistry.register(fast)
assertSame(fast, KernelRegistry.bestAvailable())
}

@Test
fun unavailableProviderIsSkipped() {
val pretender = object : KernelProvider {
override val name = "pretender"
override val priority = 100
override fun isAvailable() = false
override fun matmulFp32(): Fp32MatmulKernel = ScalarMatmulKernel
}
KernelRegistry.register(pretender)
KernelRegistry.register(ScalarKernelProvider)
// pretender outranks scalar but isn't available — scalar wins.
assertSame(ScalarKernelProvider, KernelRegistry.bestAvailable())
// availableNames excludes pretender.
assertEquals(listOf("scalar"), KernelRegistry.availableNames())
}

@Test
fun findByNameIsCaseInsensitive() {
KernelRegistry.register(ScalarKernelProvider)
assertSame(ScalarKernelProvider, KernelRegistry.find("scalar"))
assertSame(ScalarKernelProvider, KernelRegistry.find("Scalar"))
assertSame(ScalarKernelProvider, KernelRegistry.find("SCALAR"))
assertNull(KernelRegistry.find("unknown"))
}

@Test
fun reRegisteringSameInstanceIsNoOp() {
KernelRegistry.register(ScalarKernelProvider)
KernelRegistry.register(ScalarKernelProvider)
assertEquals(1, KernelRegistry.providers().size)
}
}
Loading
Loading