Skip to content

Commit 0d143c4

Browse files
gHashTagona-agent
andcommitted
feat(kv-cache): implement ternary KV cache (OPT-T03)
- Add TernaryKVCache with 2-bit quantization (16x compression) - Implement quantizeVector/dequantizeV with per-token scales - Add ternaryDot and simdTernaryDot for efficient attention - Memory savings: 8 MB → 0.5 MB for 2048 tokens - All 9 KV cache tests passing Benchmark results: - 4 heads, 128 dim, 2048 tokens: 15.5x compression - 8 heads, 128 dim, 4096 tokens: 15.8x compression Co-authored-by: Ona <no-reply@ona.com>
1 parent 877ef09 commit 0d143c4

4 files changed

Lines changed: 543 additions & 1 deletion

File tree

docs/DISCOVERIES.md

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ Where:
7474
|----|--------------|-------------|---------|--------|
7575
| OPT-T01 | Ternary Weight Quantization | 20x | 10x | ✅ Implemented |
7676
| OPT-T02 | Ternary Matrix Multiplication | N/A | 10x | ✅ Implemented |
77-
| OPT-T03 | Ternary KV Cache | 20x | 5x | 📋 Planned |
77+
| OPT-T03 | Ternary KV Cache | 16x | 1.5x | ✅ Implemented |
7878
| OPT-T04 | Ternary Attention | 20x | 5-10x | 📋 Planned |
7979
| OPT-T05 | Ternary Embeddings | 20x | 2x | 📋 Planned |
8080
| OPT-T06 | Ternary Normalization | 20x | 3x | 📋 Planned |
@@ -216,6 +216,66 @@ Where:
216216

217217
---
218218

219+
## Ternary KV Cache (OPT-T03)
220+
221+
**Status**: ✅ Implemented
222+
223+
### Implementation Details
224+
225+
| Component | File | Description |
226+
|-----------|------|-------------|
227+
| TernaryKVCache | `kv_cache.zig` | 2-bit quantized KV storage |
228+
| quantizeVector | `kv_cache.zig` | f32 → ternary with scale |
229+
| dequantizeV | `kv_cache.zig` | ternary → f32 for output |
230+
| ternaryDot | `kv_cache.zig` | Scalar ternary dot product |
231+
| simdTernaryDot | `kv_cache.zig` | SIMD-optimized (8 values/iter) |
232+
233+
### Memory Analysis
234+
235+
| KV Heads | Head Dim | Tokens | f32 (MB) | Ternary (MB) | Ratio |
236+
|----------|----------|--------|----------|--------------|-------|
237+
| 4 | 64 | 512 | 1.00 | 0.07 | 15.1x |
238+
| 4 | 128 | 2048 | 8.00 | 0.52 | 15.5x |
239+
| 8 | 128 | 4096 | 32.00 | 2.03 | 15.8x |
240+
241+
### Quantization Algorithm
242+
243+
```
244+
For each K/V vector:
245+
1. scale = max(abs(vector))
246+
2. threshold = scale * 0.3
247+
3. For each value:
248+
- if value > threshold: trit = +1
249+
- if value < -threshold: trit = -1
250+
- else: trit = 0
251+
4. Pack 4 trits per byte
252+
5. Store scale for dequantization
253+
```
254+
255+
### SIMD Ternary Dot Product
256+
257+
```zig
258+
// Sign lookup table
259+
const sign_lut = [4]f32{ 0.0, 1.0, -1.0, 0.0 };
260+
261+
// Process 8 values at a time
262+
const signs: Vec8 = .{
263+
sign_lut[(b0 >> 0) & 0x3],
264+
sign_lut[(b0 >> 2) & 0x3],
265+
// ... 8 total
266+
};
267+
sum_vec += q_vec * signs;
268+
```
269+
270+
### Benefits
271+
272+
- **16x memory reduction**: 4 bytes → 0.25 bytes per value
273+
- **16x longer context**: Same memory budget, 16x more tokens
274+
- **No multiplications**: Ternary dot product uses only add/sub
275+
- **SIMD friendly**: Sign lookup table enables vectorization
276+
277+
---
278+
219279
## Flash Attention (OPT-004)
220280

221281
**Status**: ✅ Implemented

