|
| 1 | +package sk.ainet.exec.tensor.ops |
| 2 | + |
| 3 | +import kotlin.math.abs |
| 4 | +import kotlin.random.Random |
| 5 | +import kotlin.test.Test |
| 6 | +import kotlin.test.assertTrue |
| 7 | +import sk.ainet.context.DirectCpuExecutionContext |
| 8 | +import sk.ainet.exec.kernel.ScalarQ8_0MatmulKernel |
| 9 | +import sk.ainet.lang.tensor.Shape |
| 10 | +import sk.ainet.lang.tensor.Tensor |
| 11 | +import sk.ainet.lang.tensor.data.Q8_0BlockTensorData |
| 12 | +import sk.ainet.lang.tensor.data.TensorData |
| 13 | +import sk.ainet.lang.types.FP32 |
| 14 | + |
| 15 | +/** |
| 16 | + * Integration tests for the FP32 × Q8_0 dispatch path in |
| 17 | + * [DefaultCpuOpsJvm.matmul]. Confirms that calling matmul on a |
| 18 | + * Q8_0-backed weight tensor produces the same output as the scalar |
| 19 | + * Q8_0 kernel — proving the dispatch actually routes through the |
| 20 | + * registered Q8_0 SPI kernel (or its legacy `JvmQuantizedVectorKernels` |
| 21 | + * fallback when the SPI doesn't resolve). Either path is correct; |
| 22 | + * this test pins integration, not kernel correctness (already covered |
| 23 | + * by the per-kernel parity tests in #606). |
| 24 | + */ |
| 25 | +class Q8_0MatmulDispatchTest { |
| 26 | + |
| 27 | + private val ctx = DirectCpuExecutionContext() |
| 28 | + |
| 29 | + private val blockSize = 32 |
| 30 | + private val bytesPerBlock = 34 |
| 31 | + |
| 32 | + private fun randomQ8_0Bytes(blocksPerInputDim: Int, outputDim: Int, seed: Int): ByteArray { |
| 33 | + val rng = Random(seed) |
| 34 | + val numBlocks = blocksPerInputDim * outputDim |
| 35 | + val bytes = ByteArray(numBlocks * bytesPerBlock) |
| 36 | + rng.nextBytes(bytes) |
| 37 | + for (block in 0 until numBlocks) { |
| 38 | + val base = block * bytesPerBlock |
| 39 | + // FP16 scale ≈ 7.6e-3 (low-bit FP16 0x2200) — safely finite, non-zero. |
| 40 | + bytes[base + 0] = 0x00.toByte() |
| 41 | + bytes[base + 1] = 0x22.toByte() |
| 42 | + } |
| 43 | + return bytes |
| 44 | + } |
| 45 | + |
| 46 | + private fun ScalarQ8_0_reference( |
| 47 | + input: FloatArray, weight: ByteArray, |
| 48 | + inputDim: Int, outputDim: Int, |
| 49 | + batchSize: Int, |
| 50 | + ): FloatArray { |
| 51 | + val out = FloatArray(batchSize * outputDim) |
| 52 | + for (b in 0 until batchSize) { |
| 53 | + ScalarQ8_0MatmulKernel.matmul( |
| 54 | + input, b * inputDim, |
| 55 | + weight, 0, |
| 56 | + inputDim, outputDim, |
| 57 | + out, b * outputDim, |
| 58 | + ) |
| 59 | + } |
| 60 | + return out |
| 61 | + } |
| 62 | + |
| 63 | + @Suppress("UNCHECKED_CAST") |
| 64 | + private fun q8_0Tensor(inputDim: Int, outputDim: Int, seed: Int): Tensor<FP32, Float> { |
| 65 | + val blocksPerInputDim = inputDim / blockSize |
| 66 | + val bytes = randomQ8_0Bytes(blocksPerInputDim, outputDim, seed) |
| 67 | + // Logical shape of a Q8_0 weight tensor is [inputDim, outputDim]. |
| 68 | + val data = Q8_0BlockTensorData(Shape(inputDim, outputDim), bytes) |
| 69 | + return ctx.fromData(data as TensorData<FP32, Float>, FP32::class) |
| 70 | + } |
| 71 | + |
| 72 | + private fun assertDispatchMatchesScalar( |
| 73 | + batchSize: Int, inputDim: Int, outputDim: Int, seed: Int, |
| 74 | + tolPerBlock: Float = 1e-2f, |
| 75 | + ) { |
| 76 | + val rng = Random(seed) |
| 77 | + val inputFloats = FloatArray(batchSize * inputDim) { rng.nextFloat() - 0.5f } |
| 78 | + val blocksPerInputDim = inputDim / blockSize |
| 79 | + |
| 80 | + val weightBytes = randomQ8_0Bytes(blocksPerInputDim, outputDim, seed) |
| 81 | + val weight = q8_0Tensor(inputDim, outputDim, seed).let { t -> |
| 82 | + // q8_0Tensor regenerates bytes from seed → use the SAME byte buffer |
| 83 | + // for the scalar reference path so the comparison is honest. |
| 84 | + @Suppress("UNCHECKED_CAST") |
| 85 | + val td = Q8_0BlockTensorData(Shape(inputDim, outputDim), weightBytes) as TensorData<FP32, Float> |
| 86 | + ctx.fromData(td, FP32::class) |
| 87 | + } |
| 88 | + val input = ctx.fromFloatArray<FP32, Float>( |
| 89 | + Shape(batchSize, inputDim), FP32::class, inputFloats, |
| 90 | + ) |
| 91 | + |
| 92 | + val out = ctx.ops.matmul(input, weight) |
| 93 | + val outArr = out.data.copyToFloatArray() |
| 94 | + |
| 95 | + val expected = ScalarQ8_0_reference(inputFloats, weightBytes, inputDim, outputDim, batchSize) |
| 96 | + |
| 97 | + val tol = (tolPerBlock * blocksPerInputDim.coerceAtLeast(1)).coerceAtLeast(tolPerBlock) |
| 98 | + for (i in expected.indices) { |
| 99 | + val diff = abs(expected[i] - outArr[i]) |
| 100 | + assertTrue( |
| 101 | + diff <= tol, |
| 102 | + "dispatch mismatch at $i: expected=${expected[i]} got=${outArr[i]} diff=$diff tol=$tol", |
| 103 | + ) |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + @Test |
| 108 | + fun single_batch_matmul_against_q8_0_weight_routes_correctly() { |
| 109 | + // batchSize=1 hits the optimized "no copyOfRange" branch in chooseQuantizedMatmul. |
| 110 | + assertDispatchMatchesScalar(batchSize = 1, inputDim = 128, outputDim = 64, seed = 1) |
| 111 | + } |
| 112 | + |
| 113 | + @Test |
| 114 | + fun multi_batch_matmul_against_q8_0_weight_routes_correctly() { |
| 115 | + // batchSize>1 exercises the per-row copyOfRange branch. |
| 116 | + assertDispatchMatchesScalar(batchSize = 3, inputDim = 256, outputDim = 32, seed = 2) |
| 117 | + } |
| 118 | + |
| 119 | + @Test |
| 120 | + fun llm_typical_attention_proj_matmul_routes_correctly() { |
| 121 | + // Realistic attention-projection size (matvec at dim×dim). |
| 122 | + assertDispatchMatchesScalar(batchSize = 1, inputDim = 512, outputDim = 512, seed = 3) |
| 123 | + } |
| 124 | +} |
0 commit comments