Skip to content

Commit 4823154

Browse files
Merge pull request #608 from SKaiNET-developers/feature/q8-0-dispatch
Q8_0 matmul: route DefaultCpuOpsJvm.chooseQuantizedMatmul through KernelRegistry SPI
2 parents ae4d657 + d55b790 commit 4823154

2 files changed

Lines changed: 160 additions & 8 deletions

File tree

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import sk.ainet.backend.api.kernel.Fp32MatmulKernel
77
import sk.ainet.backend.api.kernel.KernelRegistry
88
import sk.ainet.backend.api.kernel.KernelServiceLoader
99
import sk.ainet.backend.api.kernel.Q4KMatmulKernel
10+
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
1011
import sk.ainet.exec.kernel.ScalarMatmulKernel
1112
import sk.ainet.lang.tensor.Shape
1213
import sk.ainet.lang.tensor.Tensor
@@ -71,6 +72,23 @@ internal class DefaultCpuOpsJvm(
7172
?.matmulQ4K()
7273
}
7374

75+
/**
76+
* Q8_0 kernel resolved via [KernelRegistry], lazily initialized on
77+
* first quantized matmul call. Mirrors [q4kMatmulKernel] — auto-
78+
* installs ServiceLoader-discovered providers when the registry is
79+
* empty, returns `null` if no provider carries a Q8_0 kernel.
80+
* Caller falls back to [JvmQuantizedVectorKernels.matmulQ8_0Vec],
81+
* preserving the legacy code path when the SPI doesn't resolve.
82+
*/
83+
private val q8_0MatmulKernel: Q8_0MatmulKernel? by lazy {
84+
if (KernelRegistry.providers().isEmpty()) {
85+
KernelServiceLoader.installAll()
86+
}
87+
KernelRegistry.providers()
88+
.firstOrNull { it.isAvailable() && it.matmulQ8_0() != null }
89+
?.matmulQ8_0()
90+
}
91+
7492
override fun <T : DType, V> add(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> {
7593
vectorFloatBinary(a, b, { x, y -> x.add(y) }) { x, y -> x + y }?.let { return it }
7694
return super.add(a, b)
@@ -439,17 +457,27 @@ internal class DefaultCpuOpsJvm(
439457
return when (bData) {
440458
is Q8_0TensorData -> {
441459
val outBuffer = FloatArray(batchSize * outputDim)
460+
val spiKernel = q8_0MatmulKernel
442461
for (batch in 0 until batchSize) {
443462
val batchInput = if (batchSize == 1) inputBuffer
444463
else inputBuffer.copyOfRange(batch * inputDim, (batch + 1) * inputDim)
445-
JvmQuantizedVectorKernels.matmulQ8_0Vec(
446-
batchInput,
447-
bData.packedData,
448-
inputDim,
449-
outputDim,
450-
outBuffer,
451-
batch * outputDim,
452-
)
464+
if (spiKernel != null) {
465+
spiKernel.matmul(
466+
batchInput, 0,
467+
bData.packedData, 0,
468+
inputDim, outputDim,
469+
outBuffer, batch * outputDim,
470+
)
471+
} else {
472+
JvmQuantizedVectorKernels.matmulQ8_0Vec(
473+
batchInput,
474+
bData.packedData,
475+
inputDim,
476+
outputDim,
477+
outBuffer,
478+
batch * outputDim,
479+
)
480+
}
453481
}
454482
val outData = DenseFloatArrayTensorData<T>(Shape(batchSize, outputDim), outBuffer)
455483
@Suppress("UNCHECKED_CAST")
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package sk.ainet.exec.tensor.ops
2+
3+
import kotlin.math.abs
4+
import kotlin.random.Random
5+
import kotlin.test.Test
6+
import kotlin.test.assertTrue
7+
import sk.ainet.context.DirectCpuExecutionContext
8+
import sk.ainet.exec.kernel.ScalarQ8_0MatmulKernel
9+
import sk.ainet.lang.tensor.Shape
10+
import sk.ainet.lang.tensor.Tensor
11+
import sk.ainet.lang.tensor.data.Q8_0BlockTensorData
12+
import sk.ainet.lang.tensor.data.TensorData
13+
import sk.ainet.lang.types.FP32
14+
15+
/**
16+
* Integration tests for the FP32 × Q8_0 dispatch path in
17+
* [DefaultCpuOpsJvm.matmul]. Confirms that calling matmul on a
18+
* Q8_0-backed weight tensor produces the same output as the scalar
19+
* Q8_0 kernel — proving the dispatch actually routes through the
20+
* registered Q8_0 SPI kernel (or its legacy `JvmQuantizedVectorKernels`
21+
* fallback when the SPI doesn't resolve). Either path is correct;
22+
* this test pins integration, not kernel correctness (already covered
23+
* by the per-kernel parity tests in #606).
24+
*/
25+
class Q8_0MatmulDispatchTest {
26+
27+
private val ctx = DirectCpuExecutionContext()
28+
29+
private val blockSize = 32
30+
private val bytesPerBlock = 34
31+
32+
private fun randomQ8_0Bytes(blocksPerInputDim: Int, outputDim: Int, seed: Int): ByteArray {
33+
val rng = Random(seed)
34+
val numBlocks = blocksPerInputDim * outputDim
35+
val bytes = ByteArray(numBlocks * bytesPerBlock)
36+
rng.nextBytes(bytes)
37+
for (block in 0 until numBlocks) {
38+
val base = block * bytesPerBlock
39+
// FP16 scale ≈ 7.6e-3 (low-bit FP16 0x2200) — safely finite, non-zero.
40+
bytes[base + 0] = 0x00.toByte()
41+
bytes[base + 1] = 0x22.toByte()
42+
}
43+
return bytes
44+
}
45+
46+
private fun ScalarQ8_0_reference(
47+
input: FloatArray, weight: ByteArray,
48+
inputDim: Int, outputDim: Int,
49+
batchSize: Int,
50+
): FloatArray {
51+
val out = FloatArray(batchSize * outputDim)
52+
for (b in 0 until batchSize) {
53+
ScalarQ8_0MatmulKernel.matmul(
54+
input, b * inputDim,
55+
weight, 0,
56+
inputDim, outputDim,
57+
out, b * outputDim,
58+
)
59+
}
60+
return out
61+
}
62+
63+
@Suppress("UNCHECKED_CAST")
64+
private fun q8_0Tensor(inputDim: Int, outputDim: Int, seed: Int): Tensor<FP32, Float> {
65+
val blocksPerInputDim = inputDim / blockSize
66+
val bytes = randomQ8_0Bytes(blocksPerInputDim, outputDim, seed)
67+
// Logical shape of a Q8_0 weight tensor is [inputDim, outputDim].
68+
val data = Q8_0BlockTensorData(Shape(inputDim, outputDim), bytes)
69+
return ctx.fromData(data as TensorData<FP32, Float>, FP32::class)
70+
}
71+
72+
private fun assertDispatchMatchesScalar(
73+
batchSize: Int, inputDim: Int, outputDim: Int, seed: Int,
74+
tolPerBlock: Float = 1e-2f,
75+
) {
76+
val rng = Random(seed)
77+
val inputFloats = FloatArray(batchSize * inputDim) { rng.nextFloat() - 0.5f }
78+
val blocksPerInputDim = inputDim / blockSize
79+
80+
val weightBytes = randomQ8_0Bytes(blocksPerInputDim, outputDim, seed)
81+
val weight = q8_0Tensor(inputDim, outputDim, seed).let { t ->
82+
// q8_0Tensor regenerates bytes from seed → use the SAME byte buffer
83+
// for the scalar reference path so the comparison is honest.
84+
@Suppress("UNCHECKED_CAST")
85+
val td = Q8_0BlockTensorData(Shape(inputDim, outputDim), weightBytes) as TensorData<FP32, Float>
86+
ctx.fromData(td, FP32::class)
87+
}
88+
val input = ctx.fromFloatArray<FP32, Float>(
89+
Shape(batchSize, inputDim), FP32::class, inputFloats,
90+
)
91+
92+
val out = ctx.ops.matmul(input, weight)
93+
val outArr = out.data.copyToFloatArray()
94+
95+
val expected = ScalarQ8_0_reference(inputFloats, weightBytes, inputDim, outputDim, batchSize)
96+
97+
val tol = (tolPerBlock * blocksPerInputDim.coerceAtLeast(1)).coerceAtLeast(tolPerBlock)
98+
for (i in expected.indices) {
99+
val diff = abs(expected[i] - outArr[i])
100+
assertTrue(
101+
diff <= tol,
102+
"dispatch mismatch at $i: expected=${expected[i]} got=${outArr[i]} diff=$diff tol=$tol",
103+
)
104+
}
105+
}
106+
107+
@Test
108+
fun single_batch_matmul_against_q8_0_weight_routes_correctly() {
109+
// batchSize=1 hits the optimized "no copyOfRange" branch in chooseQuantizedMatmul.
110+
assertDispatchMatchesScalar(batchSize = 1, inputDim = 128, outputDim = 64, seed = 1)
111+
}
112+
113+
@Test
114+
fun multi_batch_matmul_against_q8_0_weight_routes_correctly() {
115+
// batchSize>1 exercises the per-row copyOfRange branch.
116+
assertDispatchMatchesScalar(batchSize = 3, inputDim = 256, outputDim = 32, seed = 2)
117+
}
118+
119+
@Test
120+
fun llm_typical_attention_proj_matmul_routes_correctly() {
121+
// Realistic attention-projection size (matvec at dim×dim).
122+
assertDispatchMatchesScalar(batchSize = 1, inputDim = 512, outputDim = 512, seed = 3)
123+
}
124+
}

0 commit comments

Comments
 (0)