|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +""" |
| 9 | +Benchmark the Triton SDPA kernel against PyTorch SDPA backends. |
| 10 | +
|
| 11 | +Measures latency across decode shapes matching the Qwen3.5 MoE model |
| 12 | +(B=1, H_q=16, H_kv=2, D=256). The ET Triton kernel uses native GQA |
| 13 | +(2 KV heads), while Flash/Efficient/Math require pre-expanded KV |
| 14 | +(16 heads) since they lack native GQA support. |
| 15 | +
|
| 16 | +""" |
| 17 | + |
| 18 | +import argparse |
| 19 | +import warnings |
| 20 | +from functools import partial |
| 21 | + |
| 22 | +import torch |
| 23 | +import torch.nn.functional as F |
| 24 | + |
| 25 | +from executorch.backends.cuda.triton.kernels.sdpa import ( |
| 26 | + sdpa as triton_sdpa, |
| 27 | + sdpa_decode_splitk as triton_splitk, |
| 28 | +) |
| 29 | +from torch.nn.attention import sdpa_kernel, SDPBackend |
| 30 | +from triton.testing import do_bench |
| 31 | + |
| 32 | + |
| 33 | +# PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly. |
| 34 | +# We expand KV heads via repeat_interleave so they can run, matching what |
| 35 | +# the test reference does. This is fair: it measures the kernel itself, not |
| 36 | +# the GQA dispatch overhead. |
| 37 | + |
| 38 | + |
| 39 | +def _expand_kv(k, v, num_groups): |
| 40 | + if num_groups > 1: |
| 41 | + k = k.repeat_interleave(num_groups, dim=1) |
| 42 | + v = v.repeat_interleave(num_groups, dim=1) |
| 43 | + return k, v |
| 44 | + |
| 45 | + |
| 46 | +def _expand_mask(mask, H_q): |
| 47 | + if mask is not None and mask.shape[1] == 1 and H_q > 1: |
| 48 | + mask = mask.expand(-1, H_q, -1, -1) |
| 49 | + return mask |
| 50 | + |
| 51 | + |
| 52 | +def _run_triton(q, k, v, attn_mask, enable_gqa): |
| 53 | + return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) |
| 54 | + |
| 55 | + |
| 56 | +def _run_splitk(q, k, v, attn_mask, enable_gqa): |
| 57 | + return triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) |
| 58 | + |
| 59 | + |
| 60 | +def _run_pytorch_default(q, k, v, attn_mask, enable_gqa): |
| 61 | + return F.scaled_dot_product_attention( |
| 62 | + q, |
| 63 | + k, |
| 64 | + v, |
| 65 | + attn_mask=attn_mask, |
| 66 | + enable_gqa=enable_gqa, |
| 67 | + ) |
| 68 | + |
| 69 | + |
| 70 | +def _make_pytorch_runner(backend: SDPBackend): |
| 71 | + def run(q, k, v, attn_mask, enable_gqa): |
| 72 | + with sdpa_kernel(backend): |
| 73 | + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) |
| 74 | + |
| 75 | + return run |
| 76 | + |
| 77 | + |
| 78 | +# Flash doesn't support attn_mask at all, only is_causal. |
| 79 | +# Our benchmark mask is all-ones, so no mask is equivalent. |
| 80 | +def _run_flash(q, k, v, attn_mask, enable_gqa): |
| 81 | + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): |
| 82 | + return F.scaled_dot_product_attention(q, k, v) |
| 83 | + |
| 84 | + |
| 85 | +BACKENDS = { |
| 86 | + "triton": ("ET Triton (GQA)", _run_triton), |
| 87 | + "splitk": ("ET Split-K (GQA)", _run_splitk), |
| 88 | + "pytorch": ("PyTorch", _run_pytorch_default), |
| 89 | + "flash": ("Flash (expanded KV)", _run_flash), |
| 90 | + "efficient": ( |
| 91 | + "Efficient (expanded KV)", |
| 92 | + _make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION), |
| 93 | + ), |
| 94 | + "math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)), |
| 95 | +} |
| 96 | + |
| 97 | +# Backends that need KV heads expanded before calling (no native GQA support) |
| 98 | +_NEEDS_KV_EXPAND = {"flash", "efficient", "math"} |
| 99 | + |
| 100 | +# -- Shapes ------------------------------------------------------------------ |
| 101 | + |
| 102 | +# Qwen3.5 MoE: B=1, H_q=16, H_kv=2, D=256 |
| 103 | +QWEN35_BASE = {"B": 1, "H_q": 16, "H_kv": 2, "D": 256} |
| 104 | + |
| 105 | +DECODE_SHAPES = [ |
| 106 | + dict(**QWEN35_BASE, Lq=1, Lk=64), |
| 107 | + dict(**QWEN35_BASE, Lq=1, Lk=128), |
| 108 | + dict(**QWEN35_BASE, Lq=1, Lk=256), |
| 109 | + dict(**QWEN35_BASE, Lq=1, Lk=512), |
| 110 | + dict(**QWEN35_BASE, Lq=1, Lk=1024), |
| 111 | + dict(**QWEN35_BASE, Lq=1, Lk=2048), |
| 112 | + dict(**QWEN35_BASE, Lq=1, Lk=4096), |
| 113 | + dict(**QWEN35_BASE, Lq=1, Lk=8192), |
| 114 | + dict(**QWEN35_BASE, Lq=1, Lk=16384), |
| 115 | +] |
| 116 | + |
| 117 | +SCENARIOS = { |
| 118 | + "decode": DECODE_SHAPES, |
| 119 | +} |
| 120 | + |
| 121 | +# -- Helpers ----------------------------------------------------------------- |
| 122 | + |
| 123 | + |
| 124 | +def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16): |
| 125 | + q = torch.randn(B, H_q, Lq, D, device=device, dtype=dtype) |
| 126 | + k = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype) |
| 127 | + v = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype) |
| 128 | + mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device=device) |
| 129 | + enable_gqa = H_q != H_kv |
| 130 | + num_groups = H_q // H_kv |
| 131 | + # Pre-expanded versions for backends without native GQA |
| 132 | + k_exp, v_exp = _expand_kv(k, v, num_groups) |
| 133 | + mask_exp = _expand_mask(mask, H_q) |
| 134 | + return q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa |
| 135 | + |
| 136 | + |
| 137 | +def _max_abs_error(out, ref): |
| 138 | + return (out.float() - ref.float()).abs().max().item() |
| 139 | + |
| 140 | + |
| 141 | +# Cross-backend validation tolerance (bf16 vs bf16). |
| 142 | +MAX_ABS_TOL = 1e-2 |
| 143 | + |
| 144 | + |
| 145 | +def _bench_us(fn, num_warmup, num_iters): |
| 146 | + """Return median latency in microseconds using triton.testing.do_bench.""" |
| 147 | + ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median") |
| 148 | + return ms * 1000.0 |
| 149 | + |
| 150 | + |
| 151 | +def _try_run(run_fn, q, k, v, mask, enable_gqa): |
| 152 | + """Run a backend, returning output or None on failure.""" |
| 153 | + try: |
| 154 | + return run_fn(q, k, v, mask, enable_gqa) |
| 155 | + except RuntimeError: |
| 156 | + return None |
| 157 | + |
| 158 | + |
| 159 | +def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters): |
| 160 | + """Benchmark a backend, returning median us or None on failure.""" |
| 161 | + fn = partial(run_fn, q, k, v, mask, enable_gqa) |
| 162 | + try: |
| 163 | + run_fn(q, k, v, mask, enable_gqa) |
| 164 | + return _bench_us(fn, num_warmup, num_iters) |
| 165 | + except RuntimeError: |
| 166 | + return None |
| 167 | + |
| 168 | + |
| 169 | +# -- Main -------------------------------------------------------------------- |
| 170 | + |
| 171 | + |
| 172 | +def _shape_label(shape): |
| 173 | + return ( |
| 174 | + f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} " |
| 175 | + f"D={shape['D']} Lq={shape['Lq']} Lk={shape['Lk']}" |
| 176 | + ) |
| 177 | + |
| 178 | + |
| 179 | +def _short_label(shape, scenario="decode"): |
| 180 | + return f"Lq={shape['Lq']},Lk={shape['Lk']}" |
| 181 | + |
| 182 | + |
| 183 | +@torch.inference_mode() |
| 184 | +def run_benchmark( |
| 185 | + scenario: str = "decode", |
| 186 | + num_warmup: int = 25, |
| 187 | + num_iters: int = 100, |
| 188 | +): |
| 189 | + shapes = SCENARIOS[scenario] |
| 190 | + backends = [(name, *BACKENDS[name]) for name in BACKENDS] |
| 191 | + |
| 192 | + device_name = torch.cuda.get_device_name() |
| 193 | + print() |
| 194 | + print("=" * 100) |
| 195 | + print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}") |
| 196 | + print(f" Device: {device_name}") |
| 197 | + print(f" Warmup: {num_warmup}, Iters: {num_iters}") |
| 198 | + print(f" Backends: {', '.join(label for _, label, _ in backends)}") |
| 199 | + print("=" * 100) |
| 200 | + |
| 201 | + # Build column specs: (header_text, unit_text, min_width) |
| 202 | + # Each column gets width = max(len(header), len(unit), min_width) |
| 203 | + max_label = max(len(_short_label(s, scenario)) for s in shapes) |
| 204 | + col_specs = [("Shape", "", max(8, max_label))] |
| 205 | + for _, label, _ in backends: |
| 206 | + col_specs.append((label, "(us)", 8)) |
| 207 | + |
| 208 | + col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] |
| 209 | + |
| 210 | + header = " | ".join( |
| 211 | + f"{h:<{w}}" if i == 0 else f"{h:>{w}}" |
| 212 | + for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths)) |
| 213 | + ) |
| 214 | + units = " | ".join( |
| 215 | + f"{'':>{w}}" if i == 0 else f"{u:>{w}}" |
| 216 | + for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths)) |
| 217 | + ) |
| 218 | + print(header) |
| 219 | + print(units) |
| 220 | + print("-" * len(header)) |
| 221 | + |
| 222 | + for shape in shapes: |
| 223 | + q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors(**shape) |
| 224 | + |
| 225 | + with warnings.catch_warnings(): |
| 226 | + warnings.simplefilter("ignore") |
| 227 | + |
| 228 | + # Validate outputs across backends before benchmarking |
| 229 | + outputs = {} |
| 230 | + for name, _label, run_fn in backends: |
| 231 | + if name in _NEEDS_KV_EXPAND: |
| 232 | + bk, bv, bmask = k_exp, v_exp, mask_exp |
| 233 | + else: |
| 234 | + bk, bv, bmask = k, v, mask |
| 235 | + outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa) |
| 236 | + |
| 237 | + # Use PyTorch F.sdpa as the trusted reference — never validate |
| 238 | + # against our own Triton kernels. |
| 239 | + ref_name, ref_out = None, None |
| 240 | + if outputs.get("pytorch") is not None: |
| 241 | + ref_name, ref_out = "pytorch", outputs["pytorch"] |
| 242 | + |
| 243 | + if ref_out is not None: |
| 244 | + for name, label, _ in backends: |
| 245 | + if name == ref_name or outputs[name] is None: |
| 246 | + continue |
| 247 | + err = _max_abs_error(outputs[name], ref_out) |
| 248 | + assert err < MAX_ABS_TOL, ( |
| 249 | + f"Output mismatch for {_shape_label(shape)}: " |
| 250 | + f"{label} vs {BACKENDS[ref_name][0]}, " |
| 251 | + f"max abs error {err:.3e} >= 1e-2" |
| 252 | + ) |
| 253 | + del outputs |
| 254 | + |
| 255 | + # Benchmark all backends |
| 256 | + times = {} |
| 257 | + for name, _label, run_fn in backends: |
| 258 | + if name in _NEEDS_KV_EXPAND: |
| 259 | + bk, bv, bmask = k_exp, v_exp, mask_exp |
| 260 | + else: |
| 261 | + bk, bv, bmask = k, v, mask |
| 262 | + times[name] = _try_bench( |
| 263 | + run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters |
| 264 | + ) |
| 265 | + |
| 266 | + # Format row using col_widths |
| 267 | + ci = 0 |
| 268 | + row_parts = [f"{_short_label(shape, scenario):<{col_widths[ci]}}"] |
| 269 | + ci += 1 |
| 270 | + for name, _, _ in backends: |
| 271 | + t = times[name] |
| 272 | + w = col_widths[ci] |
| 273 | + row_parts.append(f"{t:>{w}.1f}" if t is not None else f"{'N/A':>{w}}") |
| 274 | + ci += 1 |
| 275 | + print(" | ".join(row_parts)) |
| 276 | + |
| 277 | + del q, k, v, k_exp, v_exp, mask, mask_exp |
| 278 | + torch.cuda.empty_cache() |
| 279 | + |
| 280 | + print("-" * len(header)) |
| 281 | + print() |
| 282 | + |
| 283 | + |
| 284 | +def main(): |
| 285 | + parser = argparse.ArgumentParser( |
| 286 | + description="Benchmark Triton SDPA vs PyTorch backends" |
| 287 | + ) |
| 288 | + parser.add_argument( |
| 289 | + "--scenario", |
| 290 | + choices=list(SCENARIOS.keys()) + ["all"], |
| 291 | + default="all", |
| 292 | + help="Which shape set to benchmark (default: all)", |
| 293 | + ) |
| 294 | + parser.add_argument("--num_warmup", type=int, default=25) |
| 295 | + parser.add_argument("--num_iters", type=int, default=100) |
| 296 | + args = parser.parse_args() |
| 297 | + |
| 298 | + scenarios = list(SCENARIOS.keys()) if args.scenario == "all" else [args.scenario] |
| 299 | + for s in scenarios: |
| 300 | + run_benchmark( |
| 301 | + scenario=s, |
| 302 | + num_warmup=args.num_warmup, |
| 303 | + num_iters=args.num_iters, |
| 304 | + ) |
| 305 | + |
| 306 | + |
| 307 | +if __name__ == "__main__": |
| 308 | + main() |
0 commit comments