@@ -10,6 +10,7 @@ const inference = @import("gguf_inference.zig");
1010const transformer = @import ("gguf_transformer.zig" );
1111const flash = @import ("flash_attention.zig" );
1212const parallel = @import ("parallel_inference.zig" );
13+ const kv_cache = @import ("kv_cache.zig" );
1314
1415// ═══════════════════════════════════════════════════════════════════════════════
1516// .TRI FILE FORMAT
@@ -62,6 +63,10 @@ pub const TriModel = struct {
6263 rope : transformer.RoPE ,
6364 kv_caches : []transformer.KVCache ,
6465
66+ // Ternary KV cache (OPT-T03/T04) - 16x memory reduction
67+ ternary_kv_caches : ? []kv_cache.TernaryKVCache ,
68+ use_ternary_kv : bool ,
69+
6570 // Pre-allocated buffers
6671 buf_hidden : []f32 ,
6772 buf_temp : []f32 ,
@@ -131,6 +136,8 @@ pub const TriModel = struct {
131136 .layers = undefined ,
132137 .rope = undefined ,
133138 .kv_caches = undefined ,
139+ .ternary_kv_caches = null ,
140+ .use_ternary_kv = false ,
134141 .buf_hidden = undefined ,
135142 .buf_temp = undefined ,
136143 .buf_normed = undefined ,
@@ -276,6 +283,14 @@ pub const TriModel = struct {
276283 }
277284 self .allocator .free (self .kv_caches );
278285
286+ // Free ternary KV caches if enabled
287+ if (self .ternary_kv_caches ) | caches | {
288+ for (caches ) | * cache | {
289+ cache .deinit ();
290+ }
291+ self .allocator .free (caches );
292+ }
293+
279294 self .rope .deinit ();
280295
281296 self .allocator .free (self .buf_hidden );
@@ -296,6 +311,44 @@ pub const TriModel = struct {
296311 for (self .kv_caches ) | * cache | {
297312 cache .reset ();
298313 }
314+ if (self .ternary_kv_caches ) | caches | {
315+ for (caches ) | * cache | {
316+ cache .reset ();
317+ }
318+ }
319+ }
320+
321+ /// Enable ternary KV cache for 16x memory reduction
322+ /// Call after load() but before inference
323+ pub fn enableTernaryKVCache (self : * TriModel ) ! void {
324+ if (self .ternary_kv_caches != null ) return ; // Already enabled
325+
326+ const header = self .header ;
327+ self .ternary_kv_caches = try self .allocator .alloc (kv_cache .TernaryKVCache , header .num_layers );
328+
329+ for (self .ternary_kv_caches .? ) | * cache | {
330+ cache .* = try kv_cache .TernaryKVCache .init (
331+ self .allocator ,
332+ header .num_kv_heads ,
333+ header .head_dim ,
334+ header .context_length ,
335+ );
336+ }
337+
338+ self .use_ternary_kv = true ;
339+
340+ // Print memory savings
341+ const f32_mem = header .num_layers * header .context_length * header .num_kv_heads * header .head_dim * 2 * 4 ;
342+ const ternary_mem = self .ternary_kv_caches .? [0 ].memoryUsage () * header .num_layers ;
343+ const ratio = @as (f32 , @floatFromInt (f32_mem )) / @as (f32 , @floatFromInt (ternary_mem ));
344+
345+ std .debug .print ("\n ╔══════════════════════════════════════════════════════════════╗\n " , .{});
346+ std .debug .print ("║ TERNARY KV CACHE ENABLED ║\n " , .{});
347+ std .debug .print ("╠══════════════════════════════════════════════════════════════╣\n " , .{});
348+ std .debug .print ("║ f32 KV cache: {d:>10} bytes ║\n " , .{f32_mem });
349+ std .debug .print ("║ Ternary KV cache: {d:>10} bytes ║\n " , .{ternary_mem });
350+ std .debug .print ("║ Compression: {d:>10.1}x ║\n " , .{ratio });
351+ std .debug .print ("╚══════════════════════════════════════════════════════════════╝\n " , .{});
299352 }
300353
301354 // Forward pass using TERNARY matmul (NO MULTIPLICATIONS!)
@@ -354,48 +407,69 @@ pub const TriModel = struct {
354407 self .rope .apply (self .buf_k [h * head_dim .. ][0.. head_dim ], pos );
355408 }
356409
357- // Update KV cache
358- self .kv_caches [layer_idx ].append (self .buf_k , self .buf_v );
359-
360- // SIMD-OPTIMIZED ATTENTION (no allocations in hot path)
410+ // Update KV cache (f32 or ternary)
361411 const scale = 1.0 / @sqrt (@as (f32 , @floatFromInt (head_dim )));
362- const kv_group_size = num_heads / num_kv_heads ;
363- const seq_len = self .kv_caches [layer_idx ].seq_len ;
364-
365- for (0.. num_heads ) | h | {
366- const kv_h = h / kv_group_size ;
367- const q_head = self .buf_q [h * head_dim .. ][0.. head_dim ];
368-
369- // Compute attention scores with SIMD dot product
370- for (0.. seq_len ) | t | {
371- const k_offset = t * num_kv_heads * head_dim + kv_h * head_dim ;
372- const k_vec = self .kv_caches [layer_idx ].k_cache [k_offset .. ][0.. head_dim ];
373- self .buf_scores [t ] = flash .simdDot (q_head , k_vec ) * scale ;
374- }
375412
376- // Softmax
377- inference .softmax (self .buf_scores [0.. seq_len ], self .buf_scores [0.. seq_len ]);
378-
379- // Weighted sum with SIMD
380- const out_head = self .buf_attn_out [h * head_dim .. ][0.. head_dim ];
381- @memset (out_head , 0.0 );
382-
383- for (0.. seq_len ) | t | {
384- const v_offset = t * num_kv_heads * head_dim + kv_h * head_dim ;
385- const v_vec = self .kv_caches [layer_idx ].v_cache [v_offset .. ][0.. head_dim ];
386- const score = self .buf_scores [t ];
387-
388- // SIMD scale-add
389- const Vec8 = @Vector (8 , f32 );
390- const weight_vec : Vec8 = @splat (score );
391- var j : usize = 0 ;
392- while (j + 8 <= head_dim ) : (j += 8 ) {
393- const out_vec : Vec8 = out_head [j .. ][0.. 8].* ;
394- const v_vec8 : Vec8 = v_vec [j .. ][0.. 8].* ;
395- out_head [j .. ][0.. 8].* = out_vec + v_vec8 * weight_vec ;
413+ if (self .use_ternary_kv and self .ternary_kv_caches != null ) {
414+ // TERNARY KV CACHE PATH (16x memory reduction)
415+ self .ternary_kv_caches .? [layer_idx ].append (self .buf_k , self .buf_v );
416+
417+ const seq_len = self .ternary_kv_caches .? [layer_idx ].seq_len ;
418+
419+ // Use ternary attention (no K dequantization!)
420+ flash .ternaryAttentionGQA (
421+ self .buf_attn_out ,
422+ self .buf_q ,
423+ & self .ternary_kv_caches .? [layer_idx ],
424+ num_heads ,
425+ num_kv_heads ,
426+ head_dim ,
427+ scale ,
428+ self .buf_scores ,
429+ );
430+ _ = seq_len ;
431+ } else {
432+ // F32 KV CACHE PATH (original)
433+ self .kv_caches [layer_idx ].append (self .buf_k , self .buf_v );
434+
435+ const kv_group_size = num_heads / num_kv_heads ;
436+ const seq_len = self .kv_caches [layer_idx ].seq_len ;
437+
438+ for (0.. num_heads ) | h | {
439+ const kv_h = h / kv_group_size ;
440+ const q_head = self .buf_q [h * head_dim .. ][0.. head_dim ];
441+
442+ // Compute attention scores with SIMD dot product
443+ for (0.. seq_len ) | t | {
444+ const k_offset = t * num_kv_heads * head_dim + kv_h * head_dim ;
445+ const k_vec = self .kv_caches [layer_idx ].k_cache [k_offset .. ][0.. head_dim ];
446+ self .buf_scores [t ] = flash .simdDot (q_head , k_vec ) * scale ;
396447 }
397- while (j < head_dim ) : (j += 1 ) {
398- out_head [j ] += score * v_vec [j ];
448+
449+ // Softmax
450+ inference .softmax (self .buf_scores [0.. seq_len ], self .buf_scores [0.. seq_len ]);
451+
452+ // Weighted sum with SIMD
453+ const out_head = self .buf_attn_out [h * head_dim .. ][0.. head_dim ];
454+ @memset (out_head , 0.0 );
455+
456+ for (0.. seq_len ) | t | {
457+ const v_offset = t * num_kv_heads * head_dim + kv_h * head_dim ;
458+ const v_vec = self .kv_caches [layer_idx ].v_cache [v_offset .. ][0.. head_dim ];
459+ const score_val = self .buf_scores [t ];
460+
461+ // SIMD scale-add
462+ const Vec8 = @Vector (8 , f32 );
463+ const weight_vec : Vec8 = @splat (score_val );
464+ var j : usize = 0 ;
465+ while (j + 8 <= head_dim ) : (j += 8 ) {
466+ const out_vec : Vec8 = out_head [j .. ][0.. 8].* ;
467+ const v_vec8 : Vec8 = v_vec [j .. ][0.. 8].* ;
468+ out_head [j .. ][0.. 8].* = out_vec + v_vec8 * weight_vec ;
469+ }
470+ while (j < head_dim ) : (j += 1 ) {
471+ out_head [j ] += score_val * v_vec [j ];
472+ }
399473 }
400474 }
401475 }
0 commit comments