Skip to content

Commit 449ac42

Browse files
gHashTagona-agent
andcommitted
feat(norm): implement ternary normalization (OPT-T06)
- Add TernaryNormWeights struct with 2-bit packed format - Implement quantizeToTernary for f32 → ternary conversion - Add ternaryRmsNorm and simdTernaryRmsNorm functions - Integrate into tri_inference.zig with enableTernaryNorm() - Memory savings: 16x (f32 → 2-bit per weight) - Speed: 0.2x (trades speed for memory) - Accuracy: <10% max relative error Co-authored-by: Ona <no-reply@ona.com>
1 parent 8928ae1 commit 449ac42

4 files changed

Lines changed: 496 additions & 13 deletions

File tree

docs/DISCOVERIES.md

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Where:
7777
| OPT-T03 | Ternary KV Cache | 16x | 1.5x | ✅ Implemented |
7878
| OPT-T04 | Ternary Attention | 16x | 1.5x | ✅ Implemented |
7979
| OPT-T05 | Ternary Embeddings | 12.8x | 1x | ✅ Implemented |
80-
| OPT-T06 | Ternary Normalization | 20x | 3x | 📋 Planned |
80+
| OPT-T06 | Ternary Normalization | 16x | 0.2x | ✅ Implemented |
8181

8282
### Business Value
8383

@@ -304,6 +304,52 @@ Compression: 12.8x
304304
- Ternary KV cache: 12.8x compression
305305
- Combined similarity: 0.88 (vs 0.93 with only KV cache)
306306

307+
### Ternary Normalization (OPT-T06)
308+
309+
**Status**: ✅ Implemented
310+
311+
| Component | File | Description |
312+
|-----------|------|-------------|
313+
| TernaryNormWeights | `simd_matmul.zig` | Packed ternary norm weights |
314+
| quantizeToTernary | `simd_matmul.zig` | Convert f32 → ternary |
315+
| ternaryRmsNorm | `simd_matmul.zig` | Scalar ternary RMSNorm |
316+
| simdTernaryRmsNorm | `simd_matmul.zig` | SIMD-optimized version |
317+
| enableTernaryNorm | `tri_inference.zig` | Enable for all layers |
318+
319+
**Memory Savings:**
320+
```
321+
f32 norm weights: hidden_size × 4 bytes
322+
Ternary norm weights: hidden_size / 4 bytes (2 bits per weight)
323+
Compression: 16x
324+
```
325+
326+
**Benchmark Results (hidden_size=2048, 10K iterations):**
327+
```
328+
╔══════════════════════════════════════════════════════════════╗
329+
║ TERNARY NORM BENCHMARK ║
330+
╠══════════════════════════════════════════════════════════════╣
331+
║ f32 RMSNorm: 617.6 ns/iter ║
332+
║ Ternary RMSNorm: 3040.3 ns/iter ║
333+
║ Speedup: 0.20x (slower) ║
334+
║ Memory savings: 16x ║
335+
╚══════════════════════════════════════════════════════════════╝
336+
```
337+
338+
**Key Insight:** Ternary normalization trades speed for memory. The unpacking overhead makes it ~5x slower than f32, but provides 16x memory reduction. This is useful for:
339+
- Memory-constrained devices (mobile, edge)
340+
- Large models where norm weights are significant
341+
- Scenarios where memory bandwidth is the bottleneck
342+
343+
**Accuracy:**
344+
- Max relative error: <10% (acceptable for inference)
345+
- Similar to INT8 quantization error margins
346+
347+
**Usage:**
348+
```zig
349+
var model = try TriModel.load(allocator, "model.tri");
350+
try model.enableTernaryNorm(); // 16x memory reduction for norm weights
351+
```
352+
307353
### Batch Processing (INF-004)
308354

309355
**Status**: ✅ Implemented
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# ternary_normalization.vibee
2+
# TernaryNorm: RMSNorm with ternary-quantized weights
3+
# Reduces weight memory by 16x (f32 -> 2-bit)
4+
5+
name: ternary_normalization
6+
version: "1.0.0"
7+
language: zig
8+
module: ternary_normalization
9+
10+
types:
11+
TernaryNormWeights:
12+
description: "Ternary-quantized normalization weights"
13+
fields:
14+
packed_ternary: List<u8> # 4 ternary values per byte
15+
scale: Float # Scale factor for reconstruction
16+
size: Int # Original weight count
17+
18+
NormConfig:
19+
description: "Normalization configuration"
20+
fields:
21+
eps: Float # Epsilon for numerical stability (default 1e-5)
22+
use_simd: Bool # Enable SIMD optimization
23+
24+
behaviors:
25+
- name: quantize_norm_weights
26+
given: f32 normalization weights array
27+
when: quantizing to ternary format
28+
then: returns TernaryNormWeights with packed ternary values and scale
29+
30+
- name: ternary_rms_norm
31+
given: input tensor, TernaryNormWeights, epsilon
32+
when: applying RMS normalization with ternary weights
33+
then: returns normalized output with ternary weight multiplication
34+
35+
- name: simd_ternary_rms_norm
36+
given: input tensor, TernaryNormWeights, epsilon
37+
when: applying SIMD-optimized RMS normalization
38+
then: returns normalized output using SIMD for sum-of-squares and ternary multiply
39+
40+
- name: unpack_ternary_weight
41+
given: packed byte, position (0-3)
42+
when: extracting single ternary value
43+
then: returns -1, 0, or +1
44+
45+
- name: ternary_multiply_add
46+
given: input value, ternary value (-1/0/+1), scale
47+
when: multiplying by ternary weight
48+
then: returns input * (ternary * scale) without actual multiplication
49+
50+
# Algorithm:
51+
# 1. RMS = sqrt(mean(x^2) + eps)
52+
# 2. x_norm = x / RMS
53+
# 3. output = x_norm * (ternary_weight * scale)
54+
#
55+
# Ternary multiply optimization:
56+
# - ternary = +1: output = x_norm * scale
57+
# - ternary = 0: output = 0
58+
# - ternary = -1: output = -x_norm * scale
59+
#
60+
# Memory savings:
61+
# - f32 weights: 4 bytes per weight
62+
# - ternary: 2 bits per weight = 0.25 bytes
63+
# - Compression: 16x
64+
65+
# Packing format:
66+
# Each byte stores 4 ternary values:
67+
# bits [1:0] = value 0 (00=-1, 01=0, 10=+1)
68+
# bits [3:2] = value 1
69+
# bits [5:4] = value 2
70+
# bits [7:6] = value 3

0 commit comments

Comments
 (0)