Skip to content

Commit 96849d9

Browse files
gHashTagona-agent
andcommitted
feat(cache): implement KV cache compression with sliding window (OPT-C01)
- Add streamingAttention with sliding window mask - Add CompressionStats for monitoring cache efficiency - Integrate streaming attention into tri_inference.zig - Attention sink: keep first N tokens (default 4) - Local window: keep last M tokens (configurable) - Benchmark: 5x compression (500 tokens → 100 in cache) - Memory savings: 16x for 32K context with 2K window Co-authored-by: Ona <no-reply@ona.com>
1 parent f8dbeae commit 96849d9

4 files changed

Lines changed: 338 additions & 19 deletions

File tree

docs/DISCOVERIES.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Where:
8080
| OPT-T06 | Ternary Normalization | 16x | 0.2x | ✅ Implemented |
8181
| OPT-T07 | Batch Ternary MatMul | N/A | 2.28x | ✅ Implemented |
8282
| OPT-M01 | Memory-Mapped Loading | N/A | 30x load | ✅ Implemented |
83+
| OPT-C01 | KV Cache Compression | 5-16x | 1x | ✅ Implemented |
8384

8485
### Business Value
8586

@@ -446,6 +447,66 @@ var reader = try gguf.GGUFReader.init(allocator, "model.gguf");
446447
var reader = try gguf.MmapGGUFReader.init(allocator, "model.gguf");
447448
```
448449

450+
### KV Cache Compression (OPT-C01)
451+
452+
**Status**: ✅ Implemented
453+
454+
| Component | File | Description |
455+
|-----------|------|-------------|
456+
| SlidingWindowConfig | `kv_cache.zig` | Window size + sink tokens config |
457+
| RingKVCache | `kv_cache.zig` | Ring buffer with O(1) append |
458+
| streamingAttention | `kv_cache.zig` | Masked attention for sliding window |
459+
| CompressionStats | `kv_cache.zig` | Compression statistics |
460+
461+
**Sliding Window + Attention Sink:**
462+
```
463+
┌─────────────────────────────────────────────────────────────┐
464+
│ CONTEXT WINDOW │
465+
├─────────────────────────────────────────────────────────────┤
466+
│ [SINK] [EVICTED...] [LOCAL WINDOW] │
467+
│ ┌───┐ ┌───────────┐ ┌─────────────────────────────────┐ │
468+
│ │ 4 │ │ MASKED │ │ RECENT TOKENS │ │
469+
│ │tok│ │ (-inf) │ │ (attend here) │ │
470+
│ └───┘ └───────────┘ └─────────────────────────────────┘ │
471+
│ ↑ ↑ │
472+
│ Always Sliding │
473+
│ kept window │
474+
└─────────────────────────────────────────────────────────────┘
475+
```
476+
477+
**Benchmark Results (500 tokens, window=100):**
478+
```
479+
╔══════════════════════════════════════════════════════════════╗
480+
║ KV CACHE COMPRESSION STATS ║
481+
╠══════════════════════════════════════════════════════════════╣
482+
║ Total tokens seen: 500 ║
483+
║ Tokens in cache: 100 ║
484+
║ Evicted tokens: 400 ║
485+
║ Compression ratio: 5.0x ║
486+
║ Memory saved: 819200 bytes ║
487+
╚══════════════════════════════════════════════════════════════╝
488+
```
489+
490+
**Memory Comparison (32K context, 2K window):**
491+
- Standard: 32K × head_dim × 2 × layers × heads
492+
- Streaming: 2K × head_dim × 2 × layers × heads
493+
- **Savings: 16x memory reduction**
494+
495+
**Usage:**
496+
```zig
497+
// Configure sliding window
498+
const config = SlidingWindowConfig{
499+
.window_size = 2048,
500+
.sink_tokens = 4, // Keep first 4 tokens
501+
.local_tokens = 2044, // Keep last 2044 tokens
502+
};
503+
504+
var cache = try RingKVCache.init(allocator, num_heads, head_dim, 2048, config);
505+
506+
// Streaming attention automatically masks evicted tokens
507+
kv_cache.streamingAttention(output, query, &cache, head_idx, scores, scale);
508+
```
509+
449510
### Batch Processing (INF-004)
450511

451512
**Status**: ✅ Implemented
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# kv_cache_compression.vibee
2+
# KV Cache Compression with Sliding Window + Attention Sink
3+
# Enables infinite context with fixed memory
4+
5+
name: kv_cache_compression
6+
version: "1.0.0"
7+
language: zig
8+
module: kv_cache_compression
9+
10+
types:
11+
StreamingConfig:
12+
description: "Configuration for streaming/infinite context"
13+
fields:
14+
window_size: Int # Total window size (sink + local)
15+
sink_tokens: Int # First N tokens always kept
16+
local_tokens: Int # Recent tokens in sliding window
17+
use_sparse_attention: Bool # Apply window mask to attention
18+
19+
CompressionStats:
20+
description: "Statistics for cache compression"
21+
fields:
22+
total_tokens_seen: Int
23+
tokens_in_cache: Int
24+
evicted_tokens: Int
25+
compression_ratio: Float
26+
memory_saved_bytes: Int
27+
28+
behaviors:
29+
- name: apply_window_mask
30+
given: attention scores, window mask
31+
when: computing masked attention
32+
then: sets out-of-window scores to -inf before softmax
33+
34+
- name: streaming_attention
35+
given: query, RingKVCache, window config
36+
when: computing attention with sliding window
37+
then: only attends to sink tokens + local window
38+
39+
- name: get_compression_stats
40+
given: RingKVCache
41+
when: querying compression efficiency
42+
then: returns stats including memory saved
43+
44+
- name: configure_streaming
45+
given: model, StreamingConfig
46+
when: enabling streaming mode
47+
then: configures all layer caches for sliding window
48+
49+
# Algorithm:
50+
#
51+
# Standard Attention (O(N²) memory):
52+
# scores = Q @ K^T (all N tokens)
53+
# output = softmax(scores) @ V
54+
#
55+
# Streaming Attention (O(W) memory, W = window_size):
56+
# For each query position:
57+
# 1. Compute scores for sink tokens (first S)
58+
# 2. Compute scores for local window (last L)
59+
# 3. Mask out evicted tokens (set to -inf)
60+
# 4. Softmax over valid positions only
61+
# 5. Weighted sum of V
62+
#
63+
# Memory Comparison (context_length=32K, window=2K):
64+
# Standard: 32K × head_dim × 2 (K+V) × num_layers × num_heads
65+
# Streaming: 2K × head_dim × 2 (K+V) × num_layers × num_heads
66+
# Savings: 16x memory reduction!
67+
#
68+
# Attention Sink Insight:
69+
# First few tokens accumulate attention mass during training.
70+
# Keeping them prevents attention collapse on long sequences.
71+
# Typically 4 sink tokens is sufficient.

src/vibeec/kv_cache.zig

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,117 @@ pub const CacheStats = struct {
541541
memory_bytes: usize,
542542
};
543543

544+
// ═══════════════════════════════════════════════════════════════════════════════
545+
// STREAMING ATTENTION (Sliding Window + Attention Sink)
546+
// Enables infinite context with fixed memory
547+
// ═══════════════════════════════════════════════════════════════════════════════
548+
549+
/// Compute streaming attention with sliding window mask
550+
/// Only attends to sink tokens + local window, ignoring evicted tokens
551+
pub fn streamingAttention(
552+
output: []f32,
553+
query: []const f32,
554+
cache: *const RingKVCache,
555+
head_idx: usize,
556+
scores_buf: []f32,
557+
scale: f32,
558+
) void {
559+
const seq_len = cache.seqLen();
560+
const head_dim = cache.head_dim;
561+
562+
if (seq_len == 0) {
563+
@memset(output, 0.0);
564+
return;
565+
}
566+
567+
// Compute attention scores with window masking
568+
var max_score: f32 = -std.math.inf(f32);
569+
570+
for (0..seq_len) |t| {
571+
// Get logical position for window check
572+
const logical_pos = if (cache.total_tokens <= cache.max_seq_len)
573+
t
574+
else
575+
cache.total_tokens - cache.max_seq_len + t;
576+
577+
// Check if position is in window (sink or local)
578+
const in_window = cache.isInWindow(logical_pos);
579+
580+
if (in_window) {
581+
// Compute dot product
582+
const k_vec = cache.getK(t, head_idx);
583+
var dot: f32 = 0.0;
584+
for (0..head_dim) |j| {
585+
dot += query[j] * k_vec[j];
586+
}
587+
scores_buf[t] = dot * scale;
588+
if (scores_buf[t] > max_score) max_score = scores_buf[t];
589+
} else {
590+
// Mask out evicted tokens
591+
scores_buf[t] = -std.math.inf(f32);
592+
}
593+
}
594+
595+
// Softmax (numerically stable)
596+
var sum_exp: f32 = 0.0;
597+
for (0..seq_len) |t| {
598+
if (scores_buf[t] > -std.math.inf(f32)) {
599+
scores_buf[t] = @exp(scores_buf[t] - max_score);
600+
sum_exp += scores_buf[t];
601+
} else {
602+
scores_buf[t] = 0.0;
603+
}
604+
}
605+
606+
if (sum_exp > 0.0) {
607+
for (0..seq_len) |t| {
608+
scores_buf[t] /= sum_exp;
609+
}
610+
}
611+
612+
// Weighted sum of V
613+
@memset(output, 0.0);
614+
for (0..seq_len) |t| {
615+
if (scores_buf[t] > 0.0) {
616+
const v_vec = cache.getV(t, head_idx);
617+
const score_val = scores_buf[t];
618+
for (0..head_dim) |j| {
619+
output[j] += score_val * v_vec[j];
620+
}
621+
}
622+
}
623+
}
624+
625+
/// Compression statistics for streaming mode
626+
pub const CompressionStats = struct {
627+
total_tokens_seen: usize,
628+
tokens_in_cache: usize,
629+
evicted_tokens: usize,
630+
compression_ratio: f32,
631+
memory_saved_bytes: usize,
632+
effective_context: usize, // sink + local window
633+
634+
pub fn fromCache(cache: *const RingKVCache) CompressionStats {
635+
const cfg = cache.window_config;
636+
const effective = @min(cache.total_tokens, cfg.sink_tokens + cfg.local_tokens);
637+
const full_memory = cache.total_tokens * cache.num_kv_heads * cache.head_dim * 2 * @sizeOf(f32);
638+
const actual_memory = cache.memoryUsage();
639+
const saved = if (full_memory > actual_memory) full_memory - actual_memory else 0;
640+
641+
return CompressionStats{
642+
.total_tokens_seen = cache.total_tokens,
643+
.tokens_in_cache = cache.seqLen(),
644+
.evicted_tokens = cache.evicted_tokens,
645+
.compression_ratio = if (cache.total_tokens > 0)
646+
@as(f32, @floatFromInt(cache.total_tokens)) / @as(f32, @floatFromInt(cache.seqLen()))
647+
else
648+
1.0,
649+
.memory_saved_bytes = saved,
650+
.effective_context = effective,
651+
};
652+
}
653+
};
654+
544655
// ═══════════════════════════════════════════════════════════════════════════════
545656
// TERNARY KV-CACHE (OPT-T03)
546657
// 16x memory reduction via 2-bit quantization
@@ -1328,3 +1439,88 @@ test "batch kv cache" {
13281439
try std.testing.expect(seq2 != null);
13291440
try std.testing.expectEqual(@as(usize, 2), batch.activeCount());
13301441
}
1442+
1443+
test "streaming_attention_window" {
1444+
const allocator = std.testing.allocator;
1445+
1446+
// Create cache with small window for testing
1447+
const window_config = SlidingWindowConfig{
1448+
.window_size = 16,
1449+
.sink_tokens = 2, // Keep first 2 tokens
1450+
.local_tokens = 6, // Keep last 6 tokens
1451+
};
1452+
1453+
var cache = try RingKVCache.init(allocator, 1, 4, 16, window_config);
1454+
defer cache.deinit();
1455+
1456+
// Add 20 tokens (exceeds window)
1457+
for (0..20) |i| {
1458+
var k = [_]f32{ @floatFromInt(i), 0, 0, 0 };
1459+
var v = [_]f32{ 1, 0, 0, 0 };
1460+
cache.append(&k, &v);
1461+
}
1462+
1463+
// Check window membership
1464+
// Sink tokens (0, 1) should be in window
1465+
try std.testing.expect(cache.isInWindow(0));
1466+
try std.testing.expect(cache.isInWindow(1));
1467+
1468+
// Middle tokens should be evicted
1469+
try std.testing.expect(!cache.isInWindow(5));
1470+
try std.testing.expect(!cache.isInWindow(10));
1471+
1472+
// Recent tokens (14-19) should be in window
1473+
try std.testing.expect(cache.isInWindow(14));
1474+
try std.testing.expect(cache.isInWindow(19));
1475+
1476+
// Test streaming attention
1477+
const query = [_]f32{ 1, 0, 0, 0 };
1478+
var output: [4]f32 = undefined;
1479+
var scores: [16]f32 = undefined;
1480+
1481+
streamingAttention(&output, &query, &cache, 0, &scores, 1.0);
1482+
1483+
// Output should be non-zero (attention computed)
1484+
try std.testing.expect(output[0] != 0.0);
1485+
}
1486+
1487+
test "compression_stats" {
1488+
const allocator = std.testing.allocator;
1489+
1490+
const window_config = SlidingWindowConfig{
1491+
.window_size = 100,
1492+
.sink_tokens = 4,
1493+
.local_tokens = 96,
1494+
};
1495+
1496+
var cache = try RingKVCache.init(allocator, 4, 64, 100, window_config);
1497+
defer cache.deinit();
1498+
1499+
// Simulate long sequence
1500+
var k_buf: [256]f32 = undefined;
1501+
var v_buf: [256]f32 = undefined;
1502+
@memset(&k_buf, 0.1);
1503+
@memset(&v_buf, 0.2);
1504+
1505+
for (0..500) |_| {
1506+
cache.append(&k_buf, &v_buf);
1507+
}
1508+
1509+
const stats = CompressionStats.fromCache(&cache);
1510+
1511+
try std.testing.expectEqual(@as(usize, 500), stats.total_tokens_seen);
1512+
try std.testing.expectEqual(@as(usize, 100), stats.tokens_in_cache);
1513+
try std.testing.expectEqual(@as(usize, 400), stats.evicted_tokens);
1514+
try std.testing.expect(stats.compression_ratio >= 4.9); // 500/100 = 5x
1515+
1516+
std.debug.print("\n╔══════════════════════════════════════════════════════════════╗\n", .{});
1517+
std.debug.print("║ KV CACHE COMPRESSION STATS ║\n", .{});
1518+
std.debug.print("╠══════════════════════════════════════════════════════════════╣\n", .{});
1519+
std.debug.print("║ Total tokens seen: {d:>10} ║\n", .{stats.total_tokens_seen});
1520+
std.debug.print("║ Tokens in cache: {d:>10} ║\n", .{stats.tokens_in_cache});
1521+
std.debug.print("║ Evicted tokens: {d:>10} ║\n", .{stats.evicted_tokens});
1522+
std.debug.print("║ Compression ratio: {d:>10.1}x ║\n", .{stats.compression_ratio});
1523+
std.debug.print("║ Effective context: {d:>10} ║\n", .{stats.effective_context});
1524+
std.debug.print("║ Memory saved: {d:>10} bytes ║\n", .{stats.memory_saved_bytes});
1525+
std.debug.print("╚══════════════════════════════════════════════════════════════╝\n", .{});
1526+
}

src/vibeec/tri_inference.zig

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -757,27 +757,18 @@ pub const BatchTriModel = struct {
757757
for (0..num_heads) |h| {
758758
const kv_h = h / kv_group_size;
759759
const q_head = model.buf_q[h * head_dim ..][0..head_dim];
760-
761-
// Compute attention scores using RingKVCache
762-
for (0..seq_len) |t| {
763-
const k_vec = cache.getK(t, kv_h);
764-
model.buf_scores[t] = flash.simdDot(q_head, k_vec) * scale;
765-
}
766-
767-
// Softmax
768-
inference.softmax(model.buf_scores[0..seq_len], model.buf_scores[0..seq_len]);
769-
770-
// Weighted sum
771760
const out_head = model.buf_attn_out[h * head_dim ..][0..head_dim];
772-
@memset(out_head, 0.0);
773761

774-
for (0..seq_len) |t| {
775-
const v_vec = cache.getV(t, kv_h);
776-
const score_val = model.buf_scores[t];
777-
for (0..head_dim) |j| {
778-
out_head[j] += score_val * v_vec[j];
779-
}
780-
}
762+
// Use streaming attention with sliding window mask
763+
// This enables infinite context with fixed memory
764+
kv_cache.streamingAttention(
765+
out_head,
766+
q_head,
767+
cache,
768+
kv_h,
769+
model.buf_scores[0..seq_len],
770+
scale,
771+
);
781772
}
782773

783774
// Output projection

0 commit comments

Comments
 (0)