Skip to content

Commit 87e65ac

Browse files
authored
SDPA decode perf improvements for qwen-3.5-35B-A3B (#18759)
### Performance Improvements for SDPA Improves SDPA performance for decode sequences where $L_q = 1$. #### Benchmark: qwen3.5-35B-A3B * **Config:** Avg of 3 runs on A100. ### Decode Performance (tok/s) | Prompt | Decode Len | Baseline | Split-K | Speedup | |---|---|---|---|---| | P2 (1 tok) | 16 | 86.3 | 105.3 | +22% | | P2 | 64 | 87.1 | 108.4 | +24% | | P2 | 256 | 89.2 | 108.3 | +21% | | P2 | 1024 | 89.4 | 108.1 | +21% | | P15 (15 tok) | 16 | 85.9 | 102.8 | +20% | | P15 | 64 | 85.1 | 104.7 | +23% | | P15 | 256 | 88.4 | 108.8 | +23% | | P15 | 1024 | 90.0 | 107.2 | +19% | | P59 (59 tok) | 16 | 86.8 | 96.1 | +11% | | P59 | 64 | 89.5 | 99.8 | +12% | | P59 | 256 | 88.9 | 108.5 | +22% | | P59 | 1024 | 90.0 | 108.6 | +21% | | P120 (143 tok) | 16 | 87.5 | 105.0 | +20% | | P120 | 64 | 88.8 | 107.7 | +21% | | P120 | 256 | 90.3 | 107.6 | +19% | | P120 | 1024 | 89.4 | 109.3 | +22% | | P1000 (1694 tok)| 16 | 86.4 | 103.2 | +19% | | P1000 | 64 | 89.2 | 106.7 | +20% | | P1000 | 256 | 90.2 | 108.0 | +20% | | P1000 | 1024 | 89.7 | 108.0 | +20% | --- ### Prefill Performance (tok/s) | Prompt | Baseline | Split-K | Delta | |---|---|---|---| | P2 (1 tok) | 19.4 | 19.3 | ~same | | P15 (15 tok) | 192.8 | 191.6 | ~same | | P59 (59 tok) | 390.2 | 368.1 | -6% | | P120 (143 tok) | 512.4 | 481.9 | -6% | | P1000 (1694 tok)| 585.6 | 591.4 | +1% | > *Note: Prefill averaged across all 4 decode lengths per prompt since prefill is independent of decode length.* --- ### Summary * **Decode:** Split-K delivers +20% average (88.6 → 106.5 tok/s) * **Prefill:** similar between variants (both use tiled SDPA) * **Quality:** Verified identical at `temperature=0` (~25x speedup at the SDPA op level, for ~10.2K = 1024 tokens x 10 layers, calls we saw 5.3sec to 209ms speedup) #### Implementation Details * **Max Context Length:** 4K * **Kernel Constraints:** * **Baseline:** Updates example input shapes to remove the 64-token cap. * **Prefill:** Baseline and Split-K should be equivalent for prefill (both use `_sdpa_fwd_kernel_m64`).
1 parent ecf49ca commit 87e65ac

7 files changed

Lines changed: 1125 additions & 26 deletions

File tree

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Benchmark the Triton SDPA kernel against PyTorch SDPA backends.
10+
11+
Measures latency across decode shapes matching the Qwen3.5 MoE model
12+
(B=1, H_q=16, H_kv=2, D=256). The ET Triton kernel uses native GQA
13+
(2 KV heads), while Flash/Efficient/Math require pre-expanded KV
14+
(16 heads) since they lack native GQA support.
15+
16+
"""
17+
18+
import argparse
19+
import warnings
20+
from functools import partial
21+
22+
import torch
23+
import torch.nn.functional as F
24+
25+
from executorch.backends.cuda.triton.kernels.sdpa import (
26+
sdpa as triton_sdpa,
27+
sdpa_decode_splitk as triton_splitk,
28+
)
29+
from torch.nn.attention import sdpa_kernel, SDPBackend
30+
from triton.testing import do_bench
31+
32+
33+
# PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly.
34+
# We expand KV heads via repeat_interleave so they can run, matching what
35+
# the test reference does. This is fair: it measures the kernel itself, not
36+
# the GQA dispatch overhead.
37+
38+
39+
def _expand_kv(k, v, num_groups):
40+
if num_groups > 1:
41+
k = k.repeat_interleave(num_groups, dim=1)
42+
v = v.repeat_interleave(num_groups, dim=1)
43+
return k, v
44+
45+
46+
def _expand_mask(mask, H_q):
47+
if mask is not None and mask.shape[1] == 1 and H_q > 1:
48+
mask = mask.expand(-1, H_q, -1, -1)
49+
return mask
50+
51+
52+
def _run_triton(q, k, v, attn_mask, enable_gqa):
53+
return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
54+
55+
56+
def _run_splitk(q, k, v, attn_mask, enable_gqa):
57+
return triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
58+
59+
60+
def _run_pytorch_default(q, k, v, attn_mask, enable_gqa):
61+
return F.scaled_dot_product_attention(
62+
q,
63+
k,
64+
v,
65+
attn_mask=attn_mask,
66+
enable_gqa=enable_gqa,
67+
)
68+
69+
70+
def _make_pytorch_runner(backend: SDPBackend):
71+
def run(q, k, v, attn_mask, enable_gqa):
72+
with sdpa_kernel(backend):
73+
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
74+
75+
return run
76+
77+
78+
# Flash doesn't support attn_mask at all, only is_causal.
79+
# Our benchmark mask is all-ones, so no mask is equivalent.
80+
def _run_flash(q, k, v, attn_mask, enable_gqa):
81+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
82+
return F.scaled_dot_product_attention(q, k, v)
83+
84+
85+
BACKENDS = {
86+
"triton": ("ET Triton (GQA)", _run_triton),
87+
"splitk": ("ET Split-K (GQA)", _run_splitk),
88+
"pytorch": ("PyTorch", _run_pytorch_default),
89+
"flash": ("Flash (expanded KV)", _run_flash),
90+
"efficient": (
91+
"Efficient (expanded KV)",
92+
_make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION),
93+
),
94+
"math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)),
95+
}
96+
97+
# Backends that need KV heads expanded before calling (no native GQA support)
98+
_NEEDS_KV_EXPAND = {"flash", "efficient", "math"}
99+
100+
# -- Shapes ------------------------------------------------------------------
101+
102+
# Qwen3.5 MoE: B=1, H_q=16, H_kv=2, D=256
103+
QWEN35_BASE = {"B": 1, "H_q": 16, "H_kv": 2, "D": 256}
104+
105+
DECODE_SHAPES = [
106+
dict(**QWEN35_BASE, Lq=1, Lk=64),
107+
dict(**QWEN35_BASE, Lq=1, Lk=128),
108+
dict(**QWEN35_BASE, Lq=1, Lk=256),
109+
dict(**QWEN35_BASE, Lq=1, Lk=512),
110+
dict(**QWEN35_BASE, Lq=1, Lk=1024),
111+
dict(**QWEN35_BASE, Lq=1, Lk=2048),
112+
dict(**QWEN35_BASE, Lq=1, Lk=4096),
113+
dict(**QWEN35_BASE, Lq=1, Lk=8192),
114+
dict(**QWEN35_BASE, Lq=1, Lk=16384),
115+
]
116+
117+
SCENARIOS = {
118+
"decode": DECODE_SHAPES,
119+
}
120+
121+
# -- Helpers -----------------------------------------------------------------
122+
123+
124+
def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16):
125+
q = torch.randn(B, H_q, Lq, D, device=device, dtype=dtype)
126+
k = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype)
127+
v = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype)
128+
mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device=device)
129+
enable_gqa = H_q != H_kv
130+
num_groups = H_q // H_kv
131+
# Pre-expanded versions for backends without native GQA
132+
k_exp, v_exp = _expand_kv(k, v, num_groups)
133+
mask_exp = _expand_mask(mask, H_q)
134+
return q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa
135+
136+
137+
def _max_abs_error(out, ref):
138+
return (out.float() - ref.float()).abs().max().item()
139+
140+
141+
# Cross-backend validation tolerance (bf16 vs bf16).
142+
MAX_ABS_TOL = 1e-2
143+
144+
145+
def _bench_us(fn, num_warmup, num_iters):
146+
"""Return median latency in microseconds using triton.testing.do_bench."""
147+
ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median")
148+
return ms * 1000.0
149+
150+
151+
def _try_run(run_fn, q, k, v, mask, enable_gqa):
152+
"""Run a backend, returning output or None on failure."""
153+
try:
154+
return run_fn(q, k, v, mask, enable_gqa)
155+
except RuntimeError:
156+
return None
157+
158+
159+
def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters):
160+
"""Benchmark a backend, returning median us or None on failure."""
161+
fn = partial(run_fn, q, k, v, mask, enable_gqa)
162+
try:
163+
run_fn(q, k, v, mask, enable_gqa)
164+
return _bench_us(fn, num_warmup, num_iters)
165+
except RuntimeError:
166+
return None
167+
168+
169+
# -- Main --------------------------------------------------------------------
170+
171+
172+
def _shape_label(shape):
173+
return (
174+
f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} "
175+
f"D={shape['D']} Lq={shape['Lq']} Lk={shape['Lk']}"
176+
)
177+
178+
179+
def _short_label(shape, scenario="decode"):
180+
return f"Lq={shape['Lq']},Lk={shape['Lk']}"
181+
182+
183+
@torch.inference_mode()
184+
def run_benchmark(
185+
scenario: str = "decode",
186+
num_warmup: int = 25,
187+
num_iters: int = 100,
188+
):
189+
shapes = SCENARIOS[scenario]
190+
backends = [(name, *BACKENDS[name]) for name in BACKENDS]
191+
192+
device_name = torch.cuda.get_device_name()
193+
print()
194+
print("=" * 100)
195+
print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}")
196+
print(f" Device: {device_name}")
197+
print(f" Warmup: {num_warmup}, Iters: {num_iters}")
198+
print(f" Backends: {', '.join(label for _, label, _ in backends)}")
199+
print("=" * 100)
200+
201+
# Build column specs: (header_text, unit_text, min_width)
202+
# Each column gets width = max(len(header), len(unit), min_width)
203+
max_label = max(len(_short_label(s, scenario)) for s in shapes)
204+
col_specs = [("Shape", "", max(8, max_label))]
205+
for _, label, _ in backends:
206+
col_specs.append((label, "(us)", 8))
207+
208+
col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs]
209+
210+
header = " | ".join(
211+
f"{h:<{w}}" if i == 0 else f"{h:>{w}}"
212+
for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths))
213+
)
214+
units = " | ".join(
215+
f"{'':>{w}}" if i == 0 else f"{u:>{w}}"
216+
for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths))
217+
)
218+
print(header)
219+
print(units)
220+
print("-" * len(header))
221+
222+
for shape in shapes:
223+
q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors(**shape)
224+
225+
with warnings.catch_warnings():
226+
warnings.simplefilter("ignore")
227+
228+
# Validate outputs across backends before benchmarking
229+
outputs = {}
230+
for name, _label, run_fn in backends:
231+
if name in _NEEDS_KV_EXPAND:
232+
bk, bv, bmask = k_exp, v_exp, mask_exp
233+
else:
234+
bk, bv, bmask = k, v, mask
235+
outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa)
236+
237+
# Use PyTorch F.sdpa as the trusted reference — never validate
238+
# against our own Triton kernels.
239+
ref_name, ref_out = None, None
240+
if outputs.get("pytorch") is not None:
241+
ref_name, ref_out = "pytorch", outputs["pytorch"]
242+
243+
if ref_out is not None:
244+
for name, label, _ in backends:
245+
if name == ref_name or outputs[name] is None:
246+
continue
247+
err = _max_abs_error(outputs[name], ref_out)
248+
assert err < MAX_ABS_TOL, (
249+
f"Output mismatch for {_shape_label(shape)}: "
250+
f"{label} vs {BACKENDS[ref_name][0]}, "
251+
f"max abs error {err:.3e} >= 1e-2"
252+
)
253+
del outputs
254+
255+
# Benchmark all backends
256+
times = {}
257+
for name, _label, run_fn in backends:
258+
if name in _NEEDS_KV_EXPAND:
259+
bk, bv, bmask = k_exp, v_exp, mask_exp
260+
else:
261+
bk, bv, bmask = k, v, mask
262+
times[name] = _try_bench(
263+
run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters
264+
)
265+
266+
# Format row using col_widths
267+
ci = 0
268+
row_parts = [f"{_short_label(shape, scenario):<{col_widths[ci]}}"]
269+
ci += 1
270+
for name, _, _ in backends:
271+
t = times[name]
272+
w = col_widths[ci]
273+
row_parts.append(f"{t:>{w}.1f}" if t is not None else f"{'N/A':>{w}}")
274+
ci += 1
275+
print(" | ".join(row_parts))
276+
277+
del q, k, v, k_exp, v_exp, mask, mask_exp
278+
torch.cuda.empty_cache()
279+
280+
print("-" * len(header))
281+
print()
282+
283+
284+
def main():
285+
parser = argparse.ArgumentParser(
286+
description="Benchmark Triton SDPA vs PyTorch backends"
287+
)
288+
parser.add_argument(
289+
"--scenario",
290+
choices=list(SCENARIOS.keys()) + ["all"],
291+
default="all",
292+
help="Which shape set to benchmark (default: all)",
293+
)
294+
parser.add_argument("--num_warmup", type=int, default=25)
295+
parser.add_argument("--num_iters", type=int, default=100)
296+
args = parser.parse_args()
297+
298+
scenarios = list(SCENARIOS.keys()) if args.scenario == "all" else [args.scenario]
299+
for s in scenarios:
300+
run_benchmark(
301+
scenario=s,
302+
num_warmup=args.num_warmup,
303+
num_iters=args.num_iters,
304+
)
305+
306+
307+
if __name__ == "__main__":
308+
main()

0 commit comments

Comments
 (0)