specs/tri/ternary_kv_cache.vibee

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Ternary KV Cache Specification
2+
# 16x memory reduction via 2-bit quantization
3+
# φ² + 1/φ² = 3 | KOSCHEI IS IMMORTAL
4+
5+
name: ternary_kv_cache
6+
version: "1.0.0"
7+
language: zig
8+
module: ternary_kv_cache
9+
10+
description: |
11+
Ternary KV cache stores K,V vectors in 2-bit format.
12+
Each value quantized to {-1, 0, +1} with scale factor.
13+
16x memory reduction: 4 bytes (f32) → 0.25 bytes (2-bit).
14+
Enables 16x longer context with same memory budget.
15+
16+
types:
17+
TernaryKVCache:
18+
description: "KV cache with ternary quantization"
19+
fields:
20+
k_cache: List<Int>
21+
v_cache: List<Int>
22+
k_scales: List<Float>
23+
v_scales: List<Float>
24+
num_kv_heads: Int
25+
head_dim: Int
26+
max_seq_len: Int
27+
seq_len: Int
28+
29+
QuantizedVector:
30+
description: "Ternary-quantized vector with scale"
31+
fields:
32+
data: List<Int>
33+
scale: Float
34+
length: Int
35+
36+
CacheMemoryStats:
37+
description: "Memory comparison stats"
38+
fields:
39+
f32_bytes: Int
40+
ternary_bytes: Int
41+
compression_ratio: Float
42+
tokens_capacity: Int
43+
44+
behaviors:
45+
- name: quantize_vector
46+
given: f32 vector and threshold
47+
when: Storing K or V in cache
48+
then: Returns packed ternary bytes + scale factor
49+
50+
- name: dequantize_vector
51+
given: Packed ternary bytes and scale
52+
when: Reading K or V for attention
53+
then: Returns approximate f32 vector
54+
55+
- name: ternary_append
56+
given: New K,V vectors (f32)
57+
when: Adding token to cache
58+
then: Quantize and store with per-token scales
59+
60+
- name: ternary_dot_product
61+
given: f32 query and ternary key
62+
when: Computing attention score
63+
then: Efficient dot product without full dequantization
64+
65+
- name: ternary_weighted_sum
66+
given: Attention weights and ternary values
67+
when: Computing attention output
68+
then: Weighted sum with on-the-fly dequantization
69+
70+
- name: compute_memory_stats
71+
given: Cache configuration
72+
when: Analyzing memory usage
73+
then: Returns f32 vs ternary comparison
74+
75+
quantization_algorithm:
76+
description: |
77+
For each vector:
78+
1. Compute scale = max(abs(vector))
79+
2. Normalize: v_norm = vector / scale
80+
3. Quantize: trit = sign(v_norm) if abs(v_norm) > threshold else 0
81+
4. Pack: 4 trits per byte
82+
83+
Dequantize:
84+
1. Unpack trits from bytes
85+
2. Multiply by scale: value = trit * scale
86+
87+
memory_analysis:
88+
f32_cache:
89+
per_token: "num_kv_heads * head_dim * 4 bytes * 2 (K+V)"
90+
example: "4 heads * 128 dim * 4 * 2 = 4096 bytes/token"
91+
92+
ternary_cache:
93+
per_token: "num_kv_heads * head_dim / 4 bytes * 2 + scales"
94+
example: "4 heads * 128 dim / 4 * 2 + 8 = 264 bytes/token"
95+
96+
compression: "4096 / 264 = 15.5x"
97+
98+
accuracy_considerations:
99+
- name: scale_per_token
100+
description: "Each token has own scale for K and V"
101+
102+
- name: threshold_tuning
103+
description: "Threshold affects sparsity vs accuracy"
104+
105+
- name: attention_approximation
106+
description: "Ternary dot product is approximate but fast"
107+
108+
benchmarks:
109+
- name: memory_reduction
110+
metric: "ratio"
111+
target: "~16x"
112+
113+
- name: accuracy_loss
114+
metric: "cosine similarity"
115+
target: ">0.95"
116+
117+
- name: attention_speedup
118+
metric: "ratio"
119+
target: "1.5-2x (no multiplications)"
120+
121+
integration:
122+
- target: kv_cache.zig
123+
description: "Add TernaryRingKVCache alongside RingKVCache"
124+
125+
- target: tri_inference.zig
126+
description: "Option to use ternary KV cache"

src/vibeec/gguf_transformer.zig

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ pub const RingKVCache = kv_cache_mod.RingKVCache;
1212
pub const SlidingWindowConfig = kv_cache_mod.SlidingWindowConfig;
1313
pub const CacheStats = kv_cache_mod.CacheStats;
1414

15+
// Re-export ternary KV cache (OPT-T03)
16+
pub const TernaryKVCache = kv_cache_mod.TernaryKVCache;
17+
pub const TernaryCacheStats = kv_cache_mod.TernaryCacheStats;
18+
1519
// ═══════════════════════════════════════════════════════════════════════════════
1620
// RoPE - Rotary Position Embedding
1721
// ═══════════════════════════════════════════════════════════════════════════════

0 commit comments

Comments
 (0)