|
| 1 | +package sk.ainet.exec.tensor.ops |
| 2 | + |
| 3 | +import kotlin.random.Random |
| 4 | +import kotlin.test.Test |
| 5 | +import kotlin.test.assertEquals |
| 6 | +import kotlin.test.assertTrue |
| 7 | +import sk.ainet.context.DirectCpuExecutionContext |
| 8 | +import sk.ainet.lang.tensor.Shape |
| 9 | +import sk.ainet.lang.tensor.Tensor |
| 10 | +import sk.ainet.lang.tensor.data.Q5_0BlockTensorData |
| 11 | +import sk.ainet.lang.tensor.data.Q5_1BlockTensorData |
| 12 | +import sk.ainet.lang.tensor.data.TensorData |
| 13 | +import sk.ainet.lang.types.FP32 |
| 14 | + |
| 15 | +/** |
| 16 | + * Validates the packed Q5_1 / Q5_0 matmul kernels + lazy transpose: feeding a packed |
| 17 | + * weight through `ops.matmul(x, ops.transpose(W))` must match feeding the FP32-dequantized |
| 18 | + * weight through the same path. The FP32 reference is dequantized inline (independent of the |
| 19 | + * `Q5_*BlockTensorData.dequantizeBlock` code under test), matching ggml / `DequantOps`. |
| 20 | + */ |
| 21 | +class Q5MatmulDispatchTest { |
| 22 | + |
| 23 | + private val ctx = DirectCpuExecutionContext() |
| 24 | + |
| 25 | + private fun f16(v: Float): Int { |
| 26 | + // float -> IEEE half bits (round-to-nearest-even, good enough for test weights) |
| 27 | + val bits = v.toRawBits() |
| 28 | + val sign = (bits ushr 16) and 0x8000 |
| 29 | + var expo = ((bits ushr 23) and 0xFF) - 127 + 15 |
| 30 | + val mant = bits and 0x7FFFFF |
| 31 | + if (expo <= 0) return sign // flush tiny to signed zero |
| 32 | + if (expo >= 31) return sign or 0x7C00 // inf |
| 33 | + return sign or (expo shl 10) or (mant ushr 13) |
| 34 | + } |
| 35 | + |
| 36 | + private fun halfToFloat(h: Int): Float { |
| 37 | + val sign = (h and 0x8000) shl 16 |
| 38 | + val exp = (h and 0x7C00) shr 10 |
| 39 | + val mant = h and 0x03FF |
| 40 | + return when (exp) { |
| 41 | + 0 -> Float.fromBits(sign) // (subnormals flushed by f16() above) |
| 42 | + 31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13)) |
| 43 | + else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13)) |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + // --- Q5_1: 24 bytes/block (d, m, qh[4], qs[16]) --------------------------------------- |
| 48 | + |
| 49 | + private fun randomQ5_1Block(rng: Random, out: ByteArray, off: Int) { |
| 50 | + val d = f16(0.02f + rng.nextFloat() * 0.05f) |
| 51 | + val m = f16(-0.3f + rng.nextFloat() * 0.6f) |
| 52 | + out[off] = (d and 0xFF).toByte(); out[off + 1] = ((d ushr 8) and 0xFF).toByte() |
| 53 | + out[off + 2] = (m and 0xFF).toByte(); out[off + 3] = ((m ushr 8) and 0xFF).toByte() |
| 54 | + for (k in 0 until 4) out[off + 4 + k] = rng.nextInt(256).toByte() // qh |
| 55 | + for (k in 0 until 16) out[off + 8 + k] = rng.nextInt(256).toByte() // qs |
| 56 | + } |
| 57 | + |
| 58 | + private fun dequantQ5_1Block(b: ByteArray, off: Int, dst: FloatArray, dstOff: Int) { |
| 59 | + val d = halfToFloat(((b[off + 1].toInt() and 0xFF) shl 8) or (b[off].toInt() and 0xFF)) |
| 60 | + val m = halfToFloat(((b[off + 3].toInt() and 0xFF) shl 8) or (b[off + 2].toInt() and 0xFF)) |
| 61 | + val qh = intArrayOf(b[off + 4].toInt() and 0xFF, b[off + 5].toInt() and 0xFF, b[off + 6].toInt() and 0xFF, b[off + 7].toInt() and 0xFF) |
| 62 | + for (j in 0 until 16) { |
| 63 | + val q = b[off + 8 + j].toInt() and 0xFF |
| 64 | + val lo = q and 0x0F; val hi = q ushr 4 |
| 65 | + val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01 |
| 66 | + val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01 |
| 67 | + dst[dstOff + j] = d * (lo + (bitLo shl 4)) + m |
| 68 | + dst[dstOff + 16 + j] = d * (hi + (bitHi shl 4)) + m |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + // --- Q5_0: 22 bytes/block (d, qh[4], qs[16]), symmetric -16 -------------------------- |
| 73 | + |
| 74 | + private fun randomQ5_0Block(rng: Random, out: ByteArray, off: Int) { |
| 75 | + val d = f16(0.02f + rng.nextFloat() * 0.05f) |
| 76 | + out[off] = (d and 0xFF).toByte(); out[off + 1] = ((d ushr 8) and 0xFF).toByte() |
| 77 | + for (k in 0 until 4) out[off + 2 + k] = rng.nextInt(256).toByte() |
| 78 | + for (k in 0 until 16) out[off + 6 + k] = rng.nextInt(256).toByte() |
| 79 | + } |
| 80 | + |
| 81 | + private fun dequantQ5_0Block(b: ByteArray, off: Int, dst: FloatArray, dstOff: Int) { |
| 82 | + val d = halfToFloat(((b[off + 1].toInt() and 0xFF) shl 8) or (b[off].toInt() and 0xFF)) |
| 83 | + val qh = intArrayOf(b[off + 2].toInt() and 0xFF, b[off + 3].toInt() and 0xFF, b[off + 4].toInt() and 0xFF, b[off + 5].toInt() and 0xFF) |
| 84 | + for (j in 0 until 16) { |
| 85 | + val q = b[off + 6 + j].toInt() and 0xFF |
| 86 | + val lo = q and 0x0F; val hi = q ushr 4 |
| 87 | + val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01 |
| 88 | + val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01 |
| 89 | + dst[dstOff + j] = d * (lo + (bitLo shl 4) - 16) |
| 90 | + dst[dstOff + 16 + j] = d * (hi + (bitHi shl 4) - 16) |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + private fun assertPackedMatchesFp32( |
| 95 | + encoding: String, inputDim: Int, outputDim: Int, batchSize: Int, seed: Int, |
| 96 | + ) { |
| 97 | + val rng = Random(seed) |
| 98 | + val blocksPerRow = inputDim / 32 |
| 99 | + val bytesPerBlock = if (encoding == "Q5_1") 24 else 22 |
| 100 | + val bytes = ByteArray(outputDim * blocksPerRow * bytesPerBlock) |
| 101 | + val wf = FloatArray(outputDim * inputDim) // row-major [out, in] |
| 102 | + for (o in 0 until outputDim) { |
| 103 | + for (blk in 0 until blocksPerRow) { |
| 104 | + val off = (o * blocksPerRow + blk) * bytesPerBlock |
| 105 | + val dstOff = o * inputDim + blk * 32 |
| 106 | + if (encoding == "Q5_1") { randomQ5_1Block(rng, bytes, off); dequantQ5_1Block(bytes, off, wf, dstOff) } |
| 107 | + else { randomQ5_0Block(rng, bytes, off); dequantQ5_0Block(bytes, off, wf, dstOff) } |
| 108 | + } |
| 109 | + } |
| 110 | + |
| 111 | + val packed: Tensor<FP32, Float> = if (encoding == "Q5_1") |
| 112 | + ctx.fromData(Q5_1BlockTensorData(Shape(outputDim, inputDim), bytes) as TensorData<FP32, Float>, FP32::class) |
| 113 | + else |
| 114 | + ctx.fromData(Q5_0BlockTensorData(Shape(outputDim, inputDim), bytes) as TensorData<FP32, Float>, FP32::class) |
| 115 | + val fp32 = ctx.fromFloatArray<FP32, Float>(Shape(outputDim, inputDim), FP32::class, wf) |
| 116 | + |
| 117 | + val input = ctx.fromFloatArray<FP32, Float>( |
| 118 | + Shape(batchSize, inputDim), FP32::class, FloatArray(batchSize * inputDim) { (rng.nextFloat() - 0.5f) }, |
| 119 | + ) |
| 120 | + val outPacked = ctx.ops.matmul(input, ctx.ops.transpose(packed)).data.copyToFloatArray() |
| 121 | + val outFp32 = ctx.ops.matmul(input, ctx.ops.transpose(fp32)).data.copyToFloatArray() |
| 122 | + |
| 123 | + assertEquals(outFp32.size, outPacked.size, "$encoding output size") |
| 124 | + var maxErr = 0f |
| 125 | + for (i in outFp32.indices) maxErr = maxOf(maxErr, kotlin.math.abs(outFp32[i] - outPacked[i])) |
| 126 | + assertTrue(maxErr < 1e-3f, "$encoding packed matmul deviates from FP32 dequant: maxErr=$maxErr") |
| 127 | + } |
| 128 | + |
| 129 | + @Test fun q5_1_matmul_matches_fp32_dequant_single_batch() = |
| 130 | + assertPackedMatchesFp32("Q5_1", inputDim = 128, outputDim = 64, batchSize = 1, seed = 1) |
| 131 | + |
| 132 | + @Test fun q5_1_matmul_matches_fp32_dequant_multi_batch() = |
| 133 | + assertPackedMatchesFp32("Q5_1", inputDim = 256, outputDim = 96, batchSize = 3, seed = 2) |
| 134 | + |
| 135 | + @Test fun q5_0_matmul_matches_fp32_dequant_single_batch() = |
| 136 | + assertPackedMatchesFp32("Q5_0", inputDim = 128, outputDim = 64, batchSize = 1, seed = 3) |
| 137 | + |
| 138 | + @Test fun q5_0_matmul_matches_fp32_dequant_multi_batch() = |
| 139 | + assertPackedMatchesFp32("Q5_0", inputDim = 192, outputDim = 48, batchSize = 2, seed = 4) |
| 140 | +} |
0 commit comments