Skip to content

Commit 101e833

Browse files
Merge pull request #606 from SKaiNET-developers/feature/matmul-q8-0
Q8_0 matmul: add Q8_0MatmulKernel + scalar/Panama/native implementations
2 parents 65cb964 + e088933 commit 101e833

14 files changed

Lines changed: 862 additions & 0 deletions

File tree

skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,10 @@ public interface KernelProvider {
6060
* to the next provider when this one returns `null`.
6161
*/
6262
public fun matmulBf16(): Bf16MatmulKernel? = null
63+
64+
/**
65+
* F32 × Q8_0 matmul kernel exposed by this provider, or `null` if
66+
* this provider does not specialize Q8_0. Same fall-through pattern.
67+
*/
68+
public fun matmulQ8_0(): Q8_0MatmulKernel? = null
6369
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package sk.ainet.backend.api.kernel
2+
3+
/**
4+
* F32 input × Q8_0-packed weights matrix-vector multiply, in canonical
5+
* ggml block layout.
6+
*
7+
* output[outputOffset + o] = Σ_j input[inputOffset + j] · dequant(weight[o, j])
8+
* for j ∈ [0, inputDim), o ∈ [0, outputDim)
9+
*
10+
* Block layout (32-element block, 34 bytes/block; see
11+
* [sk.ainet.lang.tensor.data.Q8_0BlockTensorData] kdoc):
12+
* - bytes 0..1 : `d` (block scale, FP16 LE)
13+
* - bytes 2..33 : 32 bytes of int8 codes (signed)
14+
*
15+
* Per element: `dequant = code * d`.
16+
*
17+
* Q8_0 has no per-block min / offset — simpler than Q4_K. Accumulation
18+
* is a straight FMA chain after dequantising the 32 signed int8 codes
19+
* for each block; the scale broadcasts across all 32 lanes.
20+
*
21+
* Implementations MUST NOT mutate `input` or `weight`. They MAY assume
22+
* the arrays do not alias each other or `output`. They MUST fully
23+
* write the `outputDim` floats starting at `output[outputOffset]`.
24+
*
25+
* Packed-weight row-major contract: `weight` holds blocks laid out
26+
* `(blockIdx * outputDim + o) * 34` for output row `o` and input
27+
* block index `blockIdx`. This matches `Q8_0BlockTensorData.packedData`.
28+
*
29+
* `inputDim` MUST be a multiple of 32 (the Q8_0 block size).
30+
*/
31+
public interface Q8_0MatmulKernel {
32+
/**
33+
* @param input FP32 input vector (single row).
34+
* @param inputOffset element offset into [input] where the row starts.
35+
* @param weight packed Q8_0 bytes for the full `outputDim × inputDim` weight tensor.
36+
* @param weightByteOffset byte offset into [weight] where block (0, 0) starts.
37+
* @param inputDim contraction dimension (must be a multiple of 32).
38+
* @param outputDim number of output cells.
39+
* @param output FP32 output vector.
40+
* @param outputOffset element offset into [output] where the row starts.
41+
*/
42+
public fun matmul(
43+
input: FloatArray, inputOffset: Int,
44+
weight: ByteArray, weightByteOffset: Int,
45+
inputDim: Int, outputDim: Int,
46+
output: FloatArray, outputOffset: Int,
47+
)
48+
}

skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarKernelProvider.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sk.ainet.exec.kernel
33
import sk.ainet.backend.api.kernel.Bf16MatmulKernel
44
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
55
import sk.ainet.backend.api.kernel.KernelProvider
6+
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
67

78
/**
89
* Scalar (non-SIMD) [KernelProvider] — always available, lowest
@@ -23,4 +24,5 @@ public object ScalarKernelProvider : KernelProvider {
2324
override fun isAvailable(): Boolean = true
2425
override fun matmulFp32(): Fp32MatmulKernel = ScalarMatmulKernel
2526
override fun matmulBf16(): Bf16MatmulKernel = ScalarBf16MatmulKernel
27+
override fun matmulQ8_0(): Q8_0MatmulKernel = ScalarQ8_0MatmulKernel
2628
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package sk.ainet.exec.kernel
2+
3+
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
4+
5+
/**
6+
* Scalar reference implementation of [Q8_0MatmulKernel] — straight
7+
* per-block dequant + per-element FMA, no SIMD. Always available on
8+
* every KMP target. Used as:
9+
*
10+
* - The correctness reference that accelerated kernels (Panama Vector,
11+
* native FFM) must match within FP order tolerance.
12+
* - A guaranteed fallback when no accelerated provider is registered.
13+
*
14+
* Block layout (32-element block, 34 bytes):
15+
* - bytes 0..1 : FP16 little-endian scale (`d`)
16+
* - bytes 2..33: 32 signed int8 codes
17+
*
18+
* Dequant per element: `code * d`. No min / offset.
19+
*
20+
* Performance is intentionally modest; production paths should pick the
21+
* Panama Vector or native variant via the kernel registry.
22+
*/
23+
public object ScalarQ8_0MatmulKernel : Q8_0MatmulKernel {
24+
25+
private const val BLOCK_SIZE = 32
26+
private const val BYTES_PER_BLOCK = 34
27+
28+
override fun matmul(
29+
input: FloatArray, inputOffset: Int,
30+
weight: ByteArray, weightByteOffset: Int,
31+
inputDim: Int, outputDim: Int,
32+
output: FloatArray, outputOffset: Int,
33+
) {
34+
require(inputDim % BLOCK_SIZE == 0) {
35+
"ScalarQ8_0MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim"
36+
}
37+
if (outputDim == 0 || inputDim == 0) {
38+
if (outputDim > 0) {
39+
for (o in 0 until outputDim) output[outputOffset + o] = 0f
40+
}
41+
return
42+
}
43+
val blocksPerInputDim = inputDim / BLOCK_SIZE
44+
45+
for (o in 0 until outputDim) {
46+
var acc = 0f
47+
for (blockIdx in 0 until blocksPerInputDim) {
48+
val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK
49+
// FP16 scale: two LE bytes.
50+
val dBits = (weight[blockBase].toInt() and 0xFF) or
51+
((weight[blockBase + 1].toInt() and 0xFF) shl 8)
52+
val d = halfToFloat(dBits)
53+
// 32 int8 codes, blockIdx-th window of the input vector.
54+
val inputBase = inputOffset + blockIdx * BLOCK_SIZE
55+
val codesBase = blockBase + 2
56+
for (k in 0 until BLOCK_SIZE) {
57+
val code = weight[codesBase + k].toInt() // signed
58+
acc += input[inputBase + k] * code * d
59+
}
60+
}
61+
output[outputOffset + o] = acc
62+
}
63+
}
64+
65+
/**
66+
* Convert a 16-bit IEEE-754 half-precision value (low 16 bits of
67+
* [hbits]) to FP32. Mirrors the helper inside
68+
* `sk.ainet.lang.tensor.data.Q4_KBlockTensorData.halfToFloat`,
69+
* which is internal to skainet-lang-core and can't be imported
70+
* from this module. Inlined here as the single non-trivial piece
71+
* of Q8_0 dequant.
72+
*/
73+
private fun halfToFloat(hbits: Int): Float {
74+
val sign = (hbits and 0x8000) shl 16
75+
val exp = (hbits and 0x7C00) shr 10
76+
val mant = hbits and 0x03FF
77+
return when (exp) {
78+
0 -> {
79+
if (mant == 0) Float.fromBits(sign)
80+
else {
81+
var m = mant
82+
var e = -14
83+
while ((m and 0x400) == 0) {
84+
m = m shl 1
85+
e--
86+
}
87+
m = m and 0x3FF
88+
Float.fromBits(sign or ((e + 127) shl 23) or (m shl 13))
89+
}
90+
}
91+
31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13))
92+
else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13))
93+
}
94+
}
95+
}

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorKernelProvider.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import sk.ainet.backend.api.kernel.Bf16MatmulKernel
44
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
55
import sk.ainet.backend.api.kernel.KernelProvider
66
import sk.ainet.backend.api.kernel.Q4KMatmulKernel
7+
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
78
import sk.ainet.exec.tensor.ops.JvmCpuBackendConfig
89

