Skip to content

Commit 8e2c488

Browse files
committed
[aoti-cuda] Add SDPA benchmarking script with qwen-3.5-35B-A3B shapes
Compares ET Triton SDPA (native GQA) against PyTorch Flash/Efficient/Math backends (expanded KV) across Lk=64..16K on A100. Uses triton.testing.do_bench for timing. Standalone script, no changes to the kernel. This PR was authored with the assistance of Claude
1 parent 2c545f8 commit 8e2c488

1 file changed

Lines changed: 300 additions & 0 deletions

File tree

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Benchmark the Triton SDPA kernel against PyTorch SDPA backends.
10+
11+
Measures latency across decode shapes matching the Qwen3.5 MoE model
12+
(B=1, H_q=16, H_kv=2, D=256). The ET Triton kernel uses native GQA
13+
(2 KV heads), while Flash/Efficient/Math require pre-expanded KV
14+
(16 heads) since they lack native GQA support.
15+
16+
"""
17+
18+
import argparse
19+
import warnings
20+
from functools import partial
21+
22+
import torch
23+
import torch.nn.functional as F
24+
25+
from executorch.backends.cuda.triton.kernels.sdpa import sdpa as triton_sdpa
26+
from torch.nn.attention import sdpa_kernel, SDPBackend
27+
from triton.testing import do_bench
28+
29+
30+
# PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly.
31+
# We expand KV heads via repeat_interleave so they can run, matching what
32+
# the test reference does. This is fair: it measures the kernel itself, not
33+
# the GQA dispatch overhead.
34+
35+
36+
def _expand_kv(k, v, num_groups):
37+
if num_groups > 1:
38+
k = k.repeat_interleave(num_groups, dim=1)
39+
v = v.repeat_interleave(num_groups, dim=1)
40+
return k, v
41+
42+
43+
def _expand_mask(mask, H_q):
44+
if mask is not None and mask.shape[1] == 1 and H_q > 1:
45+
mask = mask.expand(-1, H_q, -1, -1)
46+
return mask
47+
48+
49+
def _run_triton(q, k, v, attn_mask, enable_gqa):
50+
return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
51+
52+
53+
def _run_pytorch_default(q, k, v, attn_mask, enable_gqa):
54+
return F.scaled_dot_product_attention(
55+
q,
56+
k,
57+
v,
58+
attn_mask=attn_mask,
59+
enable_gqa=enable_gqa,
60+
)
61+
62+
63+
def _make_pytorch_runner(backend: SDPBackend):
64+
def run(q, k, v, attn_mask, enable_gqa):
65+
with sdpa_kernel(backend):
66+
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
67+
68+
return run
69+
70+
71+
# Flash doesn't support attn_mask at all, only is_causal.
72+
# Our benchmark mask is all-ones, so no mask is equivalent.
73+
def _run_flash(q, k, v, attn_mask, enable_gqa):
74+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
75+
return F.scaled_dot_product_attention(q, k, v)
76+
77+
78+
BACKENDS = {
79+
"triton": ("ET Triton (GQA)", _run_triton),
80+
"pytorch": ("PyTorch", _run_pytorch_default),
81+
"flash": ("Flash (expanded KV)", _run_flash),
82+
"efficient": (
83+
"Efficient (expanded KV)",
84+
_make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION),
85+
),
86+
"math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)),
87+
}
88+
89+
# Backends that need KV heads expanded before calling (no native GQA support)
90+
_NEEDS_KV_EXPAND = {"flash", "efficient", "math"}
91+
92+
# -- Shapes ------------------------------------------------------------------
93+
94+
# Qwen3.5 MoE: B=1, H_q=16, H_kv=2, D=256
95+
QWEN35_BASE = dict(B=1, H_q=16, H_kv=2, D=256)
96+
97+
DECODE_SHAPES = [
98+
dict(**QWEN35_BASE, Lq=1, Lk=64),
99+
dict(**QWEN35_BASE, Lq=1, Lk=128),
100+
dict(**QWEN35_BASE, Lq=1, Lk=256),
101+
dict(**QWEN35_BASE, Lq=1, Lk=512),
102+
dict(**QWEN35_BASE, Lq=1, Lk=1024),
103+
dict(**QWEN35_BASE, Lq=1, Lk=2048),
104+
dict(**QWEN35_BASE, Lq=1, Lk=4096),
105+
dict(**QWEN35_BASE, Lq=1, Lk=8192),
106+
dict(**QWEN35_BASE, Lq=1, Lk=16384),
107+
]
108+
109+
SCENARIOS = {
110+
"decode": DECODE_SHAPES,
111+
}
112+
113+
# -- Helpers -----------------------------------------------------------------
114+
115+
116+
def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16):
117+
q = torch.randn(B, H_q, Lq, D, device=device, dtype=dtype)
118+
k = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype)
119+
v = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype)
120+
mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device=device)
121+
enable_gqa = H_q != H_kv
122+
num_groups = H_q // H_kv
123+
# Pre-expanded versions for backends without native GQA
124+
k_exp, v_exp = _expand_kv(k, v, num_groups)
125+
mask_exp = _expand_mask(mask, H_q)
126+
return q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa
127+
128+
129+
def _max_abs_error(out, ref):
130+
return (out.float() - ref.float()).abs().max().item()
131+
132+
133+
# Cross-backend validation tolerance (bf16 vs bf16).
134+
MAX_ABS_TOL = 1e-2
135+
136+
137+
def _bench_us(fn, num_warmup, num_iters):
138+
"""Return median latency in microseconds using triton.testing.do_bench."""
139+
ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median")
140+
return ms * 1000.0
141+
142+
143+
def _try_run(run_fn, q, k, v, mask, enable_gqa):
144+
"""Run a backend, returning output or None on failure."""
145+
try:
146+
return run_fn(q, k, v, mask, enable_gqa)
147+
except RuntimeError:
148+
return None
149+
150+
151+
def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters):
152+
"""Benchmark a backend, returning median us or None on failure."""
153+
fn = partial(run_fn, q, k, v, mask, enable_gqa)
154+
try:
155+
run_fn(q, k, v, mask, enable_gqa)
156+
return _bench_us(fn, num_warmup, num_iters)
157+
except RuntimeError:
158+
return None
159+
160+
161+
# -- Main --------------------------------------------------------------------
162+
163+
164+
def _shape_label(shape):
165+
return (
166+
f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} "
167+
f"D={shape['D']} Lq={shape['Lq']} Lk={shape['Lk']}"
168+
)
169+
170+
171+
def _short_label(shape, scenario="decode"):
172+
return f"Lq={shape['Lq']},Lk={shape['Lk']}"
173+
174+
175+
@torch.inference_mode()
176+
def run_benchmark(
177+
scenario: str = "decode",
178+
num_warmup: int = 25,
179+
num_iters: int = 100,
180+
):
181+
shapes = SCENARIOS[scenario]
182+
backends = [(name, *BACKENDS[name]) for name in BACKENDS]
183+
184+
device_name = torch.cuda.get_device_name()
185+
print()
186+
print("=" * 100)
187+
print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}")
188+
print(f" Device: {device_name}")
189+
print(f" Warmup: {num_warmup}, Iters: {num_iters}")
190+
print(f" Backends: {', '.join(label for _, label, _ in backends)}")
191+
print("=" * 100)
192+
193+
# Build column specs: (header_text, unit_text, min_width)
194+
# Each column gets width = max(len(header), len(unit), min_width)
195+
max_label = max(len(_short_label(s, scenario)) for s in shapes)
196+
col_specs = [("Shape", "", max(8, max_label))]
197+
for _, label, _ in backends:
198+
col_specs.append((label, "(us)", 8))
199+
200+
col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs]
201+
202+
header = " | ".join(
203+
f"{h:<{w}}" if i == 0 else f"{h:>{w}}"
204+
for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths))
205+
)
206+
units = " | ".join(
207+
f"{'':>{w}}" if i == 0 else f"{u:>{w}}"
208+
for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths))
209+
)
210+
print(header)
211+
print(units)
212+
print("-" * len(header))
213+
214+
for shape in shapes:
215+
q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors(**shape)
216+
217+
with warnings.catch_warnings():
218+
warnings.simplefilter("ignore")
219+
220+
# Validate outputs across backends before benchmarking
221+
outputs = {}
222+
for name, _label, run_fn in backends:
223+
if name in _NEEDS_KV_EXPAND:
224+
bk, bv, bmask = k_exp, v_exp, mask_exp
225+
else:
226+
bk, bv, bmask = k, v, mask
227+
outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa)
228+
229+
# Use PyTorch F.sdpa as the trusted reference — never validate
230+
# against our own Triton kernels.
231+
ref_name, ref_out = None, None
232+
if outputs.get("pytorch") is not None:
233+
ref_name, ref_out = "pytorch", outputs["pytorch"]
234+
235+
if ref_out is not None:
236+
for name, label, _ in backends:
237+
if name == ref_name or outputs[name] is None:
238+
continue
239+
err = _max_abs_error(outputs[name], ref_out)
240+
assert err < MAX_ABS_TOL, (
241+
f"Output mismatch for {_shape_label(shape)}: "
242+
f"{label} vs {BACKENDS[ref_name][0]}, "
243+
f"max abs error {err:.3e} >= 1e-2"
244+
)
245+
del outputs
246+
247+
# Benchmark all backends
248+
times = {}
249+
for name, _label, run_fn in backends:
250+
if name in _NEEDS_KV_EXPAND:
251+
bk, bv, bmask = k_exp, v_exp, mask_exp
252+
else:
253+
bk, bv, bmask = k, v, mask
254+
times[name] = _try_bench(
255+
run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters
256+
)
257+
258+
# Format row using col_widths
259+
ci = 0
260+
row_parts = [f"{_short_label(shape, scenario):<{col_widths[ci]}}"]
261+
ci += 1
262+
for name, _, _ in backends:
263+
t = times[name]
264+
w = col_widths[ci]
265+
row_parts.append(f"{t:>{w}.1f}" if t is not None else f"{'N/A':>{w}}")
266+
ci += 1
267+
print(" | ".join(row_parts))
268+
269+
del q, k, v, k_exp, v_exp, mask, mask_exp
270+
torch.cuda.empty_cache()
271+
272+
print("-" * len(header))
273+
print()
274+
275+
276+
def main():
277+
parser = argparse.ArgumentParser(
278+
description="Benchmark Triton SDPA vs PyTorch backends"
279+
)
280+
parser.add_argument(
281+
"--scenario",
282+
choices=list(SCENARIOS.keys()) + ["all"],
283+
default="all",
284+
help="Which shape set to benchmark (default: all)",
285+
)
286+
parser.add_argument("--num_warmup", type=int, default=25)
287+
parser.add_argument("--num_iters", type=int, default=100)
288+
args = parser.parse_args()
289+
290+
scenarios = list(SCENARIOS.keys()) if args.scenario == "all" else [args.scenario]
291+
for s in scenarios:
292+
run_benchmark(
293+
scenario=s,
294+
num_warmup=args.num_warmup,
295+
num_iters=args.num_iters,
296+
)
297+
298+
299+
if __name__ == "__main__":
300+
main()

0 commit comments

Comments
 (0)