Skip to content

Commit 877ef09

Browse files
gHashTagona-agent
andcommitted
docs: add Flash Attention (OPT-004) documentation
- Document OnlineSoftmaxState algorithm - Add memory analysis (O(n) vs O(n²)) - Include benchmark results (1.1-1.2x speedup on CPU) - Note: main benefit is memory reduction, not speed on CPU Co-authored-by: Ona <no-reply@ona.com>
1 parent 6f08e84 commit 877ef09

1 file changed

Lines changed: 60 additions & 1 deletion

File tree

docs/DISCOVERIES.md

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ Where:
206206
- [x] INF-003: KV Cache Optimization (+50% speed) ✅ Implemented
207207
- [ ] INF-004: Batch Processing (+300% throughput)
208208
- [ ] OPT-001: SIMD Vectorization (+400% matrix ops)
209-
- [ ] OPT-004: Flash Attention (+200% attention)
209+
- [x] OPT-004: Flash Attention (+10-20% attention, O(n) memory) ✅ Implemented
210210

211211
### Locked (Future)
212212

@@ -216,6 +216,65 @@ Where:
216216

217217
---
218218

219+
## Flash Attention (OPT-004)
220+
221+
**Status**: ✅ Implemented
222+
223+
### Implementation Details
224+
225+
| Component | File | Description |
226+
|-----------|------|-------------|
227+
| OnlineSoftmaxState | `flash_attention.zig` | Incremental softmax without full matrix |
228+
| simdDot | `flash_attention.zig` | SIMD-accelerated dot product |
229+
| flashAttentionHead | `flash_attention.zig` | Single head with tiling |
230+
| flashAttentionGQA | `flash_attention.zig` | Multi-head with GQA support |
231+
| standardAttention | `flash_attention.zig` | Baseline for comparison |
232+
233+
### Algorithm: Online Softmax
234+
235+
```
236+
Key insight: softmax(x) = exp(x - max) / sum(exp(x - max))
237+
238+
For each KV tile:
239+
1. Find block_max
240+
2. If block_max > global_max:
241+
- Rescale: sum_exp *= exp(old_max - new_max)
242+
- Rescale: output *= exp(old_max - new_max)
243+
3. Accumulate: sum_exp += exp(score - new_max)
244+
4. Accumulate: output += exp(score - new_max) * V
245+
5. Update global_max
246+
247+
Finalize: output /= sum_exp
248+
```
249+
250+
### Memory Analysis
251+
252+
| Method | Scores Memory | Total |
253+
|--------|---------------|-------|
254+
| Standard | O(seq_len) per head | O(num_heads * seq_len) |
255+
| Flash | O(TILE_SIZE_KV) constant | O(num_heads * head_dim) |
256+
| Savings | seq_len / 64 reduction | ~16x for 1024 tokens |
257+
258+
### Benchmark Results (32 heads, 64 head_dim)
259+
260+
| Seq Len | Standard (ms) | Flash (ms) | Speedup |
261+
|---------|---------------|------------|---------|
262+
| 32 | 0.040 | 0.035 | 1.13x |
263+
| 64 | 0.074 | 0.068 | 1.09x |
264+
| 128 | 0.152 | 0.138 | 1.10x |
265+
| 256 | 0.300 | 0.278 | 1.08x |
266+
| 512 | 0.605 | 0.544 | 1.11x |
267+
| 1024 | 1.384 | 1.184 | 1.17x |
268+
269+
**Note**: Main benefit is memory reduction, not speed on CPU. GPU implementations see 2-4x speedup due to memory bandwidth.
270+
271+
### Integration
272+
273+
- `tri_inference.zig`: Uses `flash.simdDot` for attention scores
274+
- Full `flashAttentionGQA` available but not yet integrated (requires refactoring)
275+
276+
---
277+
219278
## KV Cache Optimization (INF-003)
220279

221280
**Status**: ✅ Implemented

0 commit comments

Comments
 (0)