@@ -3,6 +3,10 @@ package sk.ainet.exec.tensor.ops
33import jdk.incubator.vector.FloatVector
44import jdk.incubator.vector.VectorSpecies
55import jdk.incubator.vector.VectorOperators
6+ import sk.ainet.backend.api.kernel.Fp32MatmulKernel
7+ import sk.ainet.backend.api.kernel.KernelRegistry
8+ import sk.ainet.backend.api.kernel.KernelServiceLoader
9+ import sk.ainet.exec.kernel.ScalarMatmulKernel
610import sk.ainet.lang.tensor.Shape
711import sk.ainet.lang.tensor.Tensor
812import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData
@@ -32,6 +36,23 @@ internal class DefaultCpuOpsJvm(
3236
3337 private val floatSpecies: VectorSpecies <Float > = FloatVector .SPECIES_PREFERRED
3438
39+ /* *
40+ * FP32 matmul kernel resolved via [KernelRegistry]. First access on a
41+ * given instance auto-installs providers via [KernelServiceLoader]
42+ * if the registry is empty; subsequent calls reuse the cached
43+ * lookup. Apps that prefer to wire their own providers can call
44+ * `KernelRegistry.register(...)` before constructing this op set.
45+ * Falls back to [ScalarMatmulKernel] only when no provider reports
46+ * itself available — in practice, [PanamaVectorKernelProvider]
47+ * (priority 50) wins on JDK 21+ with the incubator module loaded.
48+ */
49+ private val fp32MatmulKernel: Fp32MatmulKernel by lazy {
50+ if (KernelRegistry .providers().isEmpty()) {
51+ KernelServiceLoader .installAll()
52+ }
53+ KernelRegistry .bestAvailable()?.matmulFp32() ? : ScalarMatmulKernel
54+ }
55+
3556 override fun <T : DType , V > add (a : Tensor <T , V >, b : Tensor <T , V >): Tensor <T , V > {
3657 vectorFloatBinary(a, b, { x, y -> x.add(y) }) { x, y -> x + y }?.let { return it }
3758 return super .add(a, b)
@@ -808,17 +829,16 @@ internal class DefaultCpuOpsJvm(
808829 }
809830 }
810831
811- // Use blocked matmul for small/medium sizes
812- val blockedThreshold = 16 * 16 // always use blocked above tiny cases
813- if (m >= blockedThreshold || n >= blockedThreshold || k >= blockedThreshold) {
814- JvmVectorKernels .matmulFloatBlocked(m, k, n, aData.buffer, bData.buffer, outBuffer)
815- val outData = DenseFloatArrayTensorData <T >(Shape (m, n), outBuffer)
816- @Suppress(" UNCHECKED_CAST" )
817- return CpuTensor (outData as TensorData <T , V >, this , a.dtype)
818- }
819-
820- // Fallback to simple vectorized inner-product matmul
821- JvmVectorKernels .matmulFloat(m, k, n, aData.buffer, bData.buffer, outBuffer)
832+ // Route through the kernel SPI — the registered provider
833+ // (Panama on JDK 21+, scalar otherwise) is tile-blocked and
834+ // handles small + large inputs in one path, so the previous
835+ // simple-vs-blocked fork is no longer needed.
836+ fp32MatmulKernel.matmul(
837+ aData.buffer, 0 , k,
838+ bData.buffer, 0 , n,
839+ outBuffer, 0 , n,
840+ m, n, k,
841+ )
822842 val outData = DenseFloatArrayTensorData <T >(Shape (m, n), outBuffer)
823843 @Suppress(" UNCHECKED_CAST" )
824844 return CpuTensor (outData as TensorData <T , V >, this , a.dtype)
0 commit comments