Skip to content

Commit febc419

Browse files
committed
[aoti-cuda] Add SDPA benchmarking script with qwen-3.5-35B-A3B shapes
Compares ET Triton SDPA (native GQA) against PyTorch Flash/Efficient/Math backends (expanded KV) across Lk=64..16K on A100. Uses triton.testing.do_bench for timing. Standalone script, no changes to the kernel.
1 parent b24535b commit febc419

1 file changed

Lines changed: 282 additions & 0 deletions

File tree

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Benchmark the Triton SDPA kernel against PyTorch SDPA backends.
10+
11+
Measures latency across decode shapes matching the Qwen3.5 MoE model
12+
(B=1, H_q=16, H_kv=2, D=256). The ET Triton kernel uses native GQA
13+
(2 KV heads), while Flash/Efficient/Math require pre-expanded KV
14+
(16 heads) since they lack native GQA support.
15+
16+
"""
17+
18+
import argparse
19+
import warnings
20+
from functools import partial
21+
22+
import torch
23+
import torch.nn.functional as F
24+
from torch.nn.attention import SDPBackend, sdpa_kernel
25+
from triton.testing import do_bench
26+
27+
from executorch.backends.cuda.triton.kernels.sdpa import sdpa as triton_sdpa
28+
29+
30+
# PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly.
31+
# We expand KV heads via repeat_interleave so they can run, matching what
32+
# the test reference does. This is fair: it measures the kernel itself, not
33+
# the GQA dispatch overhead.
34+
35+
36+
def _expand_kv(k, v, num_groups):
37+
if num_groups > 1:
38+
k = k.repeat_interleave(num_groups, dim=1)
39+
v = v.repeat_interleave(num_groups, dim=1)
40+
return k, v
41+
42+
43+
def _expand_mask(mask, H_q):
44+
if mask is not None and mask.shape[1] == 1 and H_q > 1:
45+
mask = mask.expand(-1, H_q, -1, -1)
46+
return mask
47+
48+
49+
def _run_triton(q, k, v, attn_mask, enable_gqa):
50+
return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
51+
52+
53+
def _run_pytorch_default(q, k, v, attn_mask, enable_gqa):
54+
return F.scaled_dot_product_attention(
55+
q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa,
56+
)
57+
58+
59+
def _make_pytorch_runner(backend: SDPBackend):
60+
def run(q, k, v, attn_mask, enable_gqa):
61+
with sdpa_kernel(backend):
62+
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
63+
return run
64+
65+
66+
# Flash doesn't support attn_mask at all, only is_causal.
67+
# Our benchmark mask is all-ones, so no mask is equivalent.
68+
def _run_flash(q, k, v, attn_mask, enable_gqa):
69+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
70+
return F.scaled_dot_product_attention(q, k, v)
71+
72+
73+
BACKENDS = {
74+
"triton": ("ET Triton (GQA)", _run_triton),
75+
"pytorch": ("PyTorch", _run_pytorch_default),
76+
"flash": ("Flash (expanded KV)", _run_flash),
77+
"efficient": ("Efficient (expanded KV)", _make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION)),
78+
"math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)),
79+
}
80+
81+
# Backends that need KV heads expanded before calling (no native GQA support)
82+
_NEEDS_KV_EXPAND = {"flash", "efficient", "math"}
83+
84+
# -- Shapes ------------------------------------------------------------------
85+
86+
# Qwen3.5 MoE: B=1, H_q=16, H_kv=2, D=256
87+
QWEN35_BASE = dict(B=1, H_q=16, H_kv=2, D=256)
88+
89+
DECODE_SHAPES = [
90+
dict(**QWEN35_BASE, Lq=1, Lk=64),
91+
dict(**QWEN35_BASE, Lq=1, Lk=128),
92+
dict(**QWEN35_BASE, Lq=1, Lk=256),
93+
dict(**QWEN35_BASE, Lq=1, Lk=512),
94+
dict(**QWEN35_BASE, Lq=1, Lk=1024),
95+
dict(**QWEN35_BASE, Lq=1, Lk=2048),
96+
dict(**QWEN35_BASE, Lq=1, Lk=4096),
97+
dict(**QWEN35_BASE, Lq=1, Lk=8192),
98+
dict(**QWEN35_BASE, Lq=1, Lk=16384),
99+
]
100+
101+
SCENARIOS = {
102+
"decode": DECODE_SHAPES,
103+
}
104+
105+
# -- Helpers -----------------------------------------------------------------
106+
107+
def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16):
108+
q = torch.randn(B, H_q, Lq, D, device=device, dtype=dtype)
109+
k = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype)
110+
v = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype)
111+
mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device=device)
112+
enable_gqa = H_q != H_kv
113+
num_groups = H_q // H_kv
114+
# Pre-expanded versions for backends without native GQA
115+
k_exp, v_exp = _expand_kv(k, v, num_groups)
116+
mask_exp = _expand_mask(mask, H_q)
117+
return q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa
118+
119+
120+
def _max_abs_error(out, ref):
121+
return (out.float() - ref.float()).abs().max().item()
122+
123+
124+
def _bench_us(fn, num_warmup, num_iters):
125+
"""Return median latency in microseconds using triton.testing.do_bench."""
126+
ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median")
127+
return ms * 1000.0
128+
129+
130+
def _try_run(run_fn, q, k, v, mask, enable_gqa):
131+
"""Run a backend, returning output or None on failure."""
132+
try:
133+
return run_fn(q, k, v, mask, enable_gqa)
134+
except RuntimeError:
135+
return None
136+
137+
138+
def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters):
139+
"""Benchmark a backend, returning median us or None on failure."""
140+
fn = partial(run_fn, q, k, v, mask, enable_gqa)
141+
try:
142+
run_fn(q, k, v, mask, enable_gqa)
143+
return _bench_us(fn, num_warmup, num_iters)
144+
except RuntimeError:
145+
return None
146+
147+
148+
# -- Main --------------------------------------------------------------------
149+
150+
def _shape_label(shape):
151+
return (
152+
f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} "
153+
f"D={shape['D']} Lq={shape['Lq']} Lk={shape['Lk']}"
154+
)
155+
156+
157+
def _short_label(shape, scenario="decode"):
158+
return f"Lq={shape['Lq']},Lk={shape['Lk']}"
159+
160+
161+
@torch.inference_mode()
162+
def run_benchmark(
163+
scenario: str = "decode",
164+
num_warmup: int = 25,
165+
num_iters: int = 100,
166+
):
167+
shapes = SCENARIOS[scenario]
168+
backends = [(name, *BACKENDS[name]) for name in BACKENDS]
169+
170+
device_name = torch.cuda.get_device_name()
171+
print()
172+
print("=" * 100)
173+
print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}")
174+
print(f" Device: {device_name}")
175+
print(f" Warmup: {num_warmup}, Iters: {num_iters}")
176+
print(f" Backends: {', '.join(label for _, label, _ in backends)}")
177+
print("=" * 100)
178+
179+
# Build column specs: (header_text, unit_text, min_width)
180+
# Each column gets width = max(len(header), len(unit), min_width)
181+
max_label = max(len(_short_label(s, scenario)) for s in shapes)
182+
col_specs = [("Shape", "", max(8, max_label))]
183+
for _, label, _ in backends:
184+
col_specs.append((label, "(us)", 8))
185+
186+
col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs]
187+
188+
header = " | ".join(
189+
f"{h:<{w}}" if i == 0 else f"{h:>{w}}"
190+
for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths))
191+
)
192+
units = " | ".join(
193+
f"{'':>{w}}" if i == 0 else f"{u:>{w}}"
194+
for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths))
195+
)
196+
print(header)
197+
print(units)
198+
print("-" * len(header))
199+
200+
for shape in shapes:
201+
q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors(**shape)
202+
203+
with warnings.catch_warnings():
204+
warnings.simplefilter("ignore")
205+
206+
# Validate outputs across backends before benchmarking
207+
outputs = {}
208+
for name, _label, run_fn in backends:
209+
if name in _NEEDS_KV_EXPAND:
210+
bk, bv, bmask = k_exp, v_exp, mask_exp
211+
else:
212+
bk, bv, bmask = k, v, mask
213+
outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa)
214+
215+
ref_name, ref_out = None, None
216+
for name, _, _ in backends:
217+
if outputs[name] is not None:
218+
ref_name, ref_out = name, outputs[name]
219+
break
220+
221+
if ref_out is not None:
222+
for name, label, _ in backends:
223+
if name == ref_name or outputs[name] is None:
224+
continue
225+
err = _max_abs_error(outputs[name], ref_out)
226+
assert err < 1e-2, (
227+
f"Output mismatch for {_shape_label(shape)}: "
228+
f"{label} vs {BACKENDS[ref_name][0]}, "
229+
f"max abs error {err:.3e} >= 1e-2"
230+
)
231+
del outputs
232+
233+
# Benchmark all backends
234+
times = {}
235+
for name, _label, run_fn in backends:
236+
if name in _NEEDS_KV_EXPAND:
237+
bk, bv, bmask = k_exp, v_exp, mask_exp
238+
else:
239+
bk, bv, bmask = k, v, mask
240+
times[name] = _try_bench(run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters)
241+
242+
# Format row using col_widths
243+
ci = 0
244+
row_parts = [f"{_short_label(shape, scenario):<{col_widths[ci]}}"]
245+
ci += 1
246+
for name, _, _ in backends:
247+
t = times[name]
248+
w = col_widths[ci]
249+
row_parts.append(f"{t:>{w}.1f}" if t is not None else f"{'N/A':>{w}}")
250+
ci += 1
251+
print(" | ".join(row_parts))
252+
253+
del q, k, v, k_exp, v_exp, mask, mask_exp
254+
torch.cuda.empty_cache()
255+
256+
print("-" * len(header))
257+
print()
258+
259+
260+
def main():
261+
parser = argparse.ArgumentParser(description="Benchmark Triton SDPA vs PyTorch backends")
262+
parser.add_argument(
263+
"--scenario",
264+
choices=list(SCENARIOS.keys()) + ["all"],
265+
default="all",
266+
help="Which shape set to benchmark (default: all)",
267+
)
268+
parser.add_argument("--num_warmup", type=int, default=25)
269+
parser.add_argument("--num_iters", type=int, default=100)
270+
args = parser.parse_args()
271+
272+
scenarios = list(SCENARIOS.keys()) if args.scenario == "all" else [args.scenario]
273+
for s in scenarios:
274+
run_benchmark(
275+
scenario=s,
276+
num_warmup=args.num_warmup,
277+
num_iters=args.num_iters,
278+
)
279+
280+
281+
if __name__ == "__main__":
282+
main()

0 commit comments

Comments
 (0)