Skip to content

Commit d1394dd

Browse files
gHashTagona-agent
andcommitted
feat(embeddings): implement ternary embeddings (OPT-T05)
- Add TernaryEmbedding struct with per-token scales - Implement initFromF32() for f32 → ternary conversion - Add lookup() and lookupSIMD() for dequantization - Integrate into TriModel with enableTernaryEmbeddings() - Update validate_ternary.zig to test embeddings Results: - Embedding compression: 12.8x (8192 → 640 bytes) - Combined similarity: 0.88 (embeddings + KV cache) - All 6 ternary_weights tests passing Co-authored-by: Ona <no-reply@ona.com>
1 parent e72d1e8 commit d1394dd

5 files changed

Lines changed: 465 additions & 8 deletions

File tree

docs/DISCOVERIES.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ Where:
7676
| OPT-T02 | Ternary Matrix Multiplication | N/A | 10x | ✅ Implemented |
7777
| OPT-T03 | Ternary KV Cache | 16x | 1.5x | ✅ Implemented |
7878
| OPT-T04 | Ternary Attention | 16x | 1.5x | ✅ Implemented |
79-
| OPT-T05 | Ternary Embeddings | 20x | 2x | 📋 Planned |
79+
| OPT-T05 | Ternary Embeddings | 12.8x | 1x | ✅ Implemented |
8080
| OPT-T06 | Ternary Normalization | 20x | 3x | 📋 Planned |
8181

8282
### Business Value
@@ -281,6 +281,29 @@ const logits = try model.forward(token_id, position);
281281

282282
**Key insight:** Using RMS (root mean square) for scale instead of max preserves more information about value distribution. The threshold is set to 0.5 * RMS, which better separates signal from noise.
283283

284+
### Ternary Embeddings (OPT-T05)
285+
286+
**Status**: ✅ Implemented
287+
288+
| Component | File | Description |
289+
|-----------|------|-------------|
290+
| TernaryEmbedding | `ternary_weights.zig` | Ternary embedding table |
291+
| initFromF32 | `ternary_weights.zig` | Convert f32 → ternary |
292+
| lookup | `ternary_weights.zig` | Scalar dequantization |
293+
| lookupSIMD | `ternary_weights.zig` | SIMD-optimized lookup |
294+
295+
**Memory Savings:**
296+
```
297+
f32 embeddings: 8,192 bytes (32 vocab × 64 hidden × 4)
298+
Ternary embeddings: 640 bytes (32 vocab × (64/4 + 4))
299+
Compression: 12.8x
300+
```
301+
302+
**Combined Ternary Pipeline:**
303+
- Ternary embeddings: 12.8x compression
304+
- Ternary KV cache: 12.8x compression
305+
- Combined similarity: 0.88 (vs 0.93 with only KV cache)
306+
284307
### Test Results
285308

286309
```

specs/tri/ternary_embeddings.vibee

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Ternary Embeddings Specification
2+
# 16x memory reduction for token embeddings
3+
# φ² + 1/φ² = 3 | KOSCHEI IS IMMORTAL
4+
5+
name: ternary_embeddings
6+
version: "1.0.0"
7+
language: zig
8+
module: ternary_embeddings
9+
10+
description: |
11+
Ternary token embeddings using 2-bit quantization.
12+
Each embedding vector quantized to {-1, 0, +1} with per-token scale.
13+
16x memory reduction: vocab_size * hidden_size * 4 bytes → / 16.
14+
15+
types:
16+
TernaryEmbedding:
17+
description: "Ternary-quantized embedding table"
18+
fields:
19+
data: List<Int>
20+
scales: List<Float>
21+
vocab_size: Int
22+
hidden_size: Int
23+
24+
EmbeddingStats:
25+
description: "Memory usage statistics"
26+
fields:
27+
f32_bytes: Int
28+
ternary_bytes: Int
29+
compression_ratio: Float
30+
31+
behaviors:
32+
- name: init_from_f32
33+
given: f32 embedding table
34+
when: Converting to ternary
35+
then: Quantize each row with per-token scale
36+
37+
- name: lookup
38+
given: Token ID
39+
when: Getting embedding vector
40+
then: Dequantize on-the-fly and return f32 vector
41+
42+
- name: lookup_batch
43+
given: Array of token IDs
44+
when: Getting multiple embeddings
45+
then: Batch dequantization for efficiency
46+
47+
- name: compute_stats
48+
given: Embedding dimensions
49+
when: Analyzing memory usage
50+
then: Return compression ratio
51+
52+
quantization:
53+
method: rms_scale
54+
description: |
55+
For each embedding row:
56+
1. Compute RMS = sqrt(sum(x^2) / n)
57+
2. Scale = RMS * 1.5
58+
3. Threshold = RMS * 0.5
59+
4. Quantize: +1 if x > threshold, -1 if x < -threshold, else 0
60+
5. Pack 4 trits per byte
61+
62+
memory_analysis:
63+
f32_embedding:
64+
formula: "vocab_size * hidden_size * 4 bytes"
65+
example: "32000 * 4096 * 4 = 512 MB"
66+
67+
ternary_embedding:
68+
formula: "vocab_size * (hidden_size / 4 + 4) bytes"
69+
example: "32000 * (4096 / 4 + 4) = 32.5 MB"
70+
71+
compression: "~16x"
72+
73+
integration:
74+
- target: tri_inference.zig
75+
description: "Optional ternary embedding mode"
76+
77+
- target: TriModel
78+
description: "Add enableTernaryEmbeddings() method"

0 commit comments

Comments
 (0)