2929#include <arm_neon.h>
3030#endif
3131
32+ /* ============================================================
33+ * FP16 helpers (IEEE 754 half-precision, storage only)
34+ * ============================================================ */
35+
36+ static uint16_t f32_to_fp16 (float v ) {
37+ union { float f ; uint32_t u ; } bits ;
38+ bits .f = v ;
39+ uint32_t sign = (bits .u >> 16 ) & 0x8000 ;
40+ int32_t exp = ((bits .u >> 23 ) & 0xFF ) - 127 + 15 ;
41+ uint32_t mant = (bits .u >> 13 ) & 0x03FF ;
42+ if (exp <= 0 ) return (uint16_t )sign ;
43+ if (exp >= 31 ) return (uint16_t )(sign | 0x7C00 );
44+ return (uint16_t )(sign | ((uint32_t )exp << 10 ) | mant );
45+ }
46+
47+ static float fp16_to_f32 (uint16_t h ) {
48+ union { float f ; uint32_t u ; } bits ;
49+ uint32_t sign = (h & 0x8000 ) << 16 ;
50+ uint32_t exp = (h >> 10 ) & 0x1F ;
51+ uint32_t mant = h & 0x03FF ;
52+ if (exp == 0 ) { bits .u = sign ; return bits .f ; }
53+ if (exp == 31 ) { bits .u = sign | 0x7F800000 | (mant << 13 ); return bits .f ; }
54+ exp = exp - 15 + 127 ;
55+ bits .u = sign | (exp << 23 ) | (mant << 13 );
56+ return bits .f ;
57+ }
58+
59+ /* Convert n floats to FP16 (NEON-optimized where available) */
60+ static void f32_to_fp16_vec (const float * src , uint16_t * dst , int n ) {
61+ #ifdef __ARM_NEON
62+ int i = 0 ;
63+ for (; i + 3 < n ; i += 4 ) {
64+ float32x4_t vf = vld1q_f32 (src + i );
65+ float16x4_t vh = vcvt_f16_f32 (vf );
66+ vst1_u16 (dst + i , vreinterpret_u16_f16 (vh ));
67+ }
68+ for (; i < n ; i ++ ) {
69+ dst [i ] = f32_to_fp16 (src [i ]);
70+ }
71+ #else
72+ for (int i = 0 ; i < n ; i ++ ) {
73+ dst [i ] = f32_to_fp16 (src [i ]);
74+ }
75+ #endif
76+ }
77+
3278/* ============================================================
3379 * State management
3480 * ============================================================ */
@@ -76,8 +122,20 @@ tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type) {
76122 /* KV cache for self_attn layers */
77123 size_t kv_layer_size = (size_t )max_seq * kv_dim ;
78124 s -> key_cache = (float * )calloc ((size_t )n_layers * kv_layer_size , sizeof (float ));
79- s -> value_cache = (float * )calloc ((size_t )n_layers * kv_layer_size , sizeof (float ));
80- s -> kv_cache_size = (size_t )n_layers * kv_layer_size * sizeof (float );
125+
126+ /* Use FP16 value cache when KV key quantization is enabled (saves 2x V memory).
127+ * FP16 has sufficient precision for value vectors (used in weighted sum, not scoring). */
128+ if (kv_type < TQ_TYPE_COUNT ) {
129+ s -> use_fp16_values = 1 ;
130+ s -> value_cache_fp16 = (uint16_t * )calloc ((size_t )n_layers * kv_layer_size , sizeof (uint16_t ));
131+ s -> value_cache = NULL ;
132+ s -> kv_cache_size = (size_t )n_layers * kv_layer_size * sizeof (uint16_t );
133+ } else {
134+ s -> use_fp16_values = 0 ;
135+ s -> value_cache_fp16 = NULL ;
136+ s -> value_cache = (float * )calloc ((size_t )n_layers * kv_layer_size , sizeof (float ));
137+ s -> kv_cache_size = (size_t )n_layers * kv_layer_size * sizeof (float );
138+ }
81139
82140 /* Dynamic workspace buffers (replacing fixed-size stack arrays).
83141 * xb_q8/xb_q8s are used in deltanet_forward, self_attn_forward, and FFN
@@ -140,9 +198,10 @@ tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type) {
140198 }
141199
142200 /* Verify critical allocations */
201+ int value_cache_ok = s -> use_fp16_values ? (s -> value_cache_fp16 != NULL ) : (s -> value_cache != NULL );
143202 if (!s -> x || !s -> xb || !s -> xb2 || !s -> q || !s -> k || !s -> v ||
144203 !s -> att || !s -> hb || !s -> hb2 || !s -> logits ||
145- !s -> key_cache || !s -> value_cache ||
204+ !s -> key_cache || !value_cache_ok ||
146205 !s -> xb_q8 || !s -> xb_q8s ) {
147206 tq_free_state (s );
148207 return NULL ;
@@ -165,6 +224,7 @@ void tq_free_state(tq_state_t* state) {
165224 free (state -> logits );
166225 free (state -> key_cache );
167226 free (state -> value_cache );
227+ free (state -> value_cache_fp16 );
168228 free (state -> delta_state );
169229 free (state -> conv_state );
170230 free (state -> delta_qkv );
@@ -792,9 +852,16 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
792852
793853 /* Store K,V in cache */
794854 float * key_cache_layer = s -> key_cache + l * kv_layer_stride ;
795- float * val_cache_layer = s -> value_cache + l * kv_layer_stride ;
796855 memcpy (key_cache_layer + (size_t )pos * kv_dim , s -> k , kv_dim * sizeof (float ));
797- memcpy (val_cache_layer + (size_t )pos * kv_dim , s -> v , kv_dim * sizeof (float ));
856+
857+ /* Store V: FP16 if enabled, otherwise FP32 */
858+ if (s -> use_fp16_values ) {
859+ uint16_t * val_fp16_layer = s -> value_cache_fp16 + l * kv_layer_stride ;
860+ f32_to_fp16_vec (s -> v , val_fp16_layer + (size_t )pos * kv_dim , kv_dim );
861+ } else {
862+ float * val_cache_layer = s -> value_cache + l * kv_layer_stride ;
863+ memcpy (val_cache_layer + (size_t )pos * kv_dim , s -> v , kv_dim * sizeof (float ));
864+ }
798865
799866 /* Quantize the new key into the quantized cache for integer attention.
800867 * Each KV head's key vector is quantized independently into blocks. */
@@ -900,11 +967,40 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
900967 /* Weighted sum of values */
901968 float * xbh = s -> xb + h * head_dim ;
902969 memset (xbh , 0 , head_dim * sizeof (float ));
903- for (int t = 0 ; t < seq_len ; t ++ ) {
904- const float * vt = val_cache_layer + (size_t )t * kv_dim + kv_h * head_dim ;
905- float a = atth [t ];
906- for (int d = 0 ; d < head_dim ; d ++ ) {
907- xbh [d ] += a * vt [d ];
970+ if (s -> use_fp16_values ) {
971+ /* FP16 value path: convert on the fly during weighted sum */
972+ const uint16_t * vfp16_layer = s -> value_cache_fp16 + l * kv_layer_stride ;
973+ for (int t = 0 ; t < seq_len ; t ++ ) {
974+ const uint16_t * vt16 = vfp16_layer + (size_t )t * kv_dim + kv_h * head_dim ;
975+ float a = atth [t ];
976+ if (a == 0.0f ) continue ; /* skip zero-weight positions */
977+ #ifdef __ARM_NEON
978+ float32x4_t va = vdupq_n_f32 (a );
979+ int d = 0 ;
980+ for (; d + 3 < head_dim ; d += 4 ) {
981+ uint16x4_t vh = vld1_u16 (vt16 + d );
982+ float32x4_t vf = vcvt_f32_f16 (vreinterpret_f16_u16 (vh ));
983+ float32x4_t vx = vld1q_f32 (xbh + d );
984+ vst1q_f32 (xbh + d , vfmaq_f32 (vx , va , vf ));
985+ }
986+ for (; d < head_dim ; d ++ ) {
987+ xbh [d ] += a * fp16_to_f32 (vt16 [d ]);
988+ }
989+ #else
990+ for (int d = 0 ; d < head_dim ; d ++ ) {
991+ xbh [d ] += a * fp16_to_f32 (vt16 [d ]);
992+ }
993+ #endif
994+ }
995+ } else {
996+ /* FP32 value path (original) */
997+ const float * val_cache_layer_fp32 = s -> value_cache + l * kv_layer_stride ;
998+ for (int t = 0 ; t < seq_len ; t ++ ) {
999+ const float * vt = val_cache_layer_fp32 + (size_t )t * kv_dim + kv_h * head_dim ;
1000+ float a = atth [t ];
1001+ for (int d = 0 ; d < head_dim ; d ++ ) {
1002+ xbh [d ] += a * vt [d ];
1003+ }
9081004 }
9091005 }
9101006 }
0 commit comments