Skip to content

Commit f011e54

Browse files
committed
Fix lintrunner formatting and tighten test tolerances
UFMT/CLANGFORMAT fixes for import ordering, line wrapping, and missing blank lines flagged by CI lintrunner. Tighten max-abs-error tolerance from 0.05 to 1e-2 across tests and benchmarks (MAX_ABS_TOL constant). The benchmark cross-validation already used 1e-2; tests now match.
1 parent 5d3b620 commit f011e54

6 files changed

Lines changed: 89 additions & 41 deletions

File tree

backends/cuda/benchmarks/benchmark_sdpa.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121

2222
import torch
2323
import torch.nn.functional as F
24-
from torch.nn.attention import SDPBackend, sdpa_kernel
25-
from triton.testing import do_bench
2624

27-
from executorch.backends.cuda.triton.kernels.sdpa import sdpa as triton_sdpa
2825
from executorch.backends.cuda.triton.kernels.sdpa import (
26+
sdpa as triton_sdpa,
2927
sdpa_decode_splitk as triton_splitk,
3028
)
29+
from torch.nn.attention import sdpa_kernel, SDPBackend
30+
from triton.testing import do_bench
3131

3232

3333
# PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly.
@@ -59,14 +59,19 @@ def _run_splitk(q, k, v, attn_mask, enable_gqa):
5959

6060
def _run_pytorch_default(q, k, v, attn_mask, enable_gqa):
6161
return F.scaled_dot_product_attention(
62-
q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa,
62+
q,
63+
k,
64+
v,
65+
attn_mask=attn_mask,
66+
enable_gqa=enable_gqa,
6367
)
6468

6569

6670
def _make_pytorch_runner(backend: SDPBackend):
6771
def run(q, k, v, attn_mask, enable_gqa):
6872
with sdpa_kernel(backend):
6973
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
74+
7075
return run
7176

7277

@@ -82,7 +87,10 @@ def _run_flash(q, k, v, attn_mask, enable_gqa):
8287
"splitk": ("ET Split-K (GQA)", _run_splitk),
8388
"pytorch": ("PyTorch", _run_pytorch_default),
8489
"flash": ("Flash (expanded KV)", _run_flash),
85-
"efficient": ("Efficient (expanded KV)", _make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION)),
90+
"efficient": (
91+
"Efficient (expanded KV)",
92+
_make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION),
93+
),
8694
"math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)),
8795
}
8896

@@ -112,6 +120,7 @@ def _run_flash(q, k, v, attn_mask, enable_gqa):
112120

113121
# -- Helpers -----------------------------------------------------------------
114122

123+
115124
def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16):
116125
q = torch.randn(B, H_q, Lq, D, device=device, dtype=dtype)
117126
k = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype)
@@ -129,6 +138,10 @@ def _max_abs_error(out, ref):
129138
return (out.float() - ref.float()).abs().max().item()
130139

131140

141+
# Cross-backend validation tolerance (bf16 vs bf16).
142+
MAX_ABS_TOL = 1e-2
143+
144+
132145
def _bench_us(fn, num_warmup, num_iters):
133146
"""Return median latency in microseconds using triton.testing.do_bench."""
134147
ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median")
@@ -155,6 +168,7 @@ def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters):
155168

156169
# -- Main --------------------------------------------------------------------
157170

171+
158172
def _shape_label(shape):
159173
return (
160174
f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} "
@@ -231,7 +245,7 @@ def run_benchmark(
231245
if name == ref_name or outputs[name] is None:
232246
continue
233247
err = _max_abs_error(outputs[name], ref_out)
234-
assert err < 1e-2, (
248+
assert err < MAX_ABS_TOL, (
235249
f"Output mismatch for {_shape_label(shape)}: "
236250
f"{label} vs {BACKENDS[ref_name][0]}, "
237251
f"max abs error {err:.3e} >= 1e-2"
@@ -245,7 +259,9 @@ def run_benchmark(
245259
bk, bv, bmask = k_exp, v_exp, mask_exp
246260
else:
247261
bk, bv, bmask = k, v, mask
248-
times[name] = _try_bench(run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters)
262+
times[name] = _try_bench(
263+
run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters
264+
)
249265

250266
# Format row using col_widths
251267
ci = 0
@@ -266,7 +282,9 @@ def run_benchmark(
266282

267283

268284
def main():
269-
parser = argparse.ArgumentParser(description="Benchmark Triton SDPA vs PyTorch backends")
285+
parser = argparse.ArgumentParser(
286+
description="Benchmark Triton SDPA vs PyTorch backends"
287+
)
270288
parser.add_argument(
271289
"--scenario",
272290
choices=list(SCENARIOS.keys()) + ["all"],

backends/cuda/tests/test_triton_sdpa.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def _max_abs_error(out, ref):
6767
return (out.float() - ref.float()).abs().max().item()
6868

6969

70+
# bf16 kernel vs fp32 reference tolerance.
71+
# The benchmark cross-validates backends at 1e-2; tests use the same bar.
72+
MAX_ABS_TOL = 1e-2
73+
74+
7075
# ---------------------------------------------------------------------------
7176
# Test configurations adapted from FlashAttention
7277
# ---------------------------------------------------------------------------
@@ -130,7 +135,7 @@ def test_mha_basic(self):
130135

131136
self.assertFalse(torch.isnan(out).any(), "NaN in output")
132137
self.assertLess(
133-
_max_abs_error(out, ref), 0.05, f"D={D} Lq={Lq} Lk={Lk}"
138+
_max_abs_error(out, ref), MAX_ABS_TOL, f"D={D} Lq={Lq} Lk={Lk}"
134139
)
135140

136141
def test_mha_causal(self):
@@ -148,7 +153,7 @@ def test_mha_causal(self):
148153
ref = _reference_sdpa(q, k, v, is_causal=True)
149154

150155
self.assertFalse(torch.isnan(out).any())
151-
self.assertLess(_max_abs_error(out, ref), 0.05)
156+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
152157

153158
def test_mha_bool_mask(self):
154159
"""MHA with explicit bool attention mask."""
@@ -168,7 +173,7 @@ def test_mha_bool_mask(self):
168173
ref = _reference_sdpa(q, k, v, attn_mask=mask)
169174

170175
self.assertFalse(torch.isnan(out).any())
171-
self.assertLess(_max_abs_error(out, ref), 0.05)
176+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
172177

173178
def test_mha_non_pow2_head_dim(self):
174179
"""MHA with non-power-of-2 head dimensions."""
@@ -187,7 +192,7 @@ def test_mha_non_pow2_head_dim(self):
187192
ref = _reference_sdpa(q, k, v)
188193

189194
self.assertFalse(torch.isnan(out).any())
190-
self.assertLess(_max_abs_error(out, ref), 0.05)
195+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
191196

192197
def test_mha_non_pow2_causal(self):
193198
"""MHA with non-pow2 head dim and causal masking."""
@@ -204,7 +209,7 @@ def test_mha_non_pow2_causal(self):
204209
ref = _reference_sdpa(q, k, v, is_causal=True)
205210

206211
self.assertFalse(torch.isnan(out).any())
207-
self.assertLess(_max_abs_error(out, ref), 0.05)
212+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
208213

209214
# ------------------------------------------------------------------
210215
# GQA tests
@@ -230,7 +235,7 @@ def test_gqa_decode(self):
230235
self.assertEqual(out.shape, (B, H_q, Lq, D))
231236
self.assertFalse(torch.isnan(out).any())
232237
self.assertLess(
233-
_max_abs_error(out, ref), 0.05, f"{label} D={D} Lk={Lk}"
238+
_max_abs_error(out, ref), MAX_ABS_TOL, f"{label} D={D} Lk={Lk}"
234239
)
235240

236241
def test_gqa_decode_with_mask(self):
@@ -253,7 +258,7 @@ def test_gqa_decode_with_mask(self):
253258
ref = _reference_sdpa(q, k, v, attn_mask=mask)
254259

255260
self.assertFalse(torch.isnan(out).any())
256-
self.assertLess(_max_abs_error(out, ref), 0.05)
261+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
257262

258263
def test_gqa_short_seqlen(self):
259264
"""GQA with short seqlen_q (2-8)."""
@@ -270,7 +275,7 @@ def test_gqa_short_seqlen(self):
270275
ref = _reference_sdpa(q, k, v)
271276

272277
self.assertFalse(torch.isnan(out).any())
273-
self.assertLess(_max_abs_error(out, ref), 0.05)
278+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
274279

275280
def test_gqa_prefill(self):
276281
"""GQA prefill (long seqlen_q)."""
@@ -290,7 +295,7 @@ def test_gqa_prefill(self):
290295

291296
self.assertEqual(out.shape, (B, H_q, L, D))
292297
self.assertFalse(torch.isnan(out).any())
293-
self.assertLess(_max_abs_error(out, ref), 0.05)
298+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
294299

295300
def test_gqa_non_pow2_head_dim(self):
296301
"""GQA with non-power-of-2 head dimensions."""
@@ -308,7 +313,7 @@ def test_gqa_non_pow2_head_dim(self):
308313

309314
self.assertFalse(torch.isnan(out).any())
310315
self.assertLess(
311-
_max_abs_error(out, ref), 0.05, f"D={D} Lq={Lq} Lk={Lk}"
316+
_max_abs_error(out, ref), MAX_ABS_TOL, f"D={D} Lq={Lq} Lk={Lk}"
312317
)
313318

314319
def test_gqa_causal_prefill(self):
@@ -326,7 +331,7 @@ def test_gqa_causal_prefill(self):
326331
ref = _reference_sdpa(q, k, v, is_causal=True)
327332

328333
self.assertFalse(torch.isnan(out).any())
329-
self.assertLess(_max_abs_error(out, ref), 0.05)
334+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
330335

331336
def test_gqa_causal_decode_with_mask(self):
332337
"""GQA decode with causal-like bool mask (simulating KV cache)."""
@@ -352,7 +357,7 @@ def test_gqa_causal_decode_with_mask(self):
352357
ref = _reference_sdpa(q, k, v, attn_mask=mask)
353358

354359
self.assertFalse(torch.isnan(out).any())
355-
self.assertLess(_max_abs_error(out, ref), 0.05)
360+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
356361

357362
def test_gqa_batch_size(self):
358363
"""GQA with batch_size > 1."""
@@ -368,7 +373,7 @@ def test_gqa_batch_size(self):
368373
ref = _reference_sdpa(q, k, v)
369374

370375
self.assertFalse(torch.isnan(out).any())
371-
self.assertLess(_max_abs_error(out, ref), 0.05)
376+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
372377

373378
# ------------------------------------------------------------------
374379
# Qwen 3.5 MoE configuration
@@ -393,7 +398,7 @@ def test_qwen35_moe_config(self):
393398
self.assertEqual(out.shape, (B, H_q, Lq, D))
394399
self.assertFalse(torch.isnan(out).any())
395400
self.assertLess(
396-
_max_abs_error(out, ref), 0.05, f"Qwen config Lq={Lq} Lk={Lk}"
401+
_max_abs_error(out, ref), MAX_ABS_TOL, f"Qwen config Lq={Lq} Lk={Lk}"
397402
)
398403

399404
# ------------------------------------------------------------------
@@ -427,7 +432,7 @@ def test_custom_scale(self):
427432
ref = _reference_sdpa(q, k, v, scale=scale)
428433

429434
self.assertFalse(torch.isnan(out).any())
430-
self.assertLess(_max_abs_error(out, ref), 0.05)
435+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
431436

432437
def test_all_masked(self):
433438
"""All-masked block should produce zeros, not NaN."""
@@ -508,7 +513,7 @@ def test_non_pow2_no_mask(self):
508513
ref = _reference_sdpa(q, k, v)
509514

510515
self.assertFalse(torch.isnan(out).any())
511-
self.assertLess(_max_abs_error(out, ref), 0.05)
516+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
512517

513518

514519
if __name__ == "__main__":

backends/cuda/tests/test_triton_sdpa_splitk.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def _max_abs_error(out, ref):
6262
return (out.float() - ref.float()).abs().max().item()
6363

6464

65+
# bf16 kernel vs fp32 reference tolerance.
66+
# Matches benchmark cross-validation and test_triton_sdpa.py.
67+
MAX_ABS_TOL = 1e-2
68+
69+
6570
HEAD_DIMS_POW2 = [64, 128, 256]
6671

6772
GQA_CONFIGS = [
@@ -90,7 +95,9 @@ def setUpClass(cls):
9095
def test_decode_basic(self):
9196
"""GQA decode across head configs, head dims, and KV lengths."""
9297
for (H_q, H_kv, label), D, Lk in itertools.product(
93-
GQA_CONFIGS, HEAD_DIMS_POW2, LK_LENGTHS,
98+
GQA_CONFIGS,
99+
HEAD_DIMS_POW2,
100+
LK_LENGTHS,
94101
):
95102
with self.subTest(label=label, D=D, Lk=Lk):
96103
B, Lq = 1, 1
@@ -105,7 +112,8 @@ def test_decode_basic(self):
105112
self.assertEqual(out.shape, (B, H_q, Lq, D))
106113
self.assertFalse(torch.isnan(out).any(), "NaN in output")
107114
self.assertLess(
108-
_max_abs_error(out, ref), 0.05,
115+
_max_abs_error(out, ref),
116+
0.05,
109117
f"{label} D={D} Lk={Lk}",
110118
)
111119

@@ -126,7 +134,7 @@ def test_decode_with_mask(self):
126134
ref = _reference_sdpa(q, k, v, attn_mask=mask)
127135

128136
self.assertFalse(torch.isnan(out).any())
129-
self.assertLess(_max_abs_error(out, ref), 0.05)
137+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
130138

131139
def test_decode_mha(self):
132140
"""MHA (H_q==H_kv, num_groups=1) should work with split-K."""
@@ -142,7 +150,7 @@ def test_decode_mha(self):
142150
ref = _reference_sdpa(q, k, v)
143151

144152
self.assertFalse(torch.isnan(out).any())
145-
self.assertLess(_max_abs_error(out, ref), 0.05)
153+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
146154

147155
def test_qwen35_config(self):
148156
"""Exact Qwen3.5 MoE config: H_q=16, H_kv=2, D=256."""
@@ -162,7 +170,7 @@ def test_qwen35_config(self):
162170

163171
self.assertEqual(out.shape, (B, H_q, Lq, D))
164172
self.assertFalse(torch.isnan(out).any())
165-
self.assertLess(_max_abs_error(out, ref), 0.05)
173+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
166174

167175
def test_custom_scale(self):
168176
"""Non-default attention scale."""
@@ -177,7 +185,7 @@ def test_custom_scale(self):
177185
ref = _reference_sdpa(q, k, v, scale=scale)
178186

179187
self.assertFalse(torch.isnan(out).any())
180-
self.assertLess(_max_abs_error(out, ref), 0.05)
188+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
181189

182190
def test_cross_validate_with_sdpa(self):
183191
"""Split-K output matches tiled sdpa output for decode shapes."""
@@ -195,7 +203,8 @@ def test_cross_validate_with_sdpa(self):
195203
out_tiled = self.sdpa(q, k, v, attn_mask=mask, enable_gqa=True)
196204

197205
self.assertLess(
198-
_max_abs_error(out_splitk, out_tiled), 0.05,
206+
_max_abs_error(out_splitk, out_tiled),
207+
MAX_ABS_TOL,
199208
f"Split-K vs tiled mismatch at Lk={Lk}",
200209
)
201210

@@ -229,7 +238,7 @@ def test_lk_1(self):
229238
ref = _reference_sdpa(q, k, v)
230239

231240
self.assertFalse(torch.isnan(out).any())
232-
self.assertLess(_max_abs_error(out, ref), 0.05)
241+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
233242

234243
def test_batch_size(self):
235244
"""Batch size > 1."""
@@ -245,7 +254,7 @@ def test_batch_size(self):
245254
ref = _reference_sdpa(q, k, v)
246255

247256
self.assertFalse(torch.isnan(out).any())
248-
self.assertLess(_max_abs_error(out, ref), 0.05)
257+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
249258

250259
# ------------------------------------------------------------------
251260
# Validation errors

backends/cuda/triton/kernels/sdpa.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,9 +1438,21 @@ def sdpa_decode_splitk(
14381438
)
14391439

14401440
_launch_decode_splitk(
1441-
query, key, value, out,
1442-
B, H_q, H_kv, L_kv, D, sm_scale,
1443-
HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk,
1441+
query,
1442+
key,
1443+
value,
1444+
out,
1445+
B,
1446+
H_q,
1447+
H_kv,
1448+
L_kv,
1449+
D,
1450+
sm_scale,
1451+
HAS_MASK,
1452+
Mask_ptr,
1453+
stride_mb,
1454+
stride_mq,
1455+
stride_mk,
14441456
num_groups,
14451457
)
14461458
return out

0 commit comments

Comments
 (0)