Skip to content

Commit 9cc73aa

Browse files
Merge pull request #562 from SKaiNET-developers/feature/jvm-q4k-simd-spi
feat(kernel): SIMD-fused Q4_K matmul kernel + Q4KMatmulKernel SPI
2 parents db00c95 + 8df65b8 commit 9cc73aa

7 files changed

Lines changed: 497 additions & 8 deletions

File tree

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package sk.ainet.bench
2+
3+
import java.util.concurrent.TimeUnit
4+
import kotlin.random.Random
5+
import org.openjdk.jmh.annotations.Benchmark
6+
import org.openjdk.jmh.annotations.BenchmarkMode
7+
import org.openjdk.jmh.annotations.Level
8+
import org.openjdk.jmh.annotations.Mode
9+
import org.openjdk.jmh.annotations.OutputTimeUnit
10+
import org.openjdk.jmh.annotations.Param
11+
import org.openjdk.jmh.annotations.Scope
12+
import org.openjdk.jmh.annotations.Setup
13+
import org.openjdk.jmh.annotations.State
14+
import sk.ainet.exec.kernel.PanamaVectorQ4KMatmulKernel
15+
16+
/**
17+
* F32-input × Q4_K-weight matmul bench: measures the SIMD-fused
18+
* Panama kernel ([PanamaVectorQ4KMatmulKernel]) at typical LLM matmul
19+
* shapes for Gemma 4 E2B Q4_K_M:
20+
* - 1024 x 1024 — small attention projection
21+
* - 4096 x 4096 — hidden→hidden / FFN gate
22+
* - 4096 x 1024 — hidden→KV slice
23+
*
24+
* Each `inputDim` must be a multiple of 256 (Q4_K block size). Packed
25+
* layout is input-block-major (`(blockIdx * outputDim + o) * 144`).
26+
*
27+
* Direct comparison vs the prior `JvmQuantizedVectorKernels.matmulQ4_KVec`
28+
* partial-vec implementation is via the parity test in
29+
* `PanamaVectorQ4KMatmulKernelTest`, which exercises both code paths.
30+
* The internal visibility of that legacy kernel keeps it out of the
31+
* cross-module bench harness.
32+
*/
33+
@State(Scope.Benchmark)
34+
@BenchmarkMode(Mode.AverageTime)
35+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
36+
open class QuantizedMatmulBench {
37+
38+
@Param("1024-1024", "4096-1024", "4096-4096")
39+
var shape: String = "4096-4096"
40+
41+
private var inputDim: Int = 0
42+
private var outputDim: Int = 0
43+
private lateinit var input: FloatArray
44+
private lateinit var packedWeights: ByteArray
45+
private lateinit var output: FloatArray
46+
47+
@Setup(Level.Trial)
48+
fun setup() {
49+
val parts = shape.split("-")
50+
inputDim = parts[0].toInt()
51+
outputDim = parts[1].toInt()
52+
require(inputDim % 256 == 0) { "inputDim must be multiple of 256, got $inputDim" }
53+
54+
val numBlocks = (inputDim / 256) * outputDim
55+
val rng = Random(42)
56+
packedWeights = ByteArray(numBlocks * 144)
57+
rng.nextBytes(packedWeights)
58+
// Force d / dMin per block to 1.0f16 (0x3C00) so dequantized
59+
// magnitudes stay within finite range for steady-state runs.
60+
for (block in 0 until numBlocks) {
61+
val base = block * 144
62+
packedWeights[base] = 0x00.toByte(); packedWeights[base + 1] = 0x3C.toByte()
63+
packedWeights[base + 2] = 0x00.toByte(); packedWeights[base + 3] = 0x3C.toByte()
64+
}
65+
input = FloatArray(inputDim) { ((it % 251) - 125).toFloat() / 127f }
66+
output = FloatArray(outputDim)
67+
}
68+
69+
@Benchmark
70+
fun matmul_q4k_panama(): FloatArray {
71+
PanamaVectorQ4KMatmulKernel.matmul(
72+
input, 0,
73+
packedWeights, 0,
74+
inputDim, outputDim,
75+
output, 0,
76+
)
77+
return output
78+
}
79+
}

skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,14 @@ public interface KernelProvider {
4242
* provider does not specialize matmul.
4343
*/
4444
public fun matmulFp32(): Fp32MatmulKernel?
45+
46+
/**
47+
* F32 × Q4_K matmul kernel exposed by this provider, or `null` if
48+
* this provider does not specialize Q4_K. Default returns `null`
49+
* so providers that pre-date this accessor (e.g. older custom
50+
* providers and the scalar reference) keep compiling without
51+
* change — callers cascade to a lower-priority provider that does
52+
* carry the kernel.
53+
*/
54+
public fun matmulQ4K(): Q4KMatmulKernel? = null
4555
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package sk.ainet.backend.api.kernel
2+
3+
/**
4+
* F32 input × Q4_K-packed weights matrix-vector multiply, in canonical
5+
* ggml super-block layout.
6+
*
7+
* output[outputOffset + o] = Σ_j input[inputOffset + j] · dequant(weight[o, j])
8+
* for j ∈ [0, inputDim), o ∈ [0, outputDim)
9+
*
10+
* Block layout (256-element super-block, 144 bytes/block; see
11+
* [sk.ainet.lang.tensor.data.Q4_KTensorData] kdoc for the byte map):
12+
* - bytes 0..1 : `d` (super-block scale, FP16 LE)
13+
* - bytes 2..3 : `dMin` (super-block min-scale, FP16 LE)
14+
* - bytes 4..15 : 12 bytes of packed (6-bit scaleIdx, 6-bit minIdx) for
15+
* 8 sub-blocks via ggml's `get_scale_min_k4` mixing
16+
* - bytes 16..143 : 128 bytes of 4-bit codes, *strided* in 4 groups of
17+
* 32 bytes — each byte's lo nibble belongs to one
18+
* sub-block and the hi nibble of the same byte
19+
* belongs to the *next* sub-block over the same
20+
* intra-group index.
21+
*
22+
* Per sub-block s ∈ 0..7:
23+
* `scale[s] = d * scaleIdx[s]`
24+
* `offset[s] = dMin * minIdx[s]`
25+
* per element: `dequant = code * scale[s] - offset[s]`
26+
*
27+
* The lazy-`dmin` accumulation trick (used by every well-tuned Q4_K
28+
* kernel including ggml's reference) avoids subtracting `offset` per
29+
* element by tracking `Σ(input · code)` and `Σ(input)` per sub-block
30+
* and combining as `scale * codeSum − offset * inputSum` once.
31+
*
32+
* Implementations MUST NOT mutate `input` or `weight`. They MAY assume
33+
* the arrays do not alias each other or `output`. They MUST fully
34+
* write the `outputDim` floats starting at `output[outputOffset]`.
35+
*
36+
* Packed-weight row-major contract: `weight` holds blocks laid out
37+
* `(blockIdx * outputDim + o) * 144` for output row `o` and input
38+
* block index `blockIdx`. This matches `Q4_KBlockTensorData.packedData`
39+
* and `JvmQuantizedVectorKernels.matmulQ4_KVec`.
40+
*
41+
* `inputDim` MUST be a multiple of 256 (the Q4_K block size).
42+
*/
43+
public interface Q4KMatmulKernel {
44+
/**
45+
* @param input FP32 input vector (single row).
46+
* @param inputOffset element offset into [input] where the row starts.
47+
* @param weight packed Q4_K bytes for the full `outputDim × inputDim` weight tensor.
48+
* @param weightByteOffset byte offset into [weight] where block (0, 0) starts.
49+
* @param inputDim contraction dimension (must be a multiple of 256).
50+
* @param outputDim number of output cells.
51+
* @param output FP32 output vector.
52+
* @param outputOffset element offset into [output] where the row starts.
53+
*/
54+
public fun matmul(
55+
input: FloatArray, inputOffset: Int,
56+
weight: ByteArray, weightByteOffset: Int,
57+
inputDim: Int, outputDim: Int,
58+
output: FloatArray, outputOffset: Int,
59+
)
60+
}

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorKernelProvider.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sk.ainet.exec.kernel
22

33
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
44
import sk.ainet.backend.api.kernel.KernelProvider
5+
import sk.ainet.backend.api.kernel.Q4KMatmulKernel
56
import sk.ainet.exec.tensor.ops.JvmCpuBackendConfig
67

78
/**
@@ -37,6 +38,9 @@ public object PanamaVectorKernelProvider : KernelProvider {
3738
override fun matmulFp32(): Fp32MatmulKernel? =
3839
if (isAvailable()) PanamaVectorMatmulKernel else null
3940

41+
override fun matmulQ4K(): Q4KMatmulKernel? =
42+
if (isAvailable()) PanamaVectorQ4KMatmulKernel else null
43+
4044
private fun isVectorApiClassLoaded(): Boolean = runCatching {
4145
Class.forName("jdk.incubator.vector.FloatVector")
4246
Class.forName("jdk.incubator.vector.VectorSpecies")
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
package sk.ainet.exec.kernel
2+
3+
import jdk.incubator.vector.ByteVector
4+
import jdk.incubator.vector.FloatVector
5+
import jdk.incubator.vector.VectorOperators
6+
import jdk.incubator.vector.VectorSpecies
7+
import sk.ainet.backend.api.kernel.Q4KMatmulKernel
8+
import sk.ainet.exec.tensor.ops.parallelChunks
9+
10+
/**
11+
* SIMD-vectorized Q4_K matmul on the JDK Vector API.
12+
*
13+
* Pipeline per 32-byte qs slab (which carries two adjacent sub-blocks
14+
* — sub-block `2j` in lo nibbles, sub-block `2j+1` in hi nibbles):
15+
* 1. `ByteVector.fromArray(byteSpeciesForFloat, weight, qsRegion+idx)` — single load.
16+
* 2. `loNibVec = byteVec.and(0x0F.toByte())`,
17+
* `hiNibVec = byteVec.lanewise(LSHR, 4)` — extract both nibbles.
18+
* 3. `castShape(floatSpecies, 0)` — widen + I2F.
19+
* 4. `inputVec.fma(codeFloatVec, codeAcc)` — accumulate `Σ(input·code)`
20+
* per sub-block; track `inputAcc = Σ(input)` separately for the
21+
* lazy-`dmin` correction.
22+
* 5. After all super-blocks for a given output cell, sum across
23+
* sub-blocks: `acc += scale[s] · codeSum[s] − offset[s] · inputSum[s]`
24+
* with `scale[s] = d · scaleIdx[s]` and `offset[s] = dMin · minIdx[s]`.
25+
*
26+
* Compared to [sk.ainet.exec.tensor.ops.JvmQuantizedVectorKernels.matmulQ4_KVec]:
27+
* - Replaces the scalar 32-iteration nibble unpack into a scratch
28+
* `FloatArray` with a single `ByteVector` load + `castShape` per
29+
* `floatSpecies.length()` elements.
30+
* - Folds lo + hi nibble passes into a single byte load (existing
31+
* helper called the byte-load helper twice — once per nibble).
32+
*
33+
* Numerical equivalence with the existing partial-vec kernel is
34+
* within FMA + reordered-reduction tolerance; verified via parity
35+
* tests at `1e-5 · inputDim`.
36+
*/
37+
public object PanamaVectorQ4KMatmulKernel : Q4KMatmulKernel {
38+
39+
private const val BLOCK_SIZE = 256
40+
private const val SUB_BLOCK_SIZE = 32
41+
private const val SUB_BLOCKS_PER_BLOCK = 8
42+
private const val BYTES_PER_BLOCK = 144
43+
44+
private val floatSpecies: VectorSpecies<Float> = FloatVector.SPECIES_PREFERRED
45+
46+
/**
47+
* Byte species sized so `castShape(floatSpecies, 0)` consumes
48+
* exactly `floatSpecies.length()` bytes — same convention as
49+
* [sk.ainet.exec.tensor.ops.JvmQuantizedVectorKernels.byteSpeciesForFloat].
50+
*/
51+
private val byteSpeciesForFloat: VectorSpecies<Byte> = when (floatSpecies.length()) {
52+
16 -> ByteVector.SPECIES_128
53+
else -> ByteVector.SPECIES_64 // covers 4-wide (NEON) and 8-wide (AVX2)
54+
}
55+
56+
override fun matmul(
57+
input: FloatArray, inputOffset: Int,
58+
weight: ByteArray, weightByteOffset: Int,
59+
inputDim: Int, outputDim: Int,
60+
output: FloatArray, outputOffset: Int,
61+
) {
62+
require(inputDim % BLOCK_SIZE == 0) {
63+
"PanamaVectorQ4KMatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim"
64+
}
65+
if (outputDim == 0 || inputDim == 0) return
66+
val blocksPerInputDim = inputDim / BLOCK_SIZE
67+
68+
parallelChunks(outputDim) { startO, endO ->
69+
// Per-task scratch — must not be shared across worker threads.
70+
val scaleIdx = IntArray(SUB_BLOCKS_PER_BLOCK)
71+
val minIdx = IntArray(SUB_BLOCKS_PER_BLOCK)
72+
for (o in startO until endO) {
73+
var acc = 0f
74+
for (blockIdx in 0 until blocksPerInputDim) {
75+
val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK
76+
77+
// d, dMin (FP16 LE).
78+
val dBits = (weight[blockBase + 1].toInt() and 0xFF shl 8) or
79+
(weight[blockBase].toInt() and 0xFF)
80+
val dMinBits = (weight[blockBase + 3].toInt() and 0xFF shl 8) or
81+
(weight[blockBase + 2].toInt() and 0xFF)
82+
val d = halfToFloat(dBits)
83+
val dMin = halfToFloat(dMinBits)
84+
85+
// Sub-scale decode via ggml `get_scale_min_k4`.
86+
val scalesOffset = blockBase + 4
87+
for (sb in 0 until 4) {
88+
scaleIdx[sb] = weight[scalesOffset + sb].toInt() and 0x3F
89+
minIdx[sb] = weight[scalesOffset + sb + 4].toInt() and 0x3F
90+
}
91+
for (sb in 4 until 8) {
92+
val low4S = weight[scalesOffset + sb + 4].toInt() and 0x0F
93+
val high2S = (weight[scalesOffset + sb - 4].toInt() and 0xFF) ushr 6
94+
scaleIdx[sb] = low4S or (high2S shl 4)
95+
val low4M = (weight[scalesOffset + sb + 4].toInt() and 0xFF) ushr 4
96+
val high2M = (weight[scalesOffset + sb].toInt() and 0xFF) ushr 6
97+
minIdx[sb] = low4M or (high2M shl 4)
98+
}
99+
100+
// 4 strided qs groups; each carries sbLo (lo nibbles) and sbHi (hi nibbles).
101+
val codesOffset = blockBase + 16
102+
val inputBlockBase = inputOffset + blockIdx * BLOCK_SIZE
103+
for (groupJ in 0 until 4) {
104+
val qsRegion = codesOffset + groupJ * 32
105+
val sbLo = 2 * groupJ
106+
val sbHi = sbLo + 1
107+
val inputStartLo = inputBlockBase + sbLo * SUB_BLOCK_SIZE
108+
val inputStartHi = inputStartLo + SUB_BLOCK_SIZE
109+
110+
var codeAccLo = FloatVector.zero(floatSpecies)
111+
var inputAccLo = FloatVector.zero(floatSpecies)
112+
var codeAccHi = FloatVector.zero(floatSpecies)
113+
var inputAccHi = FloatVector.zero(floatSpecies)
114+
115+
val floatStep = floatSpecies.length()
116+
val byteLoadLen = byteSpeciesForFloat.length()
117+
var idx = 0
118+
119+
// SIMD body — single byte load feeds both nibble vectors.
120+
while (idx + floatStep <= SUB_BLOCK_SIZE &&
121+
qsRegion + idx + byteLoadLen <= weight.size
122+
) {
123+
val inVecLo = FloatVector.fromArray(floatSpecies, input, inputStartLo + idx)
124+
val inVecHi = FloatVector.fromArray(floatSpecies, input, inputStartHi + idx)
125+
val byteVec = ByteVector.fromArray(byteSpeciesForFloat, weight, qsRegion + idx)
126+
val loBytes = byteVec.and(0x0F.toByte())
127+
val hiBytes = byteVec.lanewise(VectorOperators.LSHR, 4.toByte())
128+
val codeVecLo = loBytes.castShape(floatSpecies, 0) as FloatVector
129+
val codeVecHi = hiBytes.castShape(floatSpecies, 0) as FloatVector
130+
codeAccLo = inVecLo.fma(codeVecLo, codeAccLo)
131+
inputAccLo = inVecLo.add(inputAccLo)
132+
codeAccHi = inVecHi.fma(codeVecHi, codeAccHi)
133+
inputAccHi = inVecHi.add(inputAccHi)
134+
idx += floatStep
135+
}
136+
137+
var codeSumLo = codeAccLo.reduceLanes(VectorOperators.ADD)
138+
var inputSumLo = inputAccLo.reduceLanes(VectorOperators.ADD)
139+
var codeSumHi = codeAccHi.reduceLanes(VectorOperators.ADD)
140+
var inputSumHi = inputAccHi.reduceLanes(VectorOperators.ADD)
141+
142+
// Scalar tail — only fires if floatSpecies.length() doesn't divide 32 (rare).
143+
while (idx < SUB_BLOCK_SIZE) {
144+
val byte = weight[qsRegion + idx].toInt() and 0xFF
145+
val codeLo = (byte and 0x0F).toFloat()
146+
val codeHi = (byte ushr 4).toFloat()
147+
val vLo = input[inputStartLo + idx]
148+
val vHi = input[inputStartHi + idx]
149+
codeSumLo += vLo * codeLo
150+
inputSumLo += vLo
151+
codeSumHi += vHi * codeHi
152+
inputSumHi += vHi
153+
idx++
154+
}
155+
156+
val scaleLo = d * scaleIdx[sbLo]
157+
val offsetLo = dMin * minIdx[sbLo]
158+
val scaleHi = d * scaleIdx[sbHi]
159+
val offsetHi = dMin * minIdx[sbHi]
160+
acc += codeSumLo * scaleLo - inputSumLo * offsetLo
161+
acc += codeSumHi * scaleHi - inputSumHi * offsetHi
162+
}
163+
}
164+
output[outputOffset + o] = acc
165+
}
166+
}
167+
}
168+
169+
/**
170+
* IEEE 754 binary16 → binary32 conversion. Mirrors the helper used
171+
* inside `JvmQuantizedVectorKernels` and `Q4_KTensorData` — kept
172+
* private to this file rather than depending on either, since both
173+
* are `internal` in their respective modules.
174+
*/
175+
private fun halfToFloat(hbits: Int): Float {
176+
val sign = (hbits ushr 15) and 0x1
177+
val exp = (hbits ushr 10) and 0x1F
178+
val frac = hbits and 0x3FF
179+
return when {
180+
exp == 0 -> {
181+
if (frac == 0) {
182+
if (sign == 0) 0.0f else -0.0f
183+
} else {
184+
val f = frac / 1024.0f * (1.0f / 16384.0f)
185+
if (sign == 0) f else -f
186+
}
187+
}
188+
exp == 0x1F -> {
189+
if (frac == 0) {
190+
if (sign == 0) Float.POSITIVE_INFINITY else Float.NEGATIVE_INFINITY
191+
} else {
192+
Float.NaN
193+
}
194+
}
195+
else -> {
196+
val bits = (sign shl 31) or ((exp - 15 + 127) shl 23) or (frac shl 13)
197+
Float.fromBits(bits)
198+
}
199+
}
200+
}
201+
}

0 commit comments

Comments
 (0)