Skip to content

Commit db00c95

Browse files
Merge pull request #561 from SKaiNET-developers/feature/jvm-matmul-route-spi
feat(matmul): route DefaultCpuOpsJvm FP32 matmul through KernelRegistry
2 parents 5f7f515 + 6d9a4be commit db00c95

1 file changed

Lines changed: 31 additions & 11 deletions

File tree

  • skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ package sk.ainet.exec.tensor.ops
33
import jdk.incubator.vector.FloatVector
44
import jdk.incubator.vector.VectorSpecies
55
import jdk.incubator.vector.VectorOperators
6+
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
7+
import sk.ainet.backend.api.kernel.KernelRegistry
8+
import sk.ainet.backend.api.kernel.KernelServiceLoader
9+
import sk.ainet.exec.kernel.ScalarMatmulKernel
610
import sk.ainet.lang.tensor.Shape
711
import sk.ainet.lang.tensor.Tensor
812
import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData
@@ -32,6 +36,23 @@ internal class DefaultCpuOpsJvm(
3236

3337
private val floatSpecies: VectorSpecies<Float> = FloatVector.SPECIES_PREFERRED
3438

39+
/**
40+
* FP32 matmul kernel resolved via [KernelRegistry]. First access on a
41+
* given instance auto-installs providers via [KernelServiceLoader]
42+
* if the registry is empty; subsequent calls reuse the cached
43+
* lookup. Apps that prefer to wire their own providers can call
44+
* `KernelRegistry.register(...)` before constructing this op set.
45+
* Falls back to [ScalarMatmulKernel] only when no provider reports
46+
* itself available — in practice, [PanamaVectorKernelProvider]
47+
* (priority 50) wins on JDK 21+ with the incubator module loaded.
48+
*/
49+
private val fp32MatmulKernel: Fp32MatmulKernel by lazy {
50+
if (KernelRegistry.providers().isEmpty()) {
51+
KernelServiceLoader.installAll()
52+
}
53+
KernelRegistry.bestAvailable()?.matmulFp32() ?: ScalarMatmulKernel
54+
}
55+
3556
override fun <T : DType, V> add(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> {
3657
vectorFloatBinary(a, b, { x, y -> x.add(y) }) { x, y -> x + y }?.let { return it }
3758
return super.add(a, b)
@@ -808,17 +829,16 @@ internal class DefaultCpuOpsJvm(
808829
}
809830
}
810831

811-
// Use blocked matmul for small/medium sizes
812-
val blockedThreshold = 16 * 16 // always use blocked above tiny cases
813-
if (m >= blockedThreshold || n >= blockedThreshold || k >= blockedThreshold) {
814-
JvmVectorKernels.matmulFloatBlocked(m, k, n, aData.buffer, bData.buffer, outBuffer)
815-
val outData = DenseFloatArrayTensorData<T>(Shape(m, n), outBuffer)
816-
@Suppress("UNCHECKED_CAST")
817-
return CpuTensor(outData as TensorData<T, V>, this, a.dtype)
818-
}
819-
820-
// Fallback to simple vectorized inner-product matmul
821-
JvmVectorKernels.matmulFloat(m, k, n, aData.buffer, bData.buffer, outBuffer)
832+
// Route through the kernel SPI — the registered provider
833+
// (Panama on JDK 21+, scalar otherwise) is tile-blocked and
834+
// handles small + large inputs in one path, so the previous
835+
// simple-vs-blocked fork is no longer needed.
836+
fp32MatmulKernel.matmul(
837+
aData.buffer, 0, k,
838+
bData.buffer, 0, n,
839+
outBuffer, 0, n,
840+
m, n, k,
841+
)
822842
val outData = DenseFloatArrayTensorData<T>(Shape(m, n), outBuffer)
823843
@Suppress("UNCHECKED_CAST")
824844
return CpuTensor(outData as TensorData<T, V>, this, a.dtype)

0 commit comments

Comments
 (0)