Skip to content

Commit 4525c12

Browse files
michalharakalclaude
andcommitted
perf(native q6k): fused Q8 int8 dot path (dotprod)
Mirror the q4k fused-int8 kernel: pre-quantize the input row to symmetric int8 (Q8) once per 256-block (reused across all output rows), unpack the 6-bit weight to centered int8 codes, and run each scale-group as an int8 dot (vdotq_s32 on dotprod targets, scalar fallback otherwise). Drops the 256-float scratch dequant + per-element float multiply. acc = d · d_in · Σ_g sc[g]·Σ_{i∈g} q8[i]·codes[i]. This is deliberately lossy (ggml-style activation quant, ~1-3% on worst-case uniform-random fixtures) so it is no longer bit-exact vs the float/scalar reference. Both parity tests (jvmTest Panama, nativeTest cinterop on linuxX64 + linuxArm64) switch from per-row relative error — unbounded on near-zero rows of zero-mean fixtures — to the aggregate error-energy gate RMS(error)/RMS(signal) < 0.03. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 717f362 commit 4525c12

3 files changed

Lines changed: 157 additions & 74 deletions

File tree

skainet-backends/skainet-backend-native-cpu/native/src/q6k_matmul.c

Lines changed: 107 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#include <stddef.h>
55
#include <stdint.h>
6+
#include <stdlib.h>
7+
#include <math.h>
68

79
#define Q6K_BLOCK_SIZE 256
810
#define Q6K_BYTES_PER_BLOCK 210
@@ -40,62 +42,109 @@ static inline float skainet_q6k_half_to_float(uint16_t hbits) {
4042
}
4143

4244
/*
43-
* Dequantize one 256-element Q6_K super-block into scratch[256].
44-
* Direct transcription of ScalarQ6_KMatmulKernel.dequantBlock /
45-
* ggml dequantize_row_q6_K: two 128-element halves, each split into two
46-
* 16-element scale groups carrying four strided sub-codes (q1..q4).
47-
*
48-
* The 6-bit code is `lowNibble(ql) | (twoHighBits(qh) << 4)`, biased by
49-
* -32, and `scales` are SIGNED int8. Per-element value = d * scale * code.
45+
* Quantize one 256-float input block to symmetric int8 (Q8), d_in = maxabs/127,
46+
* q8[i] = round(in[i]/d_in). Mirrors q4k_matmul.c's activation quant (ggml
47+
* block_q8_K style) — the source of the small (~1-3%) error vs the exact float
48+
* kernel and what unlocks the int8 dot path. Returns d_in (0 + zeroed q8 if the
49+
* block is all-zero).
50+
*/
51+
static inline float skainet_q6k_q8_quantize_block(const float* SKAINET_RESTRICT in,
52+
int8_t* SKAINET_RESTRICT q8) {
53+
float maxabs = 0.0f;
54+
for (int i = 0; i < Q6K_BLOCK_SIZE; ++i) {
55+
const float a = in[i] < 0.0f ? -in[i] : in[i];
56+
if (a > maxabs) maxabs = a;
57+
}
58+
if (maxabs == 0.0f) {
59+
for (int i = 0; i < Q6K_BLOCK_SIZE; ++i) q8[i] = 0;
60+
return 0.0f;
61+
}
62+
const float d_in = maxabs / 127.0f;
63+
const float inv = 127.0f / maxabs;
64+
for (int i = 0; i < Q6K_BLOCK_SIZE; ++i) {
65+
int v = (int) lrintf(in[i] * inv);
66+
if (v > 127) v = 127; else if (v < -127) v = -127;
67+
q8[i] = (int8_t) v;
68+
}
69+
return d_in;
70+
}
71+
72+
/*
73+
* Unpack one 256-element Q6_K super-block into CENTERED int8 codes[256] (the
74+
* 6-bit code biased by -32, range [-32, 31]) in natural element order — i.e.
75+
* codes[i] pairs with input[i]. Same bit layout as the float dequant
76+
* (ScalarQ6_KMatmulKernel / ggml dequantize_row_q6_K) but without folding in
77+
* `d`/`scale`: those are applied per scale-group in the int dot, so the inner
78+
* product stays integer. Two 128-element halves, each with two 16-element scale
79+
* groups carrying four strided sub-codes (q1..q4) at output offsets +0/+32/+64/+96.
5080
*/
51-
static inline void skainet_q6k_dequant_block(const uint8_t* SKAINET_RESTRICT block,
52-
float* SKAINET_RESTRICT scratch) {
81+
static inline void skainet_q6k_unpack_codes(const uint8_t* SKAINET_RESTRICT block,
82+
int8_t* SKAINET_RESTRICT codes) {
5383
const uint8_t* ql0 = block + Q6K_QL_OFFSET;
5484
const uint8_t* qh0 = block + Q6K_QH_OFFSET;
55-
const int8_t* sc0 = (const int8_t*)(block + Q6K_SCALES_OFFSET);
56-
const uint16_t d_bits = (uint16_t) block[Q6K_D_OFFSET]
57-
| ((uint16_t) block[Q6K_D_OFFSET + 1] << 8);
58-
const float d = skainet_q6k_half_to_float(d_bits);
5985

6086
for (int half = 0; half < 2; ++half) {
6187
const uint8_t* ql = ql0 + half * 64;
6288
const uint8_t* qh = qh0 + half * 32;
63-
const int8_t* sc = sc0 + half * 8;
64-
float* out = scratch + half * 128;
89+
int8_t* out = codes + half * 128;
6590
for (int is = 0; is < 2; ++is) {
66-
const float sc1 = d * (float) sc[is + 0];
67-
const float sc2 = d * (float) sc[is + 2];
68-
const float sc3 = d * (float) sc[is + 4];
69-
const float sc4 = d * (float) sc[is + 6];
7091
const int l_start = is * 16;
7192
for (int l = l_start; l < l_start + 16; ++l) {
7293
const int q_l0 = ql[l];
7394
const int q_l32 = ql[l + 32];
7495
const int q_h = qh[l];
75-
const int q1 = ((q_l0 & 0x0F) | ((q_h & 0x03) << 4)) - 32;
76-
const int q2 = ((q_l32 & 0x0F) | (((q_h >> 2) & 0x03) << 4)) - 32;
77-
const int q3 = ((q_l0 >> 4) | (((q_h >> 4) & 0x03) << 4)) - 32;
78-
const int q4 = ((q_l32 >> 4) | (((q_h >> 6) & 0x03) << 4)) - 32;
79-
out[l + 0] = sc1 * (float) q1;
80-
out[l + 32] = sc2 * (float) q2;
81-
out[l + 64] = sc3 * (float) q3;
82-
out[l + 96] = sc4 * (float) q4;
96+
out[l + 0] = (int8_t)(((q_l0 & 0x0F) | ((q_h & 0x03) << 4)) - 32);
97+
out[l + 32] = (int8_t)(((q_l32 & 0x0F) | (((q_h >> 2) & 0x03) << 4)) - 32);
98+
out[l + 64] = (int8_t)(((q_l0 >> 4) | (((q_h >> 4) & 0x03) << 4)) - 32);
99+
out[l + 96] = (int8_t)(((q_l32 >> 4) | (((q_h >> 6) & 0x03) << 4)) - 32);
83100
}
84101
}
85102
}
86103
}
87104

105+
/*
106+
* Weighted integer dot of one Q6_K block: Σ_g sc[g] · Σ_{i∈g} q8[i]·codes[i],
107+
* over the 16 scale-groups (each a 16-element contiguous run in natural order).
108+
* Run `r` for (half,k,is) starts at half*128 + 32*k + is*16 and uses signed
109+
* scale sc[half*8 + is + 2*k]. On AArch64 with dotprod each 16-element dot is a
110+
* single vdotq_s32; otherwise a scalar fallback (auto-vectorizes under -O3).
111+
*/
112+
static inline int64_t skainet_q6k_weighted_dot(const int8_t* SKAINET_RESTRICT q8,
113+
const int8_t* SKAINET_RESTRICT codes,
114+
const int8_t* SKAINET_RESTRICT sc) {
115+
int64_t sum = 0;
116+
for (int half = 0; half < 2; ++half) {
117+
for (int k = 0; k < 4; ++k) {
118+
for (int is = 0; is < 2; ++is) {
119+
const int start = half * 128 + 32 * k + is * 16;
120+
const int gs = half * 8 + is + 2 * k;
121+
int32_t dot;
122+
#ifdef SKAINET_HAVE_DOTPROD
123+
const int32x4_t acc = vdotq_s32(vdupq_n_s32(0),
124+
vld1q_s8(codes + start), vld1q_s8(q8 + start));
125+
dot = vaddvq_s32(acc);
126+
#else
127+
dot = 0;
128+
for (int j = 0; j < 16; ++j) dot += (int) q8[start + j] * (int) codes[start + j];
129+
#endif
130+
sum += (int64_t) sc[gs] * dot;
131+
}
132+
}
133+
}
134+
return sum;
135+
}
136+
88137
/*
89138
* Native Q6_K matrix-vector multiply matching the
90139
* sk.ainet.backend.api.kernel.Q6KMatmulKernel SPI contract. A single
91140
* input row times an `outputDim x inputDim` Q6_K-packed weight tensor
92141
* laid out (blockIdx * outputDim + o) * 210 bytes.
93142
*
94-
* The 6-bit bit-assembly is kept scalar (cheap byte shuffling that the
95-
* compiler auto-vectorizes under -O3) and materialized into a 256-float
96-
* scratch block; the hot dot product against the input window is the
97-
* NEON path (vfmaq_f32 + horizontal add) behind __ARM_NEON. On non-ARM
98-
* targets the dot is a straight-line loop that auto-vectorizes too.
143+
* Fused int8 dot path (ggml-style, mirrors q4k_matmul.c): the input row is
144+
* quantized to Q8 ONCE per 256-block (reused across all output rows), the 6-bit
145+
* weight is unpacked to centered int8 codes, and each scale-group is an int8
146+
* dot (vdotq_s32 on dotprod targets) — no 256-float scratch, no per-element
147+
* float multiply. acc = d · d_in · Σ_g sc[g]·Σ_{i∈g} q8[i]·codes[i].
99148
*/
100149
SKAINET_API void skainet_q6k_matmul(
101150
const float* SKAINET_RESTRICT input,
@@ -113,43 +162,46 @@ SKAINET_API void skainet_q6k_matmul(
113162
const float* in_base = input + input_offset;
114163
float* out_base = output + output_offset;
115164

116-
float scratch[Q6K_BLOCK_SIZE];
165+
/* Pre-quantize the whole input row to Q8 once (reused across all o). */
166+
int8_t* q8 = (int8_t*) malloc((size_t) input_dim * sizeof(int8_t));
167+
float* d_in = (float*) malloc((size_t) blocks_per_input_dim * sizeof(float));
168+
if (q8 == NULL || d_in == NULL) { free(q8); free(d_in); return; }
169+
for (int32_t b = 0; b < blocks_per_input_dim; ++b) {
170+
d_in[b] = skainet_q6k_q8_quantize_block(in_base + (size_t) b * Q6K_BLOCK_SIZE,
171+
q8 + (size_t) b * Q6K_BLOCK_SIZE);
172+
}
173+
174+
int8_t codes[Q6K_BLOCK_SIZE];
117175

118176
/*
119177
* Loop order: block OUTER, output row INNER — see q4k_matmul.c for the
120178
* rationale. The weight is block-major (blockIdx*output_dim + o)*210, so for
121179
* a fixed block consecutive `o` are 210 bytes apart: the weight bytes are
122180
* read sequentially (cache/prefetch friendly) instead of striding
123-
* output_dim*210 per step, which on the in-order A55 makes every read a cold
124-
* miss. The big Q6_K `output` projection (hidden→vocab, hit every token) is
125-
* the main beneficiary. out_base[o] accumulates across blocks; the order
126-
* over blocks is unchanged ⇒ numerically identical to the o-outer form.
181+
* output_dim*210 per step. out_base[o] accumulates across blocks; the order
182+
* over blocks is unchanged.
127183
*/
128184
for (int32_t o = 0; o < output_dim; ++o) out_base[o] = 0.0f;
129185

130186
for (int32_t block_idx = 0; block_idx < blocks_per_input_dim; ++block_idx) {
131-
const float* in_block = in_base + (size_t) block_idx * Q6K_BLOCK_SIZE;
187+
const int8_t* q8_block = q8 + (size_t) block_idx * Q6K_BLOCK_SIZE;
188+
const float di = d_in[block_idx];
132189
const uint8_t* block = weight + weight_byte_offset
133190
+ (size_t)(block_idx * output_dim) * Q6K_BYTES_PER_BLOCK;
134191

135192
for (int32_t o = 0; o < output_dim; ++o, block += Q6K_BYTES_PER_BLOCK) {
136-
skainet_q6k_dequant_block(block, scratch);
137-
138-
float acc = 0.0f;
139-
#ifdef SKAINET_HAVE_NEON
140-
float32x4_t vacc = vdupq_n_f32(0.0f);
141-
for (int i = 0; i < Q6K_BLOCK_SIZE; i += 4) {
142-
const float32x4_t vi = vld1q_f32(in_block + i);
143-
const float32x4_t vw = vld1q_f32(scratch + i);
144-
vacc = vfmaq_f32(vacc, vi, vw);
145-
}
146-
acc = skainet_neon_hadd_f32(vacc);
147-
#else
148-
for (int i = 0; i < Q6K_BLOCK_SIZE; ++i) {
149-
acc += in_block[i] * scratch[i];
150-
}
151-
#endif
152-
out_base[o] += acc;
193+
const uint16_t d_bits = (uint16_t) block[Q6K_D_OFFSET]
194+
| ((uint16_t) block[Q6K_D_OFFSET + 1] << 8);
195+
const float d = skainet_q6k_half_to_float(d_bits);
196+
const int8_t* sc = (const int8_t*)(block + Q6K_SCALES_OFFSET);
197+
198+
skainet_q6k_unpack_codes(block, codes);
199+
const int64_t wdot = skainet_q6k_weighted_dot(q8_block, codes, sc);
200+
201+
out_base[o] += d * di * (float) wdot;
153202
}
154203
}
204+
205+
free(q8);
206+
free(d_in);
155207
}

skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ6KMatmulKernelParityTest.kt

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,16 @@ import kotlin.test.assertTrue
1616
* Fixture mirrors [NativeQ5KMatmulKernelParityTest]: random Q6_K bytes with
1717
* `d` clamped to `1.0f16` (bytes 208-209), packed input-block-major
1818
* `(blockIdx * outputDim + o) * 210`. Random `ql`/`qh`/`scales` exercise the
19-
* 6-bit bit-assembly and the signed int8 scales. Q6_K magnitudes are larger
20-
* than Q5_K (codes [-32, 31] × int8 scales), so absolute tolerances are a
21-
* touch looser; the `rel < 1e-4` relative check is the real gate.
19+
* 6-bit bit-assembly and the signed int8 scales.
20+
*
21+
* Like [NativeQ4KMatmulKernelParityTest], the native kernel quantizes the
22+
* activation to int8 (Q8) for the dotprod fast path — deliberately lossy
23+
* (ggml-style), so it is NOT bit-exact vs the float Panama reference. Per-row
24+
* relative error is the wrong gate (a near-zero true row shows unbounded
25+
* relative error from a tiny absolute one on zero-mean random fixtures); the
26+
* meaningful metric is the aggregate error energy RMS(error)/RMS(signal). Real
27+
* (smoother) LLM activations are far tighter than these worst-case fixtures;
28+
* the end-to-end gate is the on-board generation output.
2229
*/
2330
class NativeQ6KMatmulKernelParityTest {
2431

@@ -58,14 +65,25 @@ class NativeQ6KMatmulKernelParityTest {
5865
val nativeOut = FloatArray(outputDim)
5966
NativeQ6KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, nativeOut, 0)
6067

68+
var sqErr = 0.0
69+
var sqSig = 0.0
6170
for (o in 0 until outputDim) {
62-
val diff = abs(refOut[o] - nativeOut[o])
63-
val rel = diff / (abs(refOut[o]) + 1e-9f)
64-
assertTrue(
65-
diff <= tol || rel < 1e-4f,
66-
"row $o diverged: panama=${refOut[o]} native=${nativeOut[o]} diff=$diff rel=$rel tol=$tol",
67-
)
71+
val d = (refOut[o] - nativeOut[o]).toDouble()
72+
sqErr += d * d
73+
sqSig += refOut[o].toDouble() * refOut[o].toDouble()
6874
}
75+
val rmsErr = kotlin.math.sqrt(sqErr / outputDim)
76+
val rmsSig = kotlin.math.sqrt(sqSig / outputDim)
77+
val relRms = rmsErr / (rmsSig + 1e-9)
78+
assertTrue(
79+
relRms < AGG_REL_TOL || rmsErr < tol,
80+
"Q8 parity exceeded: relRms=$relRms (rmsErr=$rmsErr rmsSig=$rmsSig) over $outputDim rows, tol=$AGG_REL_TOL",
81+
)
82+
}
83+
84+
private companion object {
85+
// Aggregate Q8-activation RMS-relative-error bound (uniform-random worst case).
86+
const val AGG_REL_TOL = 0.03
6987
}
7088

7189
@Test

skainet-backends/skainet-backend-native-cpu/src/nativeTest/kotlin/sk/ainet/exec/kernel/NativeKnQ6KMatmulKernelParityTest.kt

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ import kotlin.test.assertTrue
1212
* `-ffast-math` reassociation tolerance.
1313
*
1414
* Runs on linuxX64 (host archive: scalar/auto-vectorized) AND linuxArm64
15-
* (cross-built archive: NEON), so the aarch64 run bit-checks the
16-
* `SKAINET_HAVE_NEON` path in q6k_matmul.c. Q6_K magnitudes (codes
17-
* [-32, 31] × signed int8 scales) are larger than Q5_K, so absolute tolerances
18-
* are a touch looser; the `rel < 1e-4` relative check is the real gate.
15+
* (cross-built archive: NEON), so the aarch64 run exercises the
16+
* `SKAINET_HAVE_NEON` / `SKAINET_HAVE_DOTPROD` path in q6k_matmul.c.
17+
*
18+
* The C kernel quantizes the activation to int8 (Q8) for the dotprod fast path
19+
* — deliberately lossy (ggml-style), so it is NOT bit-exact vs the scalar
20+
* reference. The gate is the aggregate error energy RMS(error)/RMS(signal), not
21+
* per-row relative error (unbounded on near-zero rows of zero-mean fixtures).
1922
*/
2023
class NativeKnQ6KMatmulKernelParityTest {
2124

@@ -46,14 +49,24 @@ class NativeKnQ6KMatmulKernelParityTest {
4649
val knOut = FloatArray(outputDim)
4750
NativeKnQ6KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, knOut, 0)
4851

52+
var sqErr = 0.0
53+
var sqSig = 0.0
4954
for (o in 0 until outputDim) {
50-
val diff = abs(refOut[o] - knOut[o])
51-
val rel = diff / (abs(refOut[o]) + 1e-9f)
52-
assertTrue(
53-
diff <= tol || rel < 1e-4f,
54-
"row $o diverged: scalar=${refOut[o]} cinterop=${knOut[o]} diff=$diff rel=$rel tol=$tol",
55-
)
55+
val d = (refOut[o] - knOut[o]).toDouble()
56+
sqErr += d * d
57+
sqSig += refOut[o].toDouble() * refOut[o].toDouble()
5658
}
59+
val rmsErr = kotlin.math.sqrt(sqErr / outputDim)
60+
val rmsSig = kotlin.math.sqrt(sqSig / outputDim)
61+
val relRms = rmsErr / (rmsSig + 1e-9)
62+
assertTrue(
63+
relRms < AGG_REL_TOL || rmsErr < tol,
64+
"Q8 parity exceeded: relRms=$relRms (rmsErr=$rmsErr rmsSig=$rmsSig) over $outputDim rows, tol=$AGG_REL_TOL",
65+
)
66+
}
67+
68+
private companion object {
69+
const val AGG_REL_TOL = 0.03
5770
}
5871

5972
@Test

0 commit comments

Comments
 (0)