910
/**
@@ -45,6 +46,9 @@ public object PanamaVectorKernelProvider : KernelProvider {
4546
override fun matmulBf16(): Bf16MatmulKernel? =
4647
if (isAvailable()) PanamaVectorBf16MatmulKernel else null
4748

49+
override fun matmulQ8_0(): Q8_0MatmulKernel? =
50+
if (isAvailable()) PanamaVectorQ8_0MatmulKernel else null
51+
4852
private fun isVectorApiClassLoaded(): Boolean = runCatching {
4953
Class.forName("jdk.incubator.vector.FloatVector")
5054
Class.forName("jdk.incubator.vector.VectorSpecies")
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package sk.ainet.exec.kernel
2+
3+
import jdk.incubator.vector.ByteVector
4+
import jdk.incubator.vector.FloatVector
5+
import jdk.incubator.vector.VectorOperators
6+
import jdk.incubator.vector.VectorSpecies
7+
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
8+
9+
/**
10+
* SIMD-vectorized FP32 × Q8_0 matmul on the JDK Vector API.
11+
*
12+
* Pipeline per 32-element block:
13+
* 1. Decode the 2-byte FP16 scale `d` once.
14+
* 2. Walk the 32 signed int8 codes in `floatSpecies.length()`-sized
15+
* chunks. Each chunk: one ByteVector load, one `castShape` to
16+
* FloatVector (signed widening — int8 codes become small floats
17+
* in [-128, 127]), one `FloatVector.fma(input, codes, blockAcc)`
18+
* into a lane-wise block accumulator.
19+
* 3. Reduce the block accumulator across lanes (`reduceLanes(ADD)`)
20+
* and fold `* d` exactly once before adding to the running output
21+
* cell. Folding scale per-block (rather than per-element) avoids
22+
* 32 extra multiplies per block; the broadcast-and-FMA-with-scale
23+
* pattern would be wasteful here.
24+
*
25+
* Numerical equivalence with [ScalarQ8_0MatmulKernel] is within FMA +
26+
* reordered-reduction tolerance — the same bar Q4_K Panama uses.
27+
*/
28+
public object PanamaVectorQ8_0MatmulKernel : Q8_0MatmulKernel {
29+
30+
private const val BLOCK_SIZE = 32
31+
private const val BYTES_PER_BLOCK = 34
32+
33+
private val floatSpecies: VectorSpecies<Float> = FloatVector.SPECIES_PREFERRED
34+
35+
/** Byte species sized so `castShape(floatSpecies, 0)` consumes
36+
* `floatSpecies.length()` bytes — same convention as Q4_K. */
37+
private val byteSpeciesForFloat: VectorSpecies<Byte> = when (floatSpecies.length()) {
38+
16 -> ByteVector.SPECIES_128
39+
else -> ByteVector.SPECIES_64 // covers 4-wide (NEON) and 8-wide (AVX2)
40+
}
41+
42+
override fun matmul(
43+
input: FloatArray, inputOffset: Int,
44+
weight: ByteArray, weightByteOffset: Int,
45+
inputDim: Int, outputDim: Int,
46+
output: FloatArray, outputOffset: Int,
47+
) {
48+
require(inputDim % BLOCK_SIZE == 0) {
49+
"PanamaVectorQ8_0MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim"
50+
}
51+
if (outputDim == 0) return
52+
if (inputDim == 0) {
53+
for (o in 0 until outputDim) output[outputOffset + o] = 0f
54+
return
55+
}
56+
val blocksPerInputDim = inputDim / BLOCK_SIZE
57+
val laneCount = floatSpecies.length()
58+
59+
for (o in 0 until outputDim) {
60+
var acc = 0f
61+
for (blockIdx in 0 until blocksPerInputDim) {
62+
val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK
63+
// FP16 scale — two LE bytes.
64+
val dBits = (weight[blockBase].toInt() and 0xFF) or
65+
((weight[blockBase + 1].toInt() and 0xFF) shl 8)
66+
val d = halfToFloat(dBits)
67+
68+
val codesBase = blockBase + 2
69+
val inputBase = inputOffset + blockIdx * BLOCK_SIZE
70+
71+
var blockAccVec = FloatVector.zero(floatSpecies)
72+
var k = 0
73+
while (k < BLOCK_SIZE) {
74+
val byteVec = ByteVector.fromArray(byteSpeciesForFloat, weight, codesBase + k)
75+
@Suppress("UNCHECKED_CAST")
76+
val codesVec = byteVec.castShape(floatSpecies, 0) as FloatVector
77+
val inputVec = FloatVector.fromArray(floatSpecies, input, inputBase + k)
78+
blockAccVec = inputVec.fma(codesVec, blockAccVec)
79+
k += laneCount
80+
}
81+
acc += blockAccVec.reduceLanes(VectorOperators.ADD) * d
82+
}
83+
output[outputOffset + o] = acc
84+
}
85+
}
86+
87+
/** Same FP16 → FP32 conversion as [ScalarQ8_0MatmulKernel.halfToFloat]. */
88+
private fun halfToFloat(hbits: Int): Float {
89+
val sign = (hbits and 0x8000) shl 16
90+
val exp = (hbits and 0x7C00) shr 10
91+
val mant = hbits and 0x03FF
92+
return when (exp) {
93+
0 -> {
94+
if (mant == 0) Float.fromBits(sign)
95+
else {
96+
var m = mant
97+
var e = -14
98+
while ((m and 0x400) == 0) {
99+
m = m shl 1
100+
e--
101+
}
102+
m = m and 0x3FF
103+
Float.fromBits(sign or ((e + 127) shl 23) or (m shl 13))
104+
}
105+
}
106+
31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13))
107+
else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13))
108+
}
109+
}
110+
}

0 commit comments

Comments
 (0)