@@ -541,6 +541,117 @@ pub const CacheStats = struct {
541541 memory_bytes : usize ,
542542};
543543
544+ // ═══════════════════════════════════════════════════════════════════════════════
545+ // STREAMING ATTENTION (Sliding Window + Attention Sink)
546+ // Enables infinite context with fixed memory
547+ // ═══════════════════════════════════════════════════════════════════════════════
548+
549+ /// Compute streaming attention with sliding window mask
550+ /// Only attends to sink tokens + local window, ignoring evicted tokens
551+ pub fn streamingAttention (
552+ output : []f32 ,
553+ query : []const f32 ,
554+ cache : * const RingKVCache ,
555+ head_idx : usize ,
556+ scores_buf : []f32 ,
557+ scale : f32 ,
558+ ) void {
559+ const seq_len = cache .seqLen ();
560+ const head_dim = cache .head_dim ;
561+
562+ if (seq_len == 0 ) {
563+ @memset (output , 0.0 );
564+ return ;
565+ }
566+
567+ // Compute attention scores with window masking
568+ var max_score : f32 = - std .math .inf (f32 );
569+
570+ for (0.. seq_len ) | t | {
571+ // Get logical position for window check
572+ const logical_pos = if (cache .total_tokens <= cache .max_seq_len )
573+ t
574+ else
575+ cache .total_tokens - cache .max_seq_len + t ;
576+
577+ // Check if position is in window (sink or local)
578+ const in_window = cache .isInWindow (logical_pos );
579+
580+ if (in_window ) {
581+ // Compute dot product
582+ const k_vec = cache .getK (t , head_idx );
583+ var dot : f32 = 0.0 ;
584+ for (0.. head_dim ) | j | {
585+ dot += query [j ] * k_vec [j ];
586+ }
587+ scores_buf [t ] = dot * scale ;
588+ if (scores_buf [t ] > max_score ) max_score = scores_buf [t ];
589+ } else {
590+ // Mask out evicted tokens
591+ scores_buf [t ] = - std .math .inf (f32 );
592+ }
593+ }
594+
595+ // Softmax (numerically stable)
596+ var sum_exp : f32 = 0.0 ;
597+ for (0.. seq_len ) | t | {
598+ if (scores_buf [t ] > - std .math .inf (f32 )) {
599+ scores_buf [t ] = @exp (scores_buf [t ] - max_score );
600+ sum_exp += scores_buf [t ];
601+ } else {
602+ scores_buf [t ] = 0.0 ;
603+ }
604+ }
605+
606+ if (sum_exp > 0.0 ) {
607+ for (0.. seq_len ) | t | {
608+ scores_buf [t ] /= sum_exp ;
609+ }
610+ }
611+
612+ // Weighted sum of V
613+ @memset (output , 0.0 );
614+ for (0.. seq_len ) | t | {
615+ if (scores_buf [t ] > 0.0 ) {
616+ const v_vec = cache .getV (t , head_idx );
617+ const score_val = scores_buf [t ];
618+ for (0.. head_dim ) | j | {
619+ output [j ] += score_val * v_vec [j ];
620+ }
621+ }
622+ }
623+ }
624+
625+ /// Compression statistics for streaming mode
626+ pub const CompressionStats = struct {
627+ total_tokens_seen : usize ,
628+ tokens_in_cache : usize ,
629+ evicted_tokens : usize ,
630+ compression_ratio : f32 ,
631+ memory_saved_bytes : usize ,
632+ effective_context : usize , // sink + local window
633+
634+ pub fn fromCache (cache : * const RingKVCache ) CompressionStats {
635+ const cfg = cache .window_config ;
636+ const effective = @min (cache .total_tokens , cfg .sink_tokens + cfg .local_tokens );
637+ const full_memory = cache .total_tokens * cache .num_kv_heads * cache .head_dim * 2 * @sizeOf (f32 );
638+ const actual_memory = cache .memoryUsage ();
639+ const saved = if (full_memory > actual_memory ) full_memory - actual_memory else 0 ;
640+
641+ return CompressionStats {
642+ .total_tokens_seen = cache .total_tokens ,
643+ .tokens_in_cache = cache .seqLen (),
644+ .evicted_tokens = cache .evicted_tokens ,
645+ .compression_ratio = if (cache .total_tokens > 0 )
646+ @as (f32 , @floatFromInt (cache .total_tokens )) / @as (f32 , @floatFromInt (cache .seqLen ()))
647+ else
648+ 1.0 ,
649+ .memory_saved_bytes = saved ,
650+ .effective_context = effective ,
651+ };
652+ }
653+ };
654+
544655// ═══════════════════════════════════════════════════════════════════════════════
545656// TERNARY KV-CACHE (OPT-T03)
546657// 16x memory reduction via 2-bit quantization
@@ -1328,3 +1439,88 @@ test "batch kv cache" {
13281439 try std .testing .expect (seq2 != null );
13291440 try std .testing .expectEqual (@as (usize , 2 ), batch .activeCount ());
13301441}
1442+
1443+ test "streaming_attention_window" {
1444+ const allocator = std .testing .allocator ;
1445+
1446+ // Create cache with small window for testing
1447+ const window_config = SlidingWindowConfig {
1448+ .window_size = 16 ,
1449+ .sink_tokens = 2 , // Keep first 2 tokens
1450+ .local_tokens = 6 , // Keep last 6 tokens
1451+ };
1452+
1453+ var cache = try RingKVCache .init (allocator , 1 , 4 , 16 , window_config );
1454+ defer cache .deinit ();
1455+
1456+ // Add 20 tokens (exceeds window)
1457+ for (0.. 20) | i | {
1458+ var k = [_ ]f32 { @floatFromInt (i ), 0 , 0 , 0 };
1459+ var v = [_ ]f32 { 1 , 0 , 0 , 0 };
1460+ cache .append (& k , & v );
1461+ }
1462+
1463+ // Check window membership
1464+ // Sink tokens (0, 1) should be in window
1465+ try std .testing .expect (cache .isInWindow (0 ));
1466+ try std .testing .expect (cache .isInWindow (1 ));
1467+
1468+ // Middle tokens should be evicted
1469+ try std .testing .expect (! cache .isInWindow (5 ));
1470+ try std .testing .expect (! cache .isInWindow (10 ));
1471+
1472+ // Recent tokens (14-19) should be in window
1473+ try std .testing .expect (cache .isInWindow (14 ));
1474+ try std .testing .expect (cache .isInWindow (19 ));
1475+
1476+ // Test streaming attention
1477+ const query = [_ ]f32 { 1 , 0 , 0 , 0 };
1478+ var output : [4 ]f32 = undefined ;
1479+ var scores : [16 ]f32 = undefined ;
1480+
1481+ streamingAttention (& output , & query , & cache , 0 , & scores , 1.0 );
1482+
1483+ // Output should be non-zero (attention computed)
1484+ try std .testing .expect (output [0 ] != 0.0 );
1485+ }
1486+
1487+ test "compression_stats" {
1488+ const allocator = std .testing .allocator ;
1489+
1490+ const window_config = SlidingWindowConfig {
1491+ .window_size = 100 ,
1492+ .sink_tokens = 4 ,
1493+ .local_tokens = 96 ,
1494+ };
1495+
1496+ var cache = try RingKVCache .init (allocator , 4 , 64 , 100 , window_config );
1497+ defer cache .deinit ();
1498+
1499+ // Simulate long sequence
1500+ var k_buf : [256 ]f32 = undefined ;
1501+ var v_buf : [256 ]f32 = undefined ;
1502+ @memset (& k_buf , 0.1 );
1503+ @memset (& v_buf , 0.2 );
1504+
1505+ for (0.. 500) | _ | {
1506+ cache .append (& k_buf , & v_buf );
1507+ }
1508+
1509+ const stats = CompressionStats .fromCache (& cache );
1510+
1511+ try std .testing .expectEqual (@as (usize , 500 ), stats .total_tokens_seen );
1512+ try std .testing .expectEqual (@as (usize , 100 ), stats .tokens_in_cache );
1513+ try std .testing .expectEqual (@as (usize , 400 ), stats .evicted_tokens );
1514+ try std .testing .expect (stats .compression_ratio >= 4.9 ); // 500/100 = 5x
1515+
1516+ std .debug .print ("\n ╔══════════════════════════════════════════════════════════════╗\n " , .{});
1517+ std .debug .print ("║ KV CACHE COMPRESSION STATS ║\n " , .{});
1518+ std .debug .print ("╠══════════════════════════════════════════════════════════════╣\n " , .{});
1519+ std .debug .print ("║ Total tokens seen: {d:>10} ║\n " , .{stats .total_tokens_seen });
1520+ std .debug .print ("║ Tokens in cache: {d:>10} ║\n " , .{stats .tokens_in_cache });
1521+ std .debug .print ("║ Evicted tokens: {d:>10} ║\n " , .{stats .evicted_tokens });
1522+ std .debug .print ("║ Compression ratio: {d:>10.1}x ║\n " , .{stats .compression_ratio });
1523+ std .debug .print ("║ Effective context: {d:>10} ║\n " , .{stats .effective_context });
1524+ std .debug .print ("║ Memory saved: {d:>10} bytes ║\n " , .{stats .memory_saved_bytes });
1525+ std .debug .print ("╚══════════════════════════════════════════════════════════════╝\n " , .{});
1526+ }
0 commit comments