Skip to content

Commit aadb8b2

Browse files
Merge pull request #614 from SKaiNET-developers/feature/bf16-dispatch
BF16 dispatch chain (Phase 3/3): wire Bf16TensorData dispatch in DefaultCpuOpsJvm
2 parents 4e55ded + c4de7fb commit aadb8b2

2 files changed

Lines changed: 140 additions & 0 deletions

File tree

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ package sk.ainet.exec.tensor.ops
33
import jdk.incubator.vector.FloatVector
44
import jdk.incubator.vector.VectorSpecies
55
import jdk.incubator.vector.VectorOperators
6+
import sk.ainet.backend.api.kernel.Bf16MatmulKernel
67
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
78
import sk.ainet.backend.api.kernel.KernelRegistry
89
import sk.ainet.backend.api.kernel.KernelServiceLoader
910
import sk.ainet.backend.api.kernel.Q4KMatmulKernel
1011
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
12+
import sk.ainet.exec.kernel.ScalarBf16MatmulKernel
1113
import sk.ainet.exec.kernel.ScalarMatmulKernel
1214
import sk.ainet.lang.tensor.Shape
1315
import sk.ainet.lang.tensor.Tensor
@@ -17,6 +19,7 @@ import sk.ainet.lang.tensor.data.MemorySegmentBackedData
1719
import sk.ainet.lang.tensor.data.MemorySegmentTensorData
1820
import sk.ainet.lang.tensor.data.Q4MemorySegmentMarker
1921
import sk.ainet.lang.tensor.data.Q4MemorySegmentTensorData
22+
import sk.ainet.lang.tensor.data.Bf16TensorData
2023
import sk.ainet.lang.tensor.data.Q8_0TensorData
2124
import sk.ainet.lang.tensor.data.Q8MemorySegmentMarker
2225
import sk.ainet.lang.tensor.data.Q8MemorySegmentTensorData
@@ -89,6 +92,26 @@ internal class DefaultCpuOpsJvm(
8992
?.matmulQ8_0()
9093
}
9194

95+
/**
96+
* BF16 matmul kernel resolved via [KernelRegistry]. Unlike the Q4_K
97+
* and Q8_0 lookups (nullable, with legacy `JvmQuantizedVectorKernels`
98+
* fallbacks), BF16 has no pre-SPI implementation in this codebase —
99+
* the scalar SPI kernel is the floor. We mirror [fp32MatmulKernel]'s
100+
* pattern: non-null, picks the highest-priority provider that carries
101+
* a BF16 kernel (native FFM at 100, Panama Vector at 50), falls back
102+
* to [ScalarBf16MatmulKernel] when no SIMD provider reports
103+
* availability (e.g. tests that explicitly clear the registry).
104+
*/
105+
private val bf16MatmulKernel: Bf16MatmulKernel by lazy {
106+
if (KernelRegistry.providers().isEmpty()) {
107+
KernelServiceLoader.installAll()
108+
}
109+
KernelRegistry.providers()
110+
.firstOrNull { it.isAvailable() && it.matmulBf16() != null }
111+
?.matmulBf16()
112+
?: ScalarBf16MatmulKernel
113+
}
114+
92115
override fun <T : DType, V> add(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> {
93116
vectorFloatBinary(a, b, { x, y -> x.add(y) }) { x, y -> x + y }?.let { return it }
94117
return super.add(a, b)
@@ -511,6 +534,21 @@ internal class DefaultCpuOpsJvm(
511534
@Suppress("UNCHECKED_CAST")
512535
CpuTensor(outData as TensorData<T, V>, this, a.dtype)
513536
}
537+
is Bf16TensorData -> {
538+
// BF16 is dense (not block-quantized) and the kernel SPI is a
539+
// full SGEMM with `(m, n, k)` strides — no per-batch loop needed,
540+
// unlike the matvec-shaped Q4_K / Q8_0 / Q6_K branches.
541+
val outBuffer = FloatArray(batchSize * outputDim)
542+
bf16MatmulKernel.matmul(
543+
inputBuffer, 0, inputDim,
544+
bData.packedData, 0, outputDim * Bf16TensorData.BYTES_PER_ELEMENT,
545+
outBuffer, 0, outputDim,
546+
batchSize, outputDim, inputDim,
547+
)
548+
val outData = DenseFloatArrayTensorData<T>(Shape(batchSize, outputDim), outBuffer)
549+
@Suppress("UNCHECKED_CAST")
550+
CpuTensor(outData as TensorData<T, V>, this, a.dtype)
551+
}
514552
is Q6_KTensorData -> {
515553
val outBuffer = FloatArray(batchSize * outputDim)
516554
for (batch in 0 until batchSize) {
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package sk.ainet.exec.tensor.ops
2+
3+
import kotlin.math.abs
4+
import kotlin.random.Random
5+
import kotlin.test.Test
6+
import kotlin.test.assertTrue
7+
import sk.ainet.context.DirectCpuExecutionContext
8+
import sk.ainet.exec.kernel.ScalarBf16MatmulKernel
9+
import sk.ainet.lang.tensor.Shape
10+
import sk.ainet.lang.tensor.Tensor
11+
import sk.ainet.lang.tensor.data.Bf16DenseTensorData
12+
import sk.ainet.lang.tensor.data.TensorData
13+
import sk.ainet.lang.types.FP32
14+
15+
/**
16+
* Integration tests for the FP32 × BF16 dispatch path in
17+
* [DefaultCpuOpsJvm.matmul]. Confirms `ops.matmul` against a
18+
* `Bf16DenseTensorData` weight produces the same output as the
19+
* `ScalarBf16MatmulKernel` reference — proving the new `is
20+
* Bf16TensorData ->` branch in `chooseQuantizedMatmul` is reached
21+
* and routes through the BF16 SPI (or its scalar fallback when no
22+
* SIMD provider resolves). Mirrors `Q8_0MatmulDispatchTest` in
23+
* shape and coverage.
24+
*/
25+
class Bf16MatmulDispatchTest {
26+
27+
private val ctx = DirectCpuExecutionContext()
28+
29+
/** BF16 has 7 mantissa bits — accumulated error scales with `k`. */
30+
private val bf16TolPerK = 1e-2f
31+
32+
/** Truncate FP32 → BF16 (high 16 bits, zero rounding), pack LE bytes. */
33+
private fun fp32ToBf16Bytes(values: FloatArray): ByteArray {
34+
val out = ByteArray(values.size * 2)
35+
for (i in values.indices) {
36+
val bf16 = (values[i].toRawBits() ushr 16) and 0xFFFF
37+
out[i * 2] = (bf16 and 0xFF).toByte()
38+
out[i * 2 + 1] = ((bf16 ushr 8) and 0xFF).toByte()
39+
}
40+
return out
41+
}
42+
43+
@Suppress("UNCHECKED_CAST")
44+
private fun bf16Weight(inputDim: Int, outputDim: Int, seed: Int): Pair<Tensor<FP32, Float>, ByteArray> {
45+
val rng = Random(seed)
46+
val values = FloatArray(inputDim * outputDim) { rng.nextFloat() - 0.5f }
47+
val bytes = fp32ToBf16Bytes(values)
48+
val data = Bf16DenseTensorData(Shape(inputDim, outputDim), bytes) as TensorData<FP32, Float>
49+
return ctx.fromData(data, FP32::class) to bytes
50+
}
51+
52+
private fun scalarReference(
53+
input: FloatArray, weightBytes: ByteArray,
54+
m: Int, n: Int, k: Int,
55+
): FloatArray {
56+
val out = FloatArray(m * n)
57+
ScalarBf16MatmulKernel.matmul(
58+
input, 0, k,
59+
weightBytes, 0, n * 2,
60+
out, 0, n,
61+
m, n, k,
62+
)
63+
return out
64+
}
65+
66+
private fun assertDispatchMatchesScalar(
67+
m: Int, k: Int, n: Int, seed: Int,
68+
) {
69+
val rng = Random(seed)
70+
val inputFloats = FloatArray(m * k) { rng.nextFloat() - 0.5f }
71+
val (weight, weightBytes) = bf16Weight(k, n, seed)
72+
val input = ctx.fromFloatArray<FP32, Float>(Shape(m, k), FP32::class, inputFloats)
73+
74+
val out = ctx.ops.matmul(input, weight)
75+
val outArr = out.data.copyToFloatArray()
76+
val expected = scalarReference(inputFloats, weightBytes, m, n, k)
77+
78+
val tol = (bf16TolPerK * k.coerceAtLeast(1)).coerceAtLeast(bf16TolPerK)
79+
for (i in expected.indices) {
80+
val diff = abs(expected[i] - outArr[i])
81+
assertTrue(
82+
diff <= tol,
83+
"BF16 dispatch mismatch at $i: expected=${expected[i]} got=${outArr[i]} diff=$diff tol=$tol",
84+
)
85+
}
86+
}
87+
88+
@Test
89+
fun single_batch_matmul_against_bf16_weight_routes_correctly() {
90+
assertDispatchMatchesScalar(m = 1, k = 128, n = 64, seed = 1)
91+
}
92+
93+
@Test
94+
fun multi_batch_matmul_against_bf16_weight_routes_correctly() {
95+
assertDispatchMatchesScalar(m = 3, k = 256, n = 32, seed = 2)
96+
}
97+
98+
@Test
99+
fun llm_typical_attention_proj_matmul_routes_correctly() {
100+
assertDispatchMatchesScalar(m = 1, k = 512, n = 512, seed = 3)
101+
}
102+
}

0 commit comments

Comments
 (0)