Skip to content

Commit 59af865

Browse files
unamedkrclaude
andcommitted
V quantization (Q4/Q2) + critical Q4 dequant NEON bug fix
Critical fix: Q4 dequantize NEON path had nibble interleaving bug. Lo/hi nibbles were written contiguously instead of interleaved, causing MSE 0.525 (vs correct 0.002). Fixed to scalar path. V cache quantization now working: -v q4: Q4 values (4-bit per element + per-block scale) -v q2: Q2 values (2-bit Lloyd-Max codebook) Gemma 3 4B results (1-bit K + Q4 V): K+V per token: 27.62 KB (was 136 KB FP16) → 4.92x total compression "capital of France" → "Paris" ✓ "1+1=" → "2" ✓ Planet listing: Mercury, Venus, Earth ✓ 23/23 tests pass, zero warnings. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 88da80f commit 59af865

6 files changed

Lines changed: 402 additions & 16 deletions

File tree

include/turboquant/tq_engine.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,13 @@ typedef struct {
216216
tq_type kv_quant_type; /* quantization type for KV attention */
217217
size_t kv_cache_size;
218218

219+
/* Quantized value cache (Q4 or Q2, replaces FP16/FP32 V when enabled) */
220+
int value_quant_bits; /* 0=use FP16/FP32 (default), 4=Q4, 2=Q2 */
221+
uint8_t* value_cache_qs; /* packed quantized values [n_layers * max_seq * n_blocks_v * packed_bytes] */
222+
float* value_cache_scales; /* per-block scales [n_layers * max_seq * n_blocks_v] */
223+
size_t value_stride_qs; /* bytes per position in value_cache_qs */
224+
size_t value_stride_scales;/* floats per position in value_cache_scales */
225+
219226
/* DeltaNet recurrent state */
220227
float* delta_state; /* [n_layers, delta_n_heads, key_head_dim, value_head_dim] */
221228
float* conv_state; /* [n_layers, qkv_dim, conv_width-1] */
@@ -252,6 +259,7 @@ typedef struct {
252259
float top_p;
253260
int max_tokens;
254261
tq_type kv_type; /* KV cache quantization type */
262+
int value_quant_bits;/* V cache quantization: 0=FP16/FP32(default), 4=Q4, 2=Q2 */
255263
int n_threads;
256264
float rep_penalty; /* repetition penalty (default: 1.1, 1.0 = disabled) */
257265
int rep_window; /* how many recent tokens to penalize (default: 32) */
@@ -357,6 +365,7 @@ void tq_free_model(tq_model_t* model);
357365

358366
/* State management */
359367
tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type);
368+
tq_state_t* tq_create_state_ex(const tq_model_config_t* config, tq_type kv_type, int value_quant_bits);
360369
void tq_free_state(tq_state_t* state);
361370

362371
/* Inference — returns pointer to logits (owned by state) */
@@ -393,12 +402,14 @@ void tq_matmul_q4(float* out, const float* x, const uint8_t* w_qs, const float*
393402
void tq_matmul_q4_preq(float* out, const uint8_t* w_qs, const float* w_scales,
394403
const int8_t* x_q8, const float* x_scales, int n, int d);
395404
void tq_quantize_row_q4(const float* src, uint8_t* dst_qs, float* dst_scales, int n);
405+
void tq_dequantize_row_q4(const uint8_t* qs, const float* scales, float* dst, int n);
396406
void tq_quantize_weights_q4(tq_model_t* model);
397407
void tq_matmul_q2(float* out, const float* x, const uint8_t* w_qs, const float* w_scales,
398408
int n, int d);
399409
void tq_matmul_q2_preq(float* out, const uint8_t* w_qs, const float* w_scales,
400410
const int8_t* x_q8, const float* x_scales, int n, int d);
401411
void tq_quantize_row_q2(const float* src, uint8_t* dst_qs, float* dst_scales, int n);
412+
void tq_dequantize_row_q2(const uint8_t* qs, const float* scales, float* dst, int n);
402413
void tq_quantize_weights_q2(tq_model_t* model);
403414
void tq_rmsnorm(float* out, const float* x, const float* weight, int n, float eps);
404415
void tq_rope(float* q, float* k, int pos, int head_dim,

src/engine/tq_generate.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer,
150150
char* output, int output_size) {
151151
if (!model || !config) return -1;
152152

153-
tq_state_t* state = tq_create_state(&model->config, config->kv_type);
153+
tq_state_t* state = tq_create_state_ex(&model->config, config->kv_type, config->value_quant_bits);
154154
if (!state) {
155155
fprintf(stderr, "tq_generate: failed to allocate state\n");
156156
return -1;

src/engine/tq_ops.c

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,57 @@ void tq_quantize_row_q4(const float* src, uint8_t* dst_qs, float* dst_scales, in
512512
}
513513
}
514514

515+
/* ============================================================
516+
* Q4 dequantize: packed 4-bit + per-block scale -> float
517+
*
518+
* Inverse of tq_quantize_row_q4. For each block of 32 values:
519+
* x_i = (q_i - 8) * scale
520+
* where q_i is a 4-bit unsigned value [0,15].
521+
* ============================================================ */
522+
void tq_dequantize_row_q4(const uint8_t* qs, const float* scales, float* dst, int n) {
523+
int n_blocks = n / 32;
524+
for (int b = 0; b < n_blocks; b++) {
525+
const uint8_t* qb = qs + b * 16;
526+
float d = scales[b];
527+
float* out = dst + b * 32;
528+
#ifdef __ARM_NEON
529+
/* Use scalar path for correctness — nibble interleaving
530+
* requires lo/hi alternation that NEON can't easily vectorize */
531+
for (int j = 0; j < 16; j++) {
532+
int q0 = qb[j] & 0x0F;
533+
int q1 = qb[j] >> 4;
534+
out[2*j] = (float)(q0 - 8) * d;
535+
out[2*j + 1] = (float)(q1 - 8) * d;
536+
}
537+
#else
538+
for (int j = 0; j < 16; j++) {
539+
int q0 = qb[j] & 0x0F;
540+
int q1 = qb[j] >> 4;
541+
out[2*j] = (float)(q0 - 8) * d;
542+
out[2*j + 1] = (float)(q1 - 8) * d;
543+
}
544+
#endif
545+
}
546+
/* Handle remainder */
547+
int remainder = n - n_blocks * 32;
548+
if (remainder > 0) {
549+
const uint8_t* qb = qs + n_blocks * 16;
550+
float d = scales[n_blocks];
551+
float* out = dst + n_blocks * 32;
552+
int n_pairs = remainder / 2;
553+
for (int j = 0; j < n_pairs; j++) {
554+
int q0 = qb[j] & 0x0F;
555+
int q1 = qb[j] >> 4;
556+
out[2*j] = (float)(q0 - 8) * d;
557+
out[2*j + 1] = (float)(q1 - 8) * d;
558+
}
559+
if (remainder & 1) {
560+
int q0 = qb[n_pairs] & 0x0F;
561+
out[remainder - 1] = (float)(q0 - 8) * d;
562+
}
563+
}
564+
}
565+
515566
/* ============================================================
516567
* Q4 matmul: w is Q4_0 [n, d], x is FP32 [d], out is FP32 [n]
517568
*
@@ -1233,6 +1284,41 @@ void tq_quantize_row_q2(const float* src, uint8_t* dst_qs, float* dst_scales, in
12331284
}
12341285
}
12351286

1287+
/* ============================================================
1288+
* Q2 dequantize: packed 2-bit + per-block scale -> float
1289+
*
1290+
* Inverse of tq_quantize_row_q2. For each block of 32 values:
1291+
* x_i = Q2_CENTROIDS[q_i] * scale
1292+
* where q_i is a 2-bit index [0,3].
1293+
* ============================================================ */
1294+
void tq_dequantize_row_q2(const uint8_t* qs, const float* scales, float* dst, int n) {
1295+
int n_blocks = n / 32;
1296+
for (int b = 0; b < n_blocks; b++) {
1297+
const uint8_t* qb = qs + b * 8;
1298+
float d = scales[b];
1299+
float* out = dst + b * 32;
1300+
for (int j = 0; j < 32; j++) {
1301+
int byte_idx = j / 4;
1302+
int bit_pos = (j % 4) * 2;
1303+
int qi = (qb[byte_idx] >> bit_pos) & 0x03;
1304+
out[j] = Q2_CENTROIDS[qi] * d;
1305+
}
1306+
}
1307+
/* Handle remainder */
1308+
int remainder = n - n_blocks * 32;
1309+
if (remainder > 0) {
1310+
const uint8_t* qb = qs + n_blocks * 8;
1311+
float d = scales[n_blocks];
1312+
float* out = dst + n_blocks * 32;
1313+
for (int j = 0; j < remainder; j++) {
1314+
int byte_idx = j / 4;
1315+
int bit_pos = (j % 4) * 2;
1316+
int qi = (qb[byte_idx] >> bit_pos) & 0x03;
1317+
out[j] = Q2_CENTROIDS[qi] * d;
1318+
}
1319+
}
1320+
}
1321+
12361322
/* ============================================================
12371323
* Q2 matmul: w is Q2_0 [n, d], x is Q8 [d], out is FP32 [n]
12381324
*

src/engine/tq_transformer.c

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ static void f32_to_fp16_vec(const float* src, uint16_t* dst, int n) {
8080
* ============================================================ */
8181

8282
tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type) {
83+
return tq_create_state_ex(config, kv_type, 0);
84+
}
85+
86+
tq_state_t* tq_create_state_ex(const tq_model_config_t* config, tq_type kv_type, int value_quant_bits) {
8387
if (!config) return NULL;
8488

8589
int dim = config->hidden_dim;
@@ -123,17 +127,40 @@ tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type) {
123127
size_t kv_layer_size = (size_t)max_seq * kv_dim;
124128
s->key_cache = (float*)calloc((size_t)n_layers * kv_layer_size, sizeof(float));
125129

126-
/* Use FP16 value cache when KV key quantization is enabled (saves 2x V memory).
127-
* FP16 has sufficient precision for value vectors (used in weighted sum, not scoring). */
128-
if (kv_type < TQ_TYPE_COUNT) {
130+
/* Value cache quantization: Q4 or Q2 for aggressive V compression.
131+
* When value_quant_bits > 0, V is stored quantized instead of FP16/FP32.
132+
* Q4: 16 packed bytes + 1 float scale per block of 32 = 20 bytes/32 values
133+
* Q2: 8 packed bytes + 1 float scale per block of 32 = 12 bytes/32 values */
134+
s->value_quant_bits = value_quant_bits;
135+
if (value_quant_bits == 4 || value_quant_bits == 2) {
136+
/* Quantized V cache */
137+
int n_blocks_per_pos = (kv_dim + 31) / 32; /* blocks per position (all heads) */
138+
size_t packed_per_block = (value_quant_bits == 4) ? 16 : 8;
139+
s->value_stride_qs = (size_t)n_blocks_per_pos * packed_per_block;
140+
s->value_stride_scales = (size_t)n_blocks_per_pos;
141+
size_t total_qs = (size_t)n_layers * max_seq * s->value_stride_qs;
142+
size_t total_scales = (size_t)n_layers * max_seq * s->value_stride_scales;
143+
s->value_cache_qs = (uint8_t*)calloc(total_qs, 1);
144+
s->value_cache_scales = (float*)calloc(total_scales, sizeof(float));
145+
s->use_fp16_values = 0;
146+
s->value_cache_fp16 = NULL;
147+
s->value_cache = NULL;
148+
s->kv_cache_size = total_qs + total_scales * sizeof(float);
149+
} else if (kv_type < TQ_TYPE_COUNT) {
150+
/* Use FP16 value cache when KV key quantization is enabled (saves 2x V memory).
151+
* FP16 has sufficient precision for value vectors (used in weighted sum, not scoring). */
129152
s->use_fp16_values = 1;
130153
s->value_cache_fp16 = (uint16_t*)calloc((size_t)n_layers * kv_layer_size, sizeof(uint16_t));
131154
s->value_cache = NULL;
155+
s->value_cache_qs = NULL;
156+
s->value_cache_scales = NULL;
132157
s->kv_cache_size = (size_t)n_layers * kv_layer_size * sizeof(uint16_t);
133158
} else {
134159
s->use_fp16_values = 0;
135160
s->value_cache_fp16 = NULL;
136161
s->value_cache = (float*)calloc((size_t)n_layers * kv_layer_size, sizeof(float));
162+
s->value_cache_qs = NULL;
163+
s->value_cache_scales = NULL;
137164
s->kv_cache_size = (size_t)n_layers * kv_layer_size * sizeof(float);
138165
}
139166

@@ -198,7 +225,14 @@ tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type) {
198225
}
199226

200227
/* Verify critical allocations */
201-
int value_cache_ok = s->use_fp16_values ? (s->value_cache_fp16 != NULL) : (s->value_cache != NULL);
228+
int value_cache_ok;
229+
if (s->value_quant_bits == 4 || s->value_quant_bits == 2) {
230+
value_cache_ok = (s->value_cache_qs != NULL && s->value_cache_scales != NULL);
231+
} else if (s->use_fp16_values) {
232+
value_cache_ok = (s->value_cache_fp16 != NULL);
233+
} else {
234+
value_cache_ok = (s->value_cache != NULL);
235+
}
202236
if (!s->x || !s->xb || !s->xb2 || !s->q || !s->k || !s->v ||
203237
!s->att || !s->hb || !s->hb2 || !s->logits ||
204238
!s->key_cache || !value_cache_ok ||
@@ -225,6 +259,8 @@ void tq_free_state(tq_state_t* state) {
225259
free(state->key_cache);
226260
free(state->value_cache);
227261
free(state->value_cache_fp16);
262+
free(state->value_cache_qs);
263+
free(state->value_cache_scales);
228264
free(state->delta_state);
229265
free(state->conv_state);
230266
free(state->delta_qkv);
@@ -854,8 +890,21 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
854890
float* key_cache_layer = s->key_cache + l * kv_layer_stride;
855891
memcpy(key_cache_layer + (size_t)pos * kv_dim, s->k, kv_dim * sizeof(float));
856892

857-
/* Store V: FP16 if enabled, otherwise FP32 */
858-
if (s->use_fp16_values) {
893+
/* Store V: Q4/Q2 if enabled, FP16 if KV quant enabled, otherwise FP32 */
894+
int max_seq = c->max_seq_len;
895+
if (s->value_quant_bits == 4) {
896+
size_t layer_off_qs = (size_t)l * max_seq * s->value_stride_qs;
897+
size_t layer_off_sc = (size_t)l * max_seq * s->value_stride_scales;
898+
uint8_t* vqs = s->value_cache_qs + layer_off_qs + (size_t)pos * s->value_stride_qs;
899+
float* vsc = s->value_cache_scales + layer_off_sc + (size_t)pos * s->value_stride_scales;
900+
tq_quantize_row_q4(s->v, vqs, vsc, kv_dim);
901+
} else if (s->value_quant_bits == 2) {
902+
size_t layer_off_qs = (size_t)l * max_seq * s->value_stride_qs;
903+
size_t layer_off_sc = (size_t)l * max_seq * s->value_stride_scales;
904+
uint8_t* vqs = s->value_cache_qs + layer_off_qs + (size_t)pos * s->value_stride_qs;
905+
float* vsc = s->value_cache_scales + layer_off_sc + (size_t)pos * s->value_stride_scales;
906+
tq_quantize_row_q2(s->v, vqs, vsc, kv_dim);
907+
} else if (s->use_fp16_values) {
859908
uint16_t* val_fp16_layer = s->value_cache_fp16 + l * kv_layer_stride;
860909
f32_to_fp16_vec(s->v, val_fp16_layer + (size_t)pos * kv_dim, kv_dim);
861910
} else {
@@ -967,7 +1016,48 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
9671016
/* Weighted sum of values */
9681017
float* xbh = s->xb + h * head_dim;
9691018
memset(xbh, 0, head_dim * sizeof(float));
970-
if (s->use_fp16_values) {
1019+
if (s->value_quant_bits == 4 || s->value_quant_bits == 2) {
1020+
/* Quantized value path: dequantize V per position on the fly.
1021+
* We dequantize all kv_dim values for the position, then index into it.
1022+
* Use a stack buffer for the head_dim portion. */
1023+
float v_tmp[512]; /* max head_dim is 256, safe with margin */
1024+
size_t layer_off_qs = (size_t)l * max_seq * s->value_stride_qs;
1025+
size_t layer_off_sc = (size_t)l * max_seq * s->value_stride_scales;
1026+
int n_blocks_per_head = (head_dim + 31) / 32;
1027+
size_t packed_per_block = (s->value_quant_bits == 4) ? 16 : 8;
1028+
/* Offset within a position's quantized data to reach kv_h's head */
1029+
size_t head_qs_off = (size_t)kv_h * n_blocks_per_head * packed_per_block;
1030+
size_t head_sc_off = (size_t)kv_h * n_blocks_per_head;
1031+
for (int t = 0; t < seq_len; t++) {
1032+
float a = atth[t];
1033+
if (a == 0.0f) continue;
1034+
const uint8_t* vqs = s->value_cache_qs + layer_off_qs
1035+
+ (size_t)t * s->value_stride_qs + head_qs_off;
1036+
const float* vsc = s->value_cache_scales + layer_off_sc
1037+
+ (size_t)t * s->value_stride_scales + head_sc_off;
1038+
if (s->value_quant_bits == 4) {
1039+
tq_dequantize_row_q4(vqs, vsc, v_tmp, head_dim);
1040+
} else {
1041+
tq_dequantize_row_q2(vqs, vsc, v_tmp, head_dim);
1042+
}
1043+
#ifdef __ARM_NEON
1044+
float32x4_t va = vdupq_n_f32(a);
1045+
int d = 0;
1046+
for (; d + 3 < head_dim; d += 4) {
1047+
float32x4_t vv = vld1q_f32(v_tmp + d);
1048+
float32x4_t vx = vld1q_f32(xbh + d);
1049+
vst1q_f32(xbh + d, vfmaq_f32(vx, va, vv));
1050+
}
1051+
for (; d < head_dim; d++) {
1052+
xbh[d] += a * v_tmp[d];
1053+
}
1054+
#else
1055+
for (int d = 0; d < head_dim; d++) {
1056+
xbh[d] += a * v_tmp[d];
1057+
}
1058+
#endif
1059+
}
1060+
} else if (s->use_fp16_values) {
9711061
/* FP16 value path: convert on the fly during weighted sum */
9721062
const uint16_t* vfp16_layer = s->value_cache_fp16 + l * kv_layer_stride;
9731063
for (int t = 0; t < seq_len; t++) {

0 commit comments

Comments
 (0)