Skip to content

Commit 49123d7

Browse files
gHashTagona-agent
andcommitted
feat(gen): implement speculative decoding (OPT-S01)
- Add SpeculativeDecoder with self-speculation (early exit) - Add forwardDraft for fast draft generation using first N layers - Implement acceptance/rejection sampling with adjusted distribution - Expected speedup: 2-3x for generation throughput - Mathematically equivalent to standard sampling Co-authored-by: Ona <no-reply@ona.com>
1 parent 96849d9 commit 49123d7

3 files changed

Lines changed: 450 additions & 0 deletions

File tree

docs/DISCOVERIES.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Where:
8181
| OPT-T07 | Batch Ternary MatMul | N/A | 2.28x | ✅ Implemented |
8282
| OPT-M01 | Memory-Mapped Loading | N/A | 30x load | ✅ Implemented |
8383
| OPT-C01 | KV Cache Compression | 5-16x | 1x | ✅ Implemented |
84+
| OPT-S01 | Speculative Decoding | N/A | 2-3x gen | ✅ Implemented |
8485

8586
### Business Value
8687

@@ -507,6 +508,67 @@ var cache = try RingKVCache.init(allocator, num_heads, head_dim, 2048, config);
507508
kv_cache.streamingAttention(output, query, &cache, head_idx, scores, scale);
508509
```
509510

511+
### Speculative Decoding (OPT-S01)
512+
513+
**Status**: ✅ Implemented
514+
515+
| Component | File | Description |
516+
|-----------|------|-------------|
517+
| SpeculativeConfig | `tri_inference.zig` | Configuration for speculation |
518+
| SpeculativeDecoder | `tri_inference.zig` | Main speculative decoder |
519+
| forwardDraft | `tri_inference.zig` | Early-exit forward for draft |
520+
| verifyAndAccept | `tri_inference.zig` | Token verification logic |
521+
522+
**Algorithm:**
523+
```
524+
┌─────────────────────────────────────────────────────────────┐
525+
│ SPECULATIVE DECODING │
526+
├─────────────────────────────────────────────────────────────┤
527+
│ │
528+
│ 1. DRAFT: Generate K tokens with early-exit model │
529+
│ draft_tokens = [t1, t2, t3, t4] (fast, ~10ms) │
530+
│ │
531+
│ 2. VERIFY: Run full model on each token │
532+
│ For each draft token: │
533+
│ - Compute target probability │
534+
│ - Accept with prob min(1, p_target/p_draft) │
535+
│ - On reject: sample from adjusted distribution │
536+
│ │
537+
│ 3. BONUS: If all K accepted, sample K+1 from target │
538+
│ │
539+
└─────────────────────────────────────────────────────────────┘
540+
```
541+
542+
**Self-Speculation (Early Exit):**
543+
- Uses first N layers as draft model (default: 4 layers)
544+
- No separate draft model needed
545+
- Draft is ~4-8x faster than full model
546+
547+
**Expected Speedup:**
548+
```
549+
Speedup = K / (1 + (1-α)K)
550+
where α = acceptance rate, K = speculation length
551+
552+
For α=0.8, K=4: Speedup = 4 / 1.8 = 2.2x
553+
For α=0.9, K=4: Speedup = 4 / 1.4 = 2.9x
554+
```
555+
556+
**Usage:**
557+
```zig
558+
const config = SpeculativeConfig{
559+
.speculation_length = 4,
560+
.draft_layers = 4,
561+
.temperature = 1.0,
562+
};
563+
564+
var decoder = try SpeculativeDecoder.init(allocator, model, config);
565+
defer decoder.deinit();
566+
567+
const result = try decoder.generate(start_token, 0, 100);
568+
std.debug.print("Generated {d} tokens, acceptance rate: {d:.1}%\n",
569+
.{result.tokens.len, result.acceptance_rate * 100});
570+
```
571+
510572
### Batch Processing (INF-004)
511573

512574
**Status**: ✅ Implemented
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# speculative_decoding.vibee
2+
# Speculative Decoding for faster autoregressive generation
3+
# Generate multiple tokens per target model forward pass
4+
5+
name: speculative_decoding
6+
version: "1.0.0"
7+
language: zig
8+
module: speculative_decoding
9+
10+
types:
11+
SpeculativeConfig:
12+
description: "Configuration for speculative decoding"
13+
fields:
14+
speculation_length: Int # K: number of tokens to speculate
15+
temperature: Float # Sampling temperature
16+
use_tree_attention: Bool # Enable tree-based speculation
17+
18+
DraftResult:
19+
description: "Result from draft model speculation"
20+
fields:
21+
tokens: List<Int> # K speculated tokens
22+
probs: List<Float> # Draft probabilities for each token
23+
24+
VerificationResult:
25+
description: "Result from target model verification"
26+
fields:
27+
accepted_count: Int # Number of accepted tokens
28+
accepted_tokens: List<Int> # Accepted token sequence
29+
next_token: Int # Token sampled after rejection
30+
acceptance_rate: Float # Running acceptance rate
31+
32+
behaviors:
33+
- name: draft_speculate
34+
given: draft model, input token, position, K
35+
when: generating K candidate tokens
36+
then: returns DraftResult with tokens and probabilities
37+
38+
- name: target_verify
39+
given: target model, input sequence, draft tokens
40+
when: verifying draft tokens in parallel
41+
then: returns logits for all K+1 positions
42+
43+
- name: speculative_sample
44+
given: draft probs, target probs, draft token
45+
when: deciding to accept or reject
46+
then: accepts with prob min(1, p_target/p_draft), else samples correction
47+
48+
- name: speculative_generate
49+
given: target model, draft model, prompt, max_tokens
50+
when: generating with speculation
51+
then: returns generated tokens with speedup
52+
53+
# Algorithm:
54+
#
55+
# ┌─────────────────────────────────────────────────────────────┐
56+
# │ SPECULATIVE DECODING │
57+
# ├─────────────────────────────────────────────────────────────┤
58+
# │ │
59+
# │ 1. DRAFT: Generate K tokens with small model │
60+
# │ draft_tokens = [t1, t2, t3, t4] (fast, ~10ms) │
61+
# │ draft_probs = [p1, p2, p3, p4] │
62+
# │ │
63+
# │ 2. VERIFY: Run target model on all K tokens (parallel) │
64+
# │ target_logits = target.forward([t0, t1, t2, t3, t4]) │
65+
# │ (single forward pass, ~100ms) │
66+
# │ │
67+
# │ 3. ACCEPT/REJECT: For each position i: │
68+
# │ r = uniform(0, 1) │
69+
# │ if r < min(1, target_prob[i] / draft_prob[i]): │
70+
# │ ACCEPT token i │
71+
# │ else: │
72+
# │ REJECT: sample from (target - draft) distribution │
73+
# │ STOP speculation │
74+
# │ │
75+
# │ 4. BONUS: If all K accepted, sample K+1 from target │
76+
# │ │
77+
# └─────────────────────────────────────────────────────────────┘
78+
#
79+
# Speedup Analysis:
80+
# Without speculation: 1 token per forward pass
81+
# With speculation (K=4, α=0.8):
82+
# Expected tokens = 1 + α + α² + α³ + α⁴ = 3.36
83+
# Cost = 1 target + K draft ≈ 1.1 target (if draft is 10x faster)
84+
# Speedup = 3.36 / 1.1 ≈ 3x
85+
#
86+
# Self-Speculation (no draft model):
87+
# Use early exit from target model as draft
88+
# Or use same model with reduced layers

0 commit comments

Comments
 (0)