Skip to content

Commit f4ba28a

Browse files
committed
Add split-K mid-M SDPA Triton kernel for EAGLE-3 target_verify at long context
EAGLE-3 verify is a target forward over M = chain+1 query rows. On gemma4-31B's full-attention (global) layers the standard SDPA scans the whole max_seq_len KV buffer on a (B, H) grid -- one CTA per head looping the key range serially -- so at long context the verify attention is occupancy-starved and grows ~linearly with context, dominating the round and turning speculative decoding into a net loss; the M query rows otherwise ride along for free on the same K/V read. This adds a length-bounded split-K mid-M SDPA path for that case. The Triton kernel (backends/cuda/triton/kernels/sdpa_midm.py) bounds the key range to the valid length and partitions it across CTAs with a split-K online-softmax plus cross-split reduce (the flash-decoding trick), with sdpa.py-style guards for tiles a row's causal mask empties. Gemma4_31B gains opt-in dispatch (set_midm_sdpa): full-attention layers route verify windows with M in [2, MIDM_MAX_M] through the kernel, while sliding-window, prefill, decode, and other models stay on F.sdpa. The valid KV length reaches the kernel as the length of a new target_verify kv_window input (a backed SymInt); export wires it up behind --no-midm-sdpa and the runner feeds it each round. Verify global attention then stays ~flat with context instead of growing. Because kv_window's shape changes every round, target_verify can no longer be captured as a CUDA graph, so the runner's --cuda_graph now defaults off. Lossless: byte-identical to baseline greedy except rare near-tie argmax flips (M=chain+1 verify vs M=1 decode FP non-associativity; the same prompts flip without this kernel). Unit coverage in backends/cuda/tests/test_sdpa_midm.py. Benchmarks need the 31B checkpoints + A100 + a long-context export, so they run out of CI and are not kept in this message. Authored with assistance from Claude Code. ghstack-source-id: 9118464 ghstack-comment-id: 4734204816 Pull-Request: #20344
1 parent 1833677 commit f4ba28a

8 files changed

Lines changed: 774 additions & 81 deletions

File tree

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Correctness (vs F.sdpa) + isolated speedup for the mid-M flash SDPA kernel.
8+
9+
CUDA + Triton only. Validates the length-bounded mid-M kernel against the exact
10+
attention the gemma4 full-attention layers compute (causal, enable_gqa, scale=1)
11+
and shows it beats a full-buffer F.sdpa when the valid length << max_seq_len.
12+
"""
13+
14+
import unittest
15+
16+
import torch
17+
18+
from executorch.backends.cuda.triton.kernels.sdpa_midm import (
19+
midm_sdpa,
20+
sdpa_midm,
21+
sdpa_midm_reference,
22+
)
23+
24+
25+
def _require_cuda(tc):
26+
if not torch.cuda.is_available():
27+
tc.skipTest("CUDA required")
28+
29+
30+
def _rand(B, Hkv, H, M, D, S, anchor, device="cuda", dtype=torch.bfloat16):
31+
q = torch.randn(B, H, M, D, device=device, dtype=dtype)
32+
k = torch.randn(B, Hkv, S, D, device=device, dtype=dtype)
33+
v = torch.randn(B, Hkv, S, D, device=device, dtype=dtype)
34+
input_pos = torch.arange(anchor, anchor + M, device=device, dtype=torch.long)
35+
return q, k, v, input_pos
36+
37+
38+
def _rel_err(a, b):
39+
return (
40+
(a.float() - b.float()).abs().mean() / b.float().abs().mean().clamp_min(1e-6)
41+
).item()
42+
43+
44+
class TestMidMSDPA(unittest.TestCase):
45+
def setUp(self):
46+
_require_cuda(self)
47+
torch.manual_seed(0)
48+
49+
def _check(self, B, Hkv, H, M, D, S, anchor, tol=0.02):
50+
q, k, v, pos = _rand(B, Hkv, H, M, D, S, anchor)
51+
got = sdpa_midm(q, k, v, pos, scale=1.0)
52+
ref = sdpa_midm_reference(q, k, v, pos, scale=1.0)
53+
self.assertEqual(got.shape, (B, H, M, D))
54+
err = _rel_err(got, ref)
55+
self.assertLess(err, tol, f"rel_err={err} for M={M} D={D} anchor={anchor}")
56+
57+
# gemma4 global-attention shape: H=32, HKV=4 (GQA 8), D=512.
58+
def test_global_layer_verify_window(self):
59+
for M in (2, 4, 5, 8):
60+
for anchor in (0, 17, 200, 1000):
61+
self._check(1, 4, 32, M, 512, 4096, anchor)
62+
63+
def test_other_gqa_and_headdim(self):
64+
# smaller config (head_dim 256, GQA 4) to exercise generality
65+
for M in (2, 5, 8):
66+
self._check(1, 2, 8, M, 256, 2048, 300)
67+
68+
def test_anchor_zero_single_diagonal(self):
69+
# anchor 0: row j attends keys [0, j] only
70+
self._check(1, 4, 32, 4, 512, 1024, 0)
71+
72+
def test_matches_full_buffer_fsdpa(self):
73+
# The bounded kernel must equal F.sdpa over the FULL buffer with the
74+
# model's causal additive mask (the rest masked to -inf).
75+
import torch.nn.functional as F
76+
77+
q, k, v, pos = _rand(1, 4, 32, 5, 512, 8192, 500)
78+
key_idx = torch.arange(8192, device="cuda")
79+
keep = key_idx[None, :] <= pos[:, None]
80+
am = torch.where(keep, 0.0, float("-inf")).to(q.dtype)
81+
full = F.scaled_dot_product_attention(
82+
q, k, v, attn_mask=am, is_causal=False, enable_gqa=True, scale=1.0
83+
)
84+
got = sdpa_midm(q, k, v, pos, scale=1.0)
85+
self.assertLess(_rel_err(got, full), 0.02)
86+
87+
def test_splitk_large_context(self):
88+
# Many active splits: 64K buffer, anchors across the range. Exercises the
89+
# cross-split online-softmax reduce at the lengths that motivated split-K.
90+
for anchor in (2048, 30000, 60000):
91+
for M in (2, 5, 8):
92+
self._check(1, 4, 32, M, 512, 65536, anchor)
93+
94+
def test_splitk_masked_and_boundary_splits(self):
95+
# anchor small vs a large buffer: late key-range splits are fully causal-
96+
# masked for the early rows (null partials), and a row's cutoff lands mid
97+
# chunk. Reduce must discard -inf/0 partials cleanly.
98+
for anchor in (1, 31, 33, 500):
99+
self._check(1, 2, 8, 5, 256, 65536, anchor)
100+
101+
def test_dispatch_falls_back(self):
102+
# M=1 and M>MIDM_MAX_M must take the F.sdpa path (not the mid-M kernel).
103+
import torch.nn.functional as F
104+
105+
for M in (1, 16):
106+
q, k, v, pos = _rand(1, 4, 32, M, 512, 1024, 100)
107+
am = torch.zeros(M, 1024, device="cuda", dtype=q.dtype)
108+
key_idx = torch.arange(1024, device="cuda")
109+
am = torch.where(key_idx[None, :] <= pos[:, None], 0.0, float("-inf")).to(
110+
q.dtype
111+
)
112+
out = midm_sdpa(q, k, v, pos, am, scale=1.0, enable=True)
113+
ref = F.scaled_dot_product_attention(
114+
q, k, v, attn_mask=am, is_causal=False, enable_gqa=True, scale=1.0
115+
)
116+
self.assertLess(_rel_err(out, ref), 0.02)
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main(verbosity=2)

backends/cuda/triton/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
int4_matvec,
1818
)
1919
from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk
20+
from executorch.backends.cuda.triton.kernels.sdpa_midm import sdpa_midm
2021
from executorch.backends.cuda.triton.kernels.topk import topk
2122

2223
__all__ = [
@@ -29,6 +30,7 @@
2930
"moe_align_block_size",
3031
"sdpa",
3132
"sdpa_decode_splitk",
33+
"sdpa_midm",
3234
"topk",
3335
]
3436

0 commit comments

Comments
 (0)