@@ -3,11 +3,13 @@ package sk.ainet.exec.tensor.ops
33import jdk.incubator.vector.FloatVector
44import jdk.incubator.vector.VectorSpecies
55import jdk.incubator.vector.VectorOperators
6+ import sk.ainet.backend.api.kernel.Bf16MatmulKernel
67import sk.ainet.backend.api.kernel.Fp32MatmulKernel
78import sk.ainet.backend.api.kernel.KernelRegistry
89import sk.ainet.backend.api.kernel.KernelServiceLoader
910import sk.ainet.backend.api.kernel.Q4KMatmulKernel
1011import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
12+ import sk.ainet.exec.kernel.ScalarBf16MatmulKernel
1113import sk.ainet.exec.kernel.ScalarMatmulKernel
1214import sk.ainet.lang.tensor.Shape
1315import sk.ainet.lang.tensor.Tensor
@@ -17,6 +19,7 @@ import sk.ainet.lang.tensor.data.MemorySegmentBackedData
1719import sk.ainet.lang.tensor.data.MemorySegmentTensorData
1820import sk.ainet.lang.tensor.data.Q4MemorySegmentMarker
1921import sk.ainet.lang.tensor.data.Q4MemorySegmentTensorData
22+ import sk.ainet.lang.tensor.data.Bf16TensorData
2023import sk.ainet.lang.tensor.data.Q8_0TensorData
2124import sk.ainet.lang.tensor.data.Q8MemorySegmentMarker
2225import sk.ainet.lang.tensor.data.Q8MemorySegmentTensorData
@@ -89,6 +92,26 @@ internal class DefaultCpuOpsJvm(
8992 ?.matmulQ8_0()
9093 }
9194
95+ /* *
96+ * BF16 matmul kernel resolved via [KernelRegistry]. Unlike the Q4_K
97+ * and Q8_0 lookups (nullable, with legacy `JvmQuantizedVectorKernels`
98+ * fallbacks), BF16 has no pre-SPI implementation in this codebase —
99+ * the scalar SPI kernel is the floor. We mirror [fp32MatmulKernel]'s
100+ * pattern: non-null, picks the highest-priority provider that carries
101+ * a BF16 kernel (native FFM at 100, Panama Vector at 50), falls back
102+ * to [ScalarBf16MatmulKernel] when no SIMD provider reports
103+ * availability (e.g. tests that explicitly clear the registry).
104+ */
105+ private val bf16MatmulKernel: Bf16MatmulKernel by lazy {
106+ if (KernelRegistry .providers().isEmpty()) {
107+ KernelServiceLoader .installAll()
108+ }
109+ KernelRegistry .providers()
110+ .firstOrNull { it.isAvailable() && it.matmulBf16() != null }
111+ ?.matmulBf16()
112+ ? : ScalarBf16MatmulKernel
113+ }
114+
92115 override fun <T : DType , V > add (a : Tensor <T , V >, b : Tensor <T , V >): Tensor <T , V > {
93116 vectorFloatBinary(a, b, { x, y -> x.add(y) }) { x, y -> x + y }?.let { return it }
94117 return super .add(a, b)
@@ -511,6 +534,21 @@ internal class DefaultCpuOpsJvm(
511534 @Suppress(" UNCHECKED_CAST" )
512535 CpuTensor (outData as TensorData <T , V >, this , a.dtype)
513536 }
537+ is Bf16TensorData -> {
538+ // BF16 is dense (not block-quantized) and the kernel SPI is a
539+ // full SGEMM with `(m, n, k)` strides — no per-batch loop needed,
540+ // unlike the matvec-shaped Q4_K / Q8_0 / Q6_K branches.
541+ val outBuffer = FloatArray (batchSize * outputDim)
542+ bf16MatmulKernel.matmul(
543+ inputBuffer, 0 , inputDim,
544+ bData.packedData, 0 , outputDim * Bf16TensorData .BYTES_PER_ELEMENT ,
545+ outBuffer, 0 , outputDim,
546+ batchSize, outputDim, inputDim,
547+ )
548+ val outData = DenseFloatArrayTensorData <T >(Shape (batchSize, outputDim), outBuffer)
549+ @Suppress(" UNCHECKED_CAST" )
550+ CpuTensor (outData as TensorData <T , V >, this , a.dtype)
551+ }
514552 is Q6_KTensorData -> {
515553 val outBuffer = FloatArray (batchSize * outputDim)
516554 for (batch in 0 until batchSize) {
0 commit comments