@@ -80,6 +80,10 @@ static void f32_to_fp16_vec(const float* src, uint16_t* dst, int n) {
8080 * ============================================================ */
8181
8282tq_state_t * tq_create_state (const tq_model_config_t * config , tq_type kv_type ) {
83+ return tq_create_state_ex (config , kv_type , 0 );
84+ }
85+
86+ tq_state_t * tq_create_state_ex (const tq_model_config_t * config , tq_type kv_type , int value_quant_bits ) {
8387 if (!config ) return NULL ;
8488
8589 int dim = config -> hidden_dim ;
@@ -123,17 +127,40 @@ tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type) {
123127 size_t kv_layer_size = (size_t )max_seq * kv_dim ;
124128 s -> key_cache = (float * )calloc ((size_t )n_layers * kv_layer_size , sizeof (float ));
125129
126- /* Use FP16 value cache when KV key quantization is enabled (saves 2x V memory).
127- * FP16 has sufficient precision for value vectors (used in weighted sum, not scoring). */
128- if (kv_type < TQ_TYPE_COUNT ) {
130+ /* Value cache quantization: Q4 or Q2 for aggressive V compression.
131+ * When value_quant_bits > 0, V is stored quantized instead of FP16/FP32.
132+ * Q4: 16 packed bytes + 1 float scale per block of 32 = 20 bytes/32 values
133+ * Q2: 8 packed bytes + 1 float scale per block of 32 = 12 bytes/32 values */
134+ s -> value_quant_bits = value_quant_bits ;
135+ if (value_quant_bits == 4 || value_quant_bits == 2 ) {
136+ /* Quantized V cache */
137+ int n_blocks_per_pos = (kv_dim + 31 ) / 32 ; /* blocks per position (all heads) */
138+ size_t packed_per_block = (value_quant_bits == 4 ) ? 16 : 8 ;
139+ s -> value_stride_qs = (size_t )n_blocks_per_pos * packed_per_block ;
140+ s -> value_stride_scales = (size_t )n_blocks_per_pos ;
141+ size_t total_qs = (size_t )n_layers * max_seq * s -> value_stride_qs ;
142+ size_t total_scales = (size_t )n_layers * max_seq * s -> value_stride_scales ;
143+ s -> value_cache_qs = (uint8_t * )calloc (total_qs , 1 );
144+ s -> value_cache_scales = (float * )calloc (total_scales , sizeof (float ));
145+ s -> use_fp16_values = 0 ;
146+ s -> value_cache_fp16 = NULL ;
147+ s -> value_cache = NULL ;
148+ s -> kv_cache_size = total_qs + total_scales * sizeof (float );
149+ } else if (kv_type < TQ_TYPE_COUNT ) {
150+ /* Use FP16 value cache when KV key quantization is enabled (saves 2x V memory).
151+ * FP16 has sufficient precision for value vectors (used in weighted sum, not scoring). */
129152 s -> use_fp16_values = 1 ;
130153 s -> value_cache_fp16 = (uint16_t * )calloc ((size_t )n_layers * kv_layer_size , sizeof (uint16_t ));
131154 s -> value_cache = NULL ;
155+ s -> value_cache_qs = NULL ;
156+ s -> value_cache_scales = NULL ;
132157 s -> kv_cache_size = (size_t )n_layers * kv_layer_size * sizeof (uint16_t );
133158 } else {
134159 s -> use_fp16_values = 0 ;
135160 s -> value_cache_fp16 = NULL ;
136161 s -> value_cache = (float * )calloc ((size_t )n_layers * kv_layer_size , sizeof (float ));
162+ s -> value_cache_qs = NULL ;
163+ s -> value_cache_scales = NULL ;
137164 s -> kv_cache_size = (size_t )n_layers * kv_layer_size * sizeof (float );
138165 }
139166
@@ -198,7 +225,14 @@ tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type) {
198225 }
199226
200227 /* Verify critical allocations */
201- int value_cache_ok = s -> use_fp16_values ? (s -> value_cache_fp16 != NULL ) : (s -> value_cache != NULL );
228+ int value_cache_ok ;
229+ if (s -> value_quant_bits == 4 || s -> value_quant_bits == 2 ) {
230+ value_cache_ok = (s -> value_cache_qs != NULL && s -> value_cache_scales != NULL );
231+ } else if (s -> use_fp16_values ) {
232+ value_cache_ok = (s -> value_cache_fp16 != NULL );
233+ } else {
234+ value_cache_ok = (s -> value_cache != NULL );
235+ }
202236 if (!s -> x || !s -> xb || !s -> xb2 || !s -> q || !s -> k || !s -> v ||
203237 !s -> att || !s -> hb || !s -> hb2 || !s -> logits ||
204238 !s -> key_cache || !value_cache_ok ||
@@ -225,6 +259,8 @@ void tq_free_state(tq_state_t* state) {
225259 free (state -> key_cache );
226260 free (state -> value_cache );
227261 free (state -> value_cache_fp16 );
262+ free (state -> value_cache_qs );
263+ free (state -> value_cache_scales );
228264 free (state -> delta_state );
229265 free (state -> conv_state );
230266 free (state -> delta_qkv );
@@ -854,8 +890,21 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
854890 float * key_cache_layer = s -> key_cache + l * kv_layer_stride ;
855891 memcpy (key_cache_layer + (size_t )pos * kv_dim , s -> k , kv_dim * sizeof (float ));
856892
857- /* Store V: FP16 if enabled, otherwise FP32 */
858- if (s -> use_fp16_values ) {
893+ /* Store V: Q4/Q2 if enabled, FP16 if KV quant enabled, otherwise FP32 */
894+ int max_seq = c -> max_seq_len ;
895+ if (s -> value_quant_bits == 4 ) {
896+ size_t layer_off_qs = (size_t )l * max_seq * s -> value_stride_qs ;
897+ size_t layer_off_sc = (size_t )l * max_seq * s -> value_stride_scales ;
898+ uint8_t * vqs = s -> value_cache_qs + layer_off_qs + (size_t )pos * s -> value_stride_qs ;
899+ float * vsc = s -> value_cache_scales + layer_off_sc + (size_t )pos * s -> value_stride_scales ;
900+ tq_quantize_row_q4 (s -> v , vqs , vsc , kv_dim );
901+ } else if (s -> value_quant_bits == 2 ) {
902+ size_t layer_off_qs = (size_t )l * max_seq * s -> value_stride_qs ;
903+ size_t layer_off_sc = (size_t )l * max_seq * s -> value_stride_scales ;
904+ uint8_t * vqs = s -> value_cache_qs + layer_off_qs + (size_t )pos * s -> value_stride_qs ;
905+ float * vsc = s -> value_cache_scales + layer_off_sc + (size_t )pos * s -> value_stride_scales ;
906+ tq_quantize_row_q2 (s -> v , vqs , vsc , kv_dim );
907+ } else if (s -> use_fp16_values ) {
859908 uint16_t * val_fp16_layer = s -> value_cache_fp16 + l * kv_layer_stride ;
860909 f32_to_fp16_vec (s -> v , val_fp16_layer + (size_t )pos * kv_dim , kv_dim );
861910 } else {
@@ -967,7 +1016,48 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
9671016 /* Weighted sum of values */
9681017 float * xbh = s -> xb + h * head_dim ;
9691018 memset (xbh , 0 , head_dim * sizeof (float ));
970- if (s -> use_fp16_values ) {
1019+ if (s -> value_quant_bits == 4 || s -> value_quant_bits == 2 ) {
1020+ /* Quantized value path: dequantize V per position on the fly.
1021+ * We dequantize all kv_dim values for the position, then index into it.
1022+ * Use a stack buffer for the head_dim portion. */
1023+ float v_tmp [512 ]; /* max head_dim is 256, safe with margin */
1024+ size_t layer_off_qs = (size_t )l * max_seq * s -> value_stride_qs ;
1025+ size_t layer_off_sc = (size_t )l * max_seq * s -> value_stride_scales ;
1026+ int n_blocks_per_head = (head_dim + 31 ) / 32 ;
1027+ size_t packed_per_block = (s -> value_quant_bits == 4 ) ? 16 : 8 ;
1028+ /* Offset within a position's quantized data to reach kv_h's head */
1029+ size_t head_qs_off = (size_t )kv_h * n_blocks_per_head * packed_per_block ;
1030+ size_t head_sc_off = (size_t )kv_h * n_blocks_per_head ;
1031+ for (int t = 0 ; t < seq_len ; t ++ ) {
1032+ float a = atth [t ];
1033+ if (a == 0.0f ) continue ;
1034+ const uint8_t * vqs = s -> value_cache_qs + layer_off_qs
1035+ + (size_t )t * s -> value_stride_qs + head_qs_off ;
1036+ const float * vsc = s -> value_cache_scales + layer_off_sc
1037+ + (size_t )t * s -> value_stride_scales + head_sc_off ;
1038+ if (s -> value_quant_bits == 4 ) {
1039+ tq_dequantize_row_q4 (vqs , vsc , v_tmp , head_dim );
1040+ } else {
1041+ tq_dequantize_row_q2 (vqs , vsc , v_tmp , head_dim );
1042+ }
1043+ #ifdef __ARM_NEON
1044+ float32x4_t va = vdupq_n_f32 (a );
1045+ int d = 0 ;
1046+ for (; d + 3 < head_dim ; d += 4 ) {
1047+ float32x4_t vv = vld1q_f32 (v_tmp + d );
1048+ float32x4_t vx = vld1q_f32 (xbh + d );
1049+ vst1q_f32 (xbh + d , vfmaq_f32 (vx , va , vv ));
1050+ }
1051+ for (; d < head_dim ; d ++ ) {
1052+ xbh [d ] += a * v_tmp [d ];
1053+ }
1054+ #else
1055+ for (int d = 0 ; d < head_dim ; d ++ ) {
1056+ xbh [d ] += a * v_tmp [d ];
1057+ }
1058+ #endif
1059+ }
1060+ } else if (s -> use_fp16_values ) {
9711061 /* FP16 value path: convert on the fly during weighted sum */
9721062 const uint16_t * vfp16_layer = s -> value_cache_fp16 + l * kv_layer_stride ;
9731063 for (int t = 0 ; t < seq_len ; t ++ ) {
0 commit comments