Skip to content

Commit e4b16f8

Browse files
Merge pull request #649 from SKaiNET-developers/feature/q4_0-panama
feat(q4_0): Panama SIMD kernel + reconcile MemSeg to split layout
2 parents 5ff5a36 + 7c5a1c9 commit e4b16f8

13 files changed

Lines changed: 609 additions & 16 deletions

File tree

skainet-backends/skainet-backend-cpu/api/jvm/skainet-backend-cpu.api

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ public final class sk/ainet/exec/kernel/PanamaVectorQ4KMatmulKernel : sk/ainet/b
8181
public fun matmul ([FI[BIII[FI)V
8282
}
8383

84+
public final class sk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel : sk/ainet/backend/api/kernel/Q4_0MatmulKernel {
85+
public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel;
86+
public fun matmul ([FI[BIII[FI)V
87+
}
88+
8489
public final class sk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel : sk/ainet/backend/api/kernel/Q8_0MatmulKernel {
8590
public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel;
8691
public fun matmul ([FI[BIII[FI)V

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorKernelProvider.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import sk.ainet.backend.api.kernel.Bf16MatmulKernel
44
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
55
import sk.ainet.backend.api.kernel.KernelProvider
66
import sk.ainet.backend.api.kernel.Q4KMatmulKernel
7+
import sk.ainet.backend.api.kernel.Q4_0MatmulKernel
78
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
89
import sk.ainet.exec.tensor.ops.JvmCpuBackendConfig
910

@@ -49,6 +50,9 @@ public object PanamaVectorKernelProvider : KernelProvider {
4950
override fun matmulQ8_0(): Q8_0MatmulKernel? =
5051
if (isAvailable()) PanamaVectorQ8_0MatmulKernel else null
5152

53+
override fun matmulQ4_0(): Q4_0MatmulKernel? =
54+
if (isAvailable()) PanamaVectorQ4_0MatmulKernel else null
55+
5256
private fun isVectorApiClassLoaded(): Boolean = runCatching {
5357
Class.forName("jdk.incubator.vector.FloatVector")
5458
Class.forName("jdk.incubator.vector.VectorSpecies")
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package sk.ainet.exec.kernel
2+
3+
import jdk.incubator.vector.FloatVector
4+
import jdk.incubator.vector.VectorOperators
5+
import jdk.incubator.vector.VectorSpecies
6+
import sk.ainet.backend.api.kernel.Q4_0MatmulKernel
7+
8+
/**
9+
* SIMD-vectorized FP32 × Q4_0 matmul on the JDK Vector API.
10+
*
11+
* Pipeline per 32-element block:
12+
* 1. Decode the 2-byte FP16 scale `d` once.
13+
* 2. Unpack the 16 code bytes into 32 sign-corrected floats (`nibble - 8`)
14+
* in a reusable scratch buffer, using the canonical ggml **split**
15+
* layout (low nibbles → elements 0..15, high nibbles → 16..31). The
16+
* nibble-pair-per-byte packing makes a fully-fused `ByteVector`
17+
* pipeline awkward, so this kernel keeps the scratch-then-FMA shape
18+
* (same approach as the legacy `JvmQuantizedVectorKernels` Q4_0 path).
19+
* 3. SIMD-FMA the scratch against the matching input window into a
20+
* lane-wise block accumulator, reduce across lanes, and fold `* d`
21+
* exactly once per block.
22+
*
23+
* Numerical equivalence with [ScalarQ4_0MatmulKernel] is within FMA +
24+
* reordered-reduction tolerance — the same bar the Q8_0 / Q4_K Panama
25+
* kernels use.
26+
*/
27+
public object PanamaVectorQ4_0MatmulKernel : Q4_0MatmulKernel {
28+
29+
private const val BLOCK_SIZE = 32
30+
private const val BYTES_PER_BLOCK = 18
31+
32+
private val floatSpecies: VectorSpecies<Float> = FloatVector.SPECIES_PREFERRED
33+
34+
override fun matmul(
35+
input: FloatArray, inputOffset: Int,
36+
weight: ByteArray, weightByteOffset: Int,
37+
inputDim: Int, outputDim: Int,
38+
output: FloatArray, outputOffset: Int,
39+
) {
40+
require(inputDim % BLOCK_SIZE == 0) {
41+
"PanamaVectorQ4_0MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim"
42+
}
43+
if (outputDim == 0) return
44+
if (inputDim == 0) {
45+
for (o in 0 until outputDim) output[outputOffset + o] = 0f
46+
return
47+
}
48+
val blocksPerInputDim = inputDim / BLOCK_SIZE
49+
val step = floatSpecies.length()
50+
val loopBound = floatSpecies.loopBound(BLOCK_SIZE)
51+
val codeBuf = FloatArray(BLOCK_SIZE)
52+
53+
for (o in 0 until outputDim) {
54+
var acc = 0f
55+
for (blockIdx in 0 until blocksPerInputDim) {
56+
val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK
57+
// FP16 scale — two LE bytes.
58+
val dBits = (weight[blockBase].toInt() and 0xFF) or
59+
((weight[blockBase + 1].toInt() and 0xFF) shl 8)
60+
val d = halfToFloat(dBits)
61+
62+
// Split-layout unpack: low nibbles → 0..15, high → 16..31.
63+
val codesBase = blockBase + 2
64+
for (j in 0 until 16) {
65+
val b = weight[codesBase + j].toInt() and 0xFF
66+
codeBuf[j] = ((b and 0x0F) - 8).toFloat()
67+
codeBuf[16 + j] = ((b ushr 4) - 8).toFloat()
68+
}
69+
70+
val inputBase = inputOffset + blockIdx * BLOCK_SIZE
71+
var blockAccVec = FloatVector.zero(floatSpecies)
72+
var k = 0
73+
while (k < loopBound) {
74+
val inV = FloatVector.fromArray(floatSpecies, input, inputBase + k)
75+
val cV = FloatVector.fromArray(floatSpecies, codeBuf, k)
76+
blockAccVec = inV.fma(cV, blockAccVec)
77+
k += step
78+
}
79+
var blockAcc = blockAccVec.reduceLanes(VectorOperators.ADD)
80+
// Scalar tail (only if floatSpecies.length() doesn't divide 32 — rare).
81+
while (k < BLOCK_SIZE) {
82+
blockAcc += input[inputBase + k] * codeBuf[k]
83+
k++
84+
}
85+
acc += blockAcc * d
86+
}
87+
output[outputOffset + o] = acc
88+
}
89+
}
90+
91+
/** Same FP16 → FP32 conversion as [ScalarQ4_0MatmulKernel]. */
92+
private fun halfToFloat(hbits: Int): Float {
93+
val sign = (hbits and 0x8000) shl 16
94+
val exp = (hbits and 0x7C00) shr 10
95+
val mant = hbits and 0x03FF
96+
return when (exp) {
97+
0 -> {
98+
if (mant == 0) Float.fromBits(sign)
99+
else {
100+
var m = mant
101+
var e = -14
102+
while ((m and 0x400) == 0) {
103+
m = m shl 1
104+
e--
105+
}
106+
m = m and 0x3FF
107+
Float.fromBits(sign or ((e + 127) shl 23) or (m shl 13))
108+
}
109+
}
110+
31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13))
111+
else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13))
112+
}
113+
}
114+
}

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -549,13 +549,14 @@ internal object JvmQuantizedVectorKernels {
549549
// Read f16 scale
550550
val scale = halfToFloat(read2BytesLE(weightSeg, blockByteOffset))
551551

552-
// Unpack 16 packed bytes → 32 sign-corrected nibbles. Two
553-
// nibbles per byte load means half the byte traffic of the
554-
// straight scalar dot product.
552+
// Unpack 16 packed bytes → 32 sign-corrected nibbles in the
553+
// canonical ggml *split* layout: low nibbles decode elements
554+
// 0..15, high nibbles decode elements 16..31. (Matches
555+
// DequantOps.dequantQ4_0FromBytes and Q4_0BlockTensorData.)
555556
for (k in 0 until 16) {
556557
val b = weightSeg.get(JAVA_BYTE_LE, codesOffset + k.toLong()).toInt() and 0xFF
557-
codeBuf[2 * k] = (b and 0x0F).toFloat() - 8f
558-
codeBuf[2 * k + 1] = (b ushr 4).toFloat() - 8f
558+
codeBuf[k] = (b and 0x0F).toFloat() - 8f
559+
codeBuf[16 + k] = (b ushr 4).toFloat() - 8f
559560
}
560561

561562
// SIMD FMA dot product.
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package sk.ainet.exec.kernel
2+
3+
import kotlin.math.abs
4+
import kotlin.random.Random
5+
import kotlin.test.Test
6+
import kotlin.test.assertEquals
7+
import kotlin.test.assertFailsWith
8+
import kotlin.test.assertTrue
9+
10+
/**
11+
* Numerical parity tests for [PanamaVectorQ4_0MatmulKernel] against
12+
* [ScalarQ4_0MatmulKernel]. Both kernels apply the same FP16-scale
13+
* decode + `(nibble - 8)` dequant in the canonical ggml split layout;
14+
* differences come from FMA + reordered-reduction order only.
15+
*
16+
* Tolerance scales with the number of Q4_0 blocks processed: `1e-2 *
17+
* blocksPerInputDim`, clamped to a `1e-2` floor — mirrors the Q8_0
18+
* parity test convention.
19+
*/
20+
class PanamaVectorQ4_0MatmulKernelParityTest {
21+
22+
private val blockSize = 32
23+
private val bytesPerBlock = 18
24+
25+
/** Random Q4_0 packed bytes; scales clamped to a small positive FP16. */
26+
private fun randomQ4_0Bytes(blocksPerInputDim: Int, outputDim: Int, seed: Int): ByteArray {
27+
val rng = Random(seed)
28+
val numBlocks = blocksPerInputDim * outputDim
29+
val bytes = ByteArray(numBlocks * bytesPerBlock)
30+
rng.nextBytes(bytes)
31+
for (block in 0 until numBlocks) {
32+
val base = block * bytesPerBlock
33+
bytes[base + 0] = 0x00.toByte()
34+
bytes[base + 1] = 0x22.toByte() // FP16 0x2200 ≈ 7.6e-3
35+
}
36+
return bytes
37+
}
38+
39+
private fun assertParity(
40+
inputDim: Int,
41+
outputDim: Int,
42+
seed: Int,
43+
tolPerBlock: Float = 1e-2f,
44+
) {
45+
val blocksPerInputDim = inputDim / blockSize
46+
val rng = Random(seed)
47+
val input = FloatArray(inputDim) { rng.nextFloat() - 0.5f }
48+
val weight = randomQ4_0Bytes(blocksPerInputDim, outputDim, seed)
49+
val outScalar = FloatArray(outputDim)
50+
val outPanama = FloatArray(outputDim)
51+
52+
ScalarQ4_0MatmulKernel.matmul(input, 0, weight, 0, inputDim, outputDim, outScalar, 0)
53+
PanamaVectorQ4_0MatmulKernel.matmul(input, 0, weight, 0, inputDim, outputDim, outPanama, 0)
54+
55+
val tol = (tolPerBlock * blocksPerInputDim.coerceAtLeast(1)).coerceAtLeast(tolPerBlock)
56+
for (i in outScalar.indices) {
57+
val diff = abs(outScalar[i] - outPanama[i])
58+
assertTrue(
59+
diff <= tol,
60+
"mismatch at $i: scalar=${outScalar[i]} panama=${outPanama[i]} diff=$diff tol=$tol",
61+
)
62+
}
63+
}
64+
65+
@Test fun single_block_single_output_matches_scalar() =
66+
assertParity(inputDim = 32, outputDim = 1, seed = 1)
67+
68+
@Test fun single_block_multiple_outputs_matches_scalar() =
69+
assertParity(inputDim = 32, outputDim = 7, seed = 2)
70+
71+
@Test fun multiple_blocks_single_output_matches_scalar() =
72+
assertParity(inputDim = 256, outputDim = 1, seed = 3)
73+
74+
@Test fun llm_typical_attention_proj_matches_scalar() =
75+
assertParity(inputDim = 512, outputDim = 512, seed = 4)
76+
77+
@Test fun llm_typical_ffn_proj_matches_scalar() =
78+
assertParity(inputDim = 256, outputDim = 1024, seed = 5)
79+
80+
@Test fun rejects_non_block_aligned_input_dim() {
81+
assertFailsWith<IllegalArgumentException> {
82+
PanamaVectorQ4_0MatmulKernel.matmul(
83+
FloatArray(31), 0,
84+
ByteArray(bytesPerBlock), 0,
85+
31, 1,
86+
FloatArray(1), 0,
87+
)
88+
}
89+
}
90+
91+
@Test fun zero_input_dim_zeros_output() {
92+
val out = FloatArray(5) { 9f }
93+
PanamaVectorQ4_0MatmulKernel.matmul(
94+
FloatArray(0), 0,
95+
ByteArray(0), 0,
96+
0, 5,
97+
out, 0,
98+
)
99+
for (v in out) assertEquals(0f, v, "output should be zeroed for inputDim=0")
100+
}
101+
102+
@Test fun provider_returns_panama_q4_0_when_available() {
103+
val kernel = PanamaVectorKernelProvider.matmulQ4_0()
104+
if (PanamaVectorKernelProvider.isAvailable()) {
105+
assertTrue(
106+
kernel === PanamaVectorQ4_0MatmulKernel,
107+
"Provider must hand out the Panama Q4_0 kernel when available",
108+
)
109+
} else {
110+
assertEquals(null, kernel, "Provider must return null when Vector API unavailable")
111+
}
112+
}
113+
}

skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/QuantizedMemSegMatmulTest.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class QuantizedMemSegMatmulTest {
4747

4848
/**
4949
* Encode a single Q4_0 block: 32 float values -> 18 bytes (2 scale + 16 packed nibbles).
50+
* Uses the canonical ggml *split* layout: code[j] is the low nibble of
51+
* byte j, code[j+16] is the high nibble of byte j.
5052
*/
5153
private fun encodeQ4_0Block(values: FloatArray): ByteArray {
5254
require(values.size == 32)
@@ -62,8 +64,8 @@ class QuantizedMemSegMatmulTest {
6264
val out = ByteArray(18)
6365
out[0] = (scaleHalf and 0xFF).toByte()
6466
out[1] = ((scaleHalf shr 8) and 0xFF).toByte()
65-
for (i in 0 until 16) {
66-
out[2 + i] = ((codes[2 * i + 1] shl 4) or codes[2 * i]).toByte()
67+
for (j in 0 until 16) {
68+
out[2 + j] = ((codes[j + 16] shl 4) or codes[j]).toByte()
6769
}
6870
return out
6971
}

skainet-backends/skainet-backend-native-cpu/native/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_library(skainet_kernels SHARED
1515
src/fp32_matmul.c
1616
src/bf16_matmul.c
1717
src/q8_0_matmul.c
18+
src/q4_0_matmul.c
1819
)
1920

2021
target_include_directories(skainet_kernels PUBLIC

skainet-backends/skainet-backend-native-cpu/native/include/skainet_kernels.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,32 @@ SKAINET_API void skainet_q8_0_matmul(
119119
int32_t output_offset
120120
);
121121

122+
/*
123+
* Q4_0 matrix-vector multiply.
124+
*
125+
* output[output_offset + o] = sum_j input[input_offset + j] *
126+
* dequant(weight[block, o, j])
127+
*
128+
* Block layout: canonical ggml Q4_0, 32 elements per block, 18 bytes
129+
* per block (2 B FP16 scale + 16 B packed 4-bit codes in split layout —
130+
* low nibbles → elements 0..15, high nibbles → 16..31), with packed
131+
* weights laid out as
132+
* weight + weight_byte_offset + (block_idx * output_dim + o) * 18
133+
*
134+
* Dequant per element: `(code - 8) * d`. input_dim must be a multiple
135+
* of 32.
136+
*/
137+
SKAINET_API void skainet_q4_0_matmul(
138+
const float* input,
139+
int32_t input_offset,
140+
const uint8_t* weight,
141+
int32_t weight_byte_offset,
142+
int32_t input_dim,
143+
int32_t output_dim,
144+
float* output,
145+
int32_t output_offset
146+
);
147+
122148
#ifdef __cplusplus
123149
}
124150
#endif

0 commit comments

Comments
 (0)