Commit f4ba28a
committed
Add split-K mid-M SDPA Triton kernel for EAGLE-3 target_verify at long context
EAGLE-3 verify is a target forward over M = chain+1 query rows. On gemma4-31B's
full-attention (global) layers the standard SDPA scans the whole max_seq_len KV
buffer on a (B, H) grid -- one CTA per head looping the key range serially -- so
at long context the verify attention is occupancy-starved and grows ~linearly
with context, dominating the round and turning speculative decoding into a net
loss; the M query rows otherwise ride along for free on the same K/V read.
This adds a length-bounded split-K mid-M SDPA path for that case. The Triton
kernel (backends/cuda/triton/kernels/sdpa_midm.py) bounds the key range to the
valid length and partitions it across CTAs with a split-K online-softmax plus
cross-split reduce (the flash-decoding trick), with sdpa.py-style guards for
tiles a row's causal mask empties. Gemma4_31B gains opt-in dispatch
(set_midm_sdpa): full-attention layers route verify windows with M in
[2, MIDM_MAX_M] through the kernel, while sliding-window, prefill, decode, and
other models stay on F.sdpa. The valid KV length reaches the kernel as the
length of a new target_verify kv_window input (a backed SymInt); export wires it
up behind --no-midm-sdpa and the runner feeds it each round. Verify global
attention then stays ~flat with context instead of growing.
Because kv_window's shape changes every round, target_verify can no longer be
captured as a CUDA graph, so the runner's --cuda_graph now defaults off.
Lossless: byte-identical to baseline greedy except rare near-tie argmax flips
(M=chain+1 verify vs M=1 decode FP non-associativity; the same prompts flip
without this kernel). Unit coverage in backends/cuda/tests/test_sdpa_midm.py.
Benchmarks need the 31B checkpoints + A100 + a long-context export, so they run
out of CI and are not kept in this message.
Authored with assistance from Claude Code.
ghstack-source-id: 9118464
ghstack-comment-id: 4734204816
Pull-Request: #203441 parent 1833677 commit f4ba28a
8 files changed
Lines changed: 774 additions & 81 deletions
File tree
- backends/cuda
- tests
- triton/kernels
- examples/models
- eagle3
- gemma4_31b
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
| |||
29 | 30 | | |
30 | 31 | | |
31 | 32 | | |
| 33 | + | |
32 | 34 | | |
33 | 35 | | |
34 | 36 | | |
| |||
0 commit comments