Skip to content

Commit 586db77

Browse files
Merge pull request #709 from SKaiNET-developers/feature/708-q5-packed-matmul-kernels
feat(backend-cpu): packed Q5_1 / Q5_0 matmul kernels + lazy transpose
2 parents 0fe0b2c + d9f760f commit 586db77

6 files changed

Lines changed: 575 additions & 0 deletions

File tree

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ import sk.ainet.lang.tensor.data.Q4_KBlockTensorData
3131
import sk.ainet.lang.tensor.data.Q4_KTensorData
3232
import sk.ainet.lang.tensor.data.Q6_KBlockTensorData
3333
import sk.ainet.lang.tensor.data.Q6_KTensorData
34+
import sk.ainet.lang.tensor.data.Q5_1BlockTensorData
35+
import sk.ainet.lang.tensor.data.Q5_1TensorData
36+
import sk.ainet.lang.tensor.data.Q5_0BlockTensorData
37+
import sk.ainet.lang.tensor.data.Q5_0TensorData
3438
import sk.ainet.lang.tensor.data.TensorData
3539
import sk.ainet.lang.types.DType
3640
import sk.ainet.lang.types.FP16
@@ -224,6 +228,21 @@ internal class DefaultCpuOpsJvm(
224228
@Suppress("UNCHECKED_CAST")
225229
return newTensor(transposed as TensorData<T, V>, tensor.dtype, tensor)
226230
}
231+
// Q5_1 / Q5_0 packed bytes use a row-major `[out, in]` layout that the
232+
// `matmulQ5_1Vec` / `matmulQ5_0Vec` kernels index by output row, so the
233+
// transpose is a pure shape swap — the same bytes give the right values
234+
// under the swapped shape (lets `ops.matmul(x, ops.transpose(W))` run
235+
// without a dequant round-trip).
236+
if (data is Q5_1TensorData) {
237+
val transposed = Q5_1BlockTensorData(Shape(cols, rows), data.packedData)
238+
@Suppress("UNCHECKED_CAST")
239+
return newTensor(transposed as TensorData<T, V>, tensor.dtype, tensor)
240+
}
241+
if (data is Q5_0TensorData) {
242+
val transposed = Q5_0BlockTensorData(Shape(cols, rows), data.packedData)
243+
@Suppress("UNCHECKED_CAST")
244+
return newTensor(transposed as TensorData<T, V>, tensor.dtype, tensor)
245+
}
227246
// MemorySegment FP32 fast path: physical transpose via SIMD.
228247
// Uses Arena.ofAuto() so the result segment is reclaimed by GC
229248
// when the wrapping Tensor is no longer reachable. Earlier
@@ -558,6 +577,32 @@ internal class DefaultCpuOpsJvm(
558577
@Suppress("UNCHECKED_CAST")
559578
CpuTensor(outData as TensorData<T, V>, this, a.dtype)
560579
}
580+
is Q5_1TensorData -> {
581+
val outBuffer = FloatArray(batchSize * outputDim)
582+
for (batch in 0 until batchSize) {
583+
val batchInput = if (batchSize == 1) inputBuffer
584+
else inputBuffer.copyOfRange(batch * inputDim, (batch + 1) * inputDim)
585+
JvmQuantizedVectorKernels.matmulQ5_1Vec(
586+
batchInput, bData.packedData, inputDim, outputDim, outBuffer, batch * outputDim,
587+
)
588+
}
589+
val outData = DenseFloatArrayTensorData<T>(Shape(batchSize, outputDim), outBuffer)
590+
@Suppress("UNCHECKED_CAST")
591+
CpuTensor(outData as TensorData<T, V>, this, a.dtype)
592+
}
593+
is Q5_0TensorData -> {
594+
val outBuffer = FloatArray(batchSize * outputDim)
595+
for (batch in 0 until batchSize) {
596+
val batchInput = if (batchSize == 1) inputBuffer
597+
else inputBuffer.copyOfRange(batch * inputDim, (batch + 1) * inputDim)
598+
JvmQuantizedVectorKernels.matmulQ5_0Vec(
599+
batchInput, bData.packedData, inputDim, outputDim, outBuffer, batch * outputDim,
600+
)
601+
}
602+
val outData = DenseFloatArrayTensorData<T>(Shape(batchSize, outputDim), outBuffer)
603+
@Suppress("UNCHECKED_CAST")
604+
CpuTensor(outData as TensorData<T, V>, this, a.dtype)
605+
}
561606
is Q4_KTensorData -> {
562607
val outBuffer = FloatArray(batchSize * outputDim)
563608
val spiKernel = q4kMatmulKernel

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,4 +909,98 @@ internal object JvmQuantizedVectorKernels {
909909
output[outputOffset + o] = accVec.reduceLanes(VectorOperators.ADD) + accScalar
910910
}
911911
}
912+
913+
/**
914+
* Q5_1 matrix-vector multiply: `output = input · Wᵀ` for a packed Q5_1 weight.
915+
*
916+
* Packed weights are in the natural GGUF **row-major** `[outputDim, inputDim]`
917+
* layout: output row `o`'s `inputDim` weights are `inputDim / 32` contiguous
918+
* 24-byte blocks. Dequant matches `DequantOps.dequantQ5_1FromBytes` exactly:
919+
* `w = d * (code + (highBit shl 4)) + m`. Scalar (keeps weights packed — the
920+
* memory win; SIMD vectorization of the inner loop is a follow-up).
921+
*/
922+
fun matmulQ5_1Vec(
923+
input: FloatArray,
924+
packedWeights: ByteArray,
925+
inputDim: Int,
926+
outputDim: Int,
927+
output: FloatArray,
928+
outputOffset: Int = 0,
929+
) {
930+
val bytesPerBlock = 24
931+
val blocksPerInputDim = (inputDim + 31) / 32
932+
for (o in 0 until outputDim) {
933+
var acc = 0f
934+
val rowBase = o * blocksPerInputDim * bytesPerBlock
935+
for (blk in 0 until blocksPerInputDim) {
936+
val base = rowBase + blk * bytesPerBlock
937+
val d = halfToFloat(((packedWeights[base + 1].toInt() and 0xFF) shl 8) or (packedWeights[base].toInt() and 0xFF))
938+
val m = halfToFloat(((packedWeights[base + 3].toInt() and 0xFF) shl 8) or (packedWeights[base + 2].toInt() and 0xFF))
939+
val qh = intArrayOf(
940+
packedWeights[base + 4].toInt() and 0xFF,
941+
packedWeights[base + 5].toInt() and 0xFF,
942+
packedWeights[base + 6].toInt() and 0xFF,
943+
packedWeights[base + 7].toInt() and 0xFF,
944+
)
945+
val qsBase = base + 8
946+
val inBase = blk * 32
947+
for (j in 0 until 16) {
948+
val q = packedWeights[qsBase + j].toInt() and 0xFF
949+
val lo = q and 0x0F
950+
val hi = q ushr 4
951+
val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01
952+
val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01
953+
val wLo = d * (lo + (bitLo shl 4)) + m
954+
val wHi = d * (hi + (bitHi shl 4)) + m
955+
acc += input[inBase + j] * wLo + input[inBase + 16 + j] * wHi
956+
}
957+
}
958+
output[outputOffset + o] = acc
959+
}
960+
}
961+
962+
/**
963+
* Q5_0 matrix-vector multiply: `output = input · Wᵀ` for a packed Q5_0 weight.
964+
*
965+
* Row-major `[outputDim, inputDim]` packing of 22-byte blocks. Dequant matches
966+
* `DequantOps.dequantQ5_0FromBytes`: `w = d * (code + (highBit shl 4) - 16)`.
967+
*/
968+
fun matmulQ5_0Vec(
969+
input: FloatArray,
970+
packedWeights: ByteArray,
971+
inputDim: Int,
972+
outputDim: Int,
973+
output: FloatArray,
974+
outputOffset: Int = 0,
975+
) {
976+
val bytesPerBlock = 22
977+
val blocksPerInputDim = (inputDim + 31) / 32
978+
for (o in 0 until outputDim) {
979+
var acc = 0f
980+
val rowBase = o * blocksPerInputDim * bytesPerBlock
981+
for (blk in 0 until blocksPerInputDim) {
982+
val base = rowBase + blk * bytesPerBlock
983+
val d = halfToFloat(((packedWeights[base + 1].toInt() and 0xFF) shl 8) or (packedWeights[base].toInt() and 0xFF))
984+
val qh = intArrayOf(
985+
packedWeights[base + 2].toInt() and 0xFF,
986+
packedWeights[base + 3].toInt() and 0xFF,
987+
packedWeights[base + 4].toInt() and 0xFF,
988+
packedWeights[base + 5].toInt() and 0xFF,
989+
)
990+
val qsBase = base + 6
991+
val inBase = blk * 32
992+
for (j in 0 until 16) {
993+
val q = packedWeights[qsBase + j].toInt() and 0xFF
994+
val lo = q and 0x0F
995+
val hi = q ushr 4
996+
val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01
997+
val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01
998+
val wLo = d * (lo + (bitLo shl 4) - 16)
999+
val wHi = d * (hi + (bitHi shl 4) - 16)
1000+
acc += input[inBase + j] * wLo + input[inBase + 16 + j] * wHi
1001+
}
1002+
}
1003+
output[outputOffset + o] = acc
1004+
}
1005+
}
9121006
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package sk.ainet.exec.tensor.ops
2+
3+
import kotlin.random.Random
4+
import kotlin.test.Test
5+
import kotlin.test.assertEquals
6+
import kotlin.test.assertTrue
7+
import sk.ainet.context.DirectCpuExecutionContext
8+
import sk.ainet.lang.tensor.Shape
9+
import sk.ainet.lang.tensor.Tensor
10+
import sk.ainet.lang.tensor.data.Q5_0BlockTensorData
11+
import sk.ainet.lang.tensor.data.Q5_1BlockTensorData
12+
import sk.ainet.lang.tensor.data.TensorData
13+
import sk.ainet.lang.types.FP32
14+
15+
/**
16+
* Validates the packed Q5_1 / Q5_0 matmul kernels + lazy transpose: feeding a packed
17+
* weight through `ops.matmul(x, ops.transpose(W))` must match feeding the FP32-dequantized
18+
* weight through the same path. The FP32 reference is dequantized inline (independent of the
19+
* `Q5_*BlockTensorData.dequantizeBlock` code under test), matching ggml / `DequantOps`.
20+
*/
21+
class Q5MatmulDispatchTest {
22+
23+
private val ctx = DirectCpuExecutionContext()
24+
25+
private fun f16(v: Float): Int {
26+
// float -> IEEE half bits (round-to-nearest-even, good enough for test weights)
27+
val bits = v.toRawBits()
28+
val sign = (bits ushr 16) and 0x8000
29+
var expo = ((bits ushr 23) and 0xFF) - 127 + 15
30+
val mant = bits and 0x7FFFFF
31+
if (expo <= 0) return sign // flush tiny to signed zero
32+
if (expo >= 31) return sign or 0x7C00 // inf
33+
return sign or (expo shl 10) or (mant ushr 13)
34+
}
35+
36+
private fun halfToFloat(h: Int): Float {
37+
val sign = (h and 0x8000) shl 16
38+
val exp = (h and 0x7C00) shr 10
39+
val mant = h and 0x03FF
40+
return when (exp) {
41+
0 -> Float.fromBits(sign) // (subnormals flushed by f16() above)
42+
31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13))
43+
else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13))
44+
}
45+
}
46+
47+
// --- Q5_1: 24 bytes/block (d, m, qh[4], qs[16]) ---------------------------------------
48+
49+
private fun randomQ5_1Block(rng: Random, out: ByteArray, off: Int) {
50+
val d = f16(0.02f + rng.nextFloat() * 0.05f)
51+
val m = f16(-0.3f + rng.nextFloat() * 0.6f)
52+
out[off] = (d and 0xFF).toByte(); out[off + 1] = ((d ushr 8) and 0xFF).toByte()
53+
out[off + 2] = (m and 0xFF).toByte(); out[off + 3] = ((m ushr 8) and 0xFF).toByte()
54+
for (k in 0 until 4) out[off + 4 + k] = rng.nextInt(256).toByte() // qh
55+
for (k in 0 until 16) out[off + 8 + k] = rng.nextInt(256).toByte() // qs
56+
}
57+
58+
private fun dequantQ5_1Block(b: ByteArray, off: Int, dst: FloatArray, dstOff: Int) {
59+
val d = halfToFloat(((b[off + 1].toInt() and 0xFF) shl 8) or (b[off].toInt() and 0xFF))
60+
val m = halfToFloat(((b[off + 3].toInt() and 0xFF) shl 8) or (b[off + 2].toInt() and 0xFF))
61+
val qh = intArrayOf(b[off + 4].toInt() and 0xFF, b[off + 5].toInt() and 0xFF, b[off + 6].toInt() and 0xFF, b[off + 7].toInt() and 0xFF)
62+
for (j in 0 until 16) {
63+
val q = b[off + 8 + j].toInt() and 0xFF
64+
val lo = q and 0x0F; val hi = q ushr 4
65+
val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01
66+
val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01
67+
dst[dstOff + j] = d * (lo + (bitLo shl 4)) + m
68+
dst[dstOff + 16 + j] = d * (hi + (bitHi shl 4)) + m
69+
}
70+
}
71+
72+
// --- Q5_0: 22 bytes/block (d, qh[4], qs[16]), symmetric -16 --------------------------
73+
74+
private fun randomQ5_0Block(rng: Random, out: ByteArray, off: Int) {
75+
val d = f16(0.02f + rng.nextFloat() * 0.05f)
76+
out[off] = (d and 0xFF).toByte(); out[off + 1] = ((d ushr 8) and 0xFF).toByte()
77+
for (k in 0 until 4) out[off + 2 + k] = rng.nextInt(256).toByte()
78+
for (k in 0 until 16) out[off + 6 + k] = rng.nextInt(256).toByte()
79+
}
80+
81+
private fun dequantQ5_0Block(b: ByteArray, off: Int, dst: FloatArray, dstOff: Int) {
82+
val d = halfToFloat(((b[off + 1].toInt() and 0xFF) shl 8) or (b[off].toInt() and 0xFF))
83+
val qh = intArrayOf(b[off + 2].toInt() and 0xFF, b[off + 3].toInt() and 0xFF, b[off + 4].toInt() and 0xFF, b[off + 5].toInt() and 0xFF)
84+
for (j in 0 until 16) {
85+
val q = b[off + 6 + j].toInt() and 0xFF
86+
val lo = q and 0x0F; val hi = q ushr 4
87+
val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01
88+
val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01
89+
dst[dstOff + j] = d * (lo + (bitLo shl 4) - 16)
90+
dst[dstOff + 16 + j] = d * (hi + (bitHi shl 4) - 16)
91+
}
92+
}
93+
94+
private fun assertPackedMatchesFp32(
95+
encoding: String, inputDim: Int, outputDim: Int, batchSize: Int, seed: Int,
96+
) {
97+
val rng = Random(seed)
98+
val blocksPerRow = inputDim / 32
99+
val bytesPerBlock = if (encoding == "Q5_1") 24 else 22
100+
val bytes = ByteArray(outputDim * blocksPerRow * bytesPerBlock)
101+
val wf = FloatArray(outputDim * inputDim) // row-major [out, in]
102+
for (o in 0 until outputDim) {
103+
for (blk in 0 until blocksPerRow) {
104+
val off = (o * blocksPerRow + blk) * bytesPerBlock
105+
val dstOff = o * inputDim + blk * 32
106+
if (encoding == "Q5_1") { randomQ5_1Block(rng, bytes, off); dequantQ5_1Block(bytes, off, wf, dstOff) }
107+
else { randomQ5_0Block(rng, bytes, off); dequantQ5_0Block(bytes, off, wf, dstOff) }
108+
}
109+
}
110+
111+
val packed: Tensor<FP32, Float> = if (encoding == "Q5_1")
112+
ctx.fromData(Q5_1BlockTensorData(Shape(outputDim, inputDim), bytes) as TensorData<FP32, Float>, FP32::class)
113+
else
114+
ctx.fromData(Q5_0BlockTensorData(Shape(outputDim, inputDim), bytes) as TensorData<FP32, Float>, FP32::class)
115+
val fp32 = ctx.fromFloatArray<FP32, Float>(Shape(outputDim, inputDim), FP32::class, wf)
116+
117+
val input = ctx.fromFloatArray<FP32, Float>(
118+
Shape(batchSize, inputDim), FP32::class, FloatArray(batchSize * inputDim) { (rng.nextFloat() - 0.5f) },
119+
)
120+
val outPacked = ctx.ops.matmul(input, ctx.ops.transpose(packed)).data.copyToFloatArray()
121+
val outFp32 = ctx.ops.matmul(input, ctx.ops.transpose(fp32)).data.copyToFloatArray()
122+
123+
assertEquals(outFp32.size, outPacked.size, "$encoding output size")
124+
var maxErr = 0f
125+
for (i in outFp32.indices) maxErr = maxOf(maxErr, kotlin.math.abs(outFp32[i] - outPacked[i]))
126+
assertTrue(maxErr < 1e-3f, "$encoding packed matmul deviates from FP32 dequant: maxErr=$maxErr")
127+
}
128+
129+
@Test fun q5_1_matmul_matches_fp32_dequant_single_batch() =
130+
assertPackedMatchesFp32("Q5_1", inputDim = 128, outputDim = 64, batchSize = 1, seed = 1)
131+
132+
@Test fun q5_1_matmul_matches_fp32_dequant_multi_batch() =
133+
assertPackedMatchesFp32("Q5_1", inputDim = 256, outputDim = 96, batchSize = 3, seed = 2)
134+
135+
@Test fun q5_0_matmul_matches_fp32_dequant_single_batch() =
136+
assertPackedMatchesFp32("Q5_0", inputDim = 128, outputDim = 64, batchSize = 1, seed = 3)
137+
138+
@Test fun q5_0_matmul_matches_fp32_dequant_multi_batch() =
139+
assertPackedMatchesFp32("Q5_0", inputDim = 192, outputDim = 48, batchSize = 2, seed = 4)
140+
}

0 commit comments

Comments
 (0)