Skip to content

Commit 40e361b

Browse files
committed
Add split-K decode SDPA as standalone triton_op
Register `triton::sdpa_decode_splitk` as an independent op so AOTI can trace and compile it without the runtime L_kv conditional that prevents the split-K path from appearing in the standard `sdpa` op. The split-K (flash-decoding) approach partitions the KV sequence across CTAs and reduces partial softmax results in a second kernel. The benchmark script now includes the split-K column for comparison. BLOCK_G (the GQA group tile) uses _next_power_of_2_unclamped() to avoid inflating small group counts to 16. Phantom rows from over-sized tiles change register pressure and instruction scheduling, altering fp32 accumulation order enough to degrade output quality over long autoregressive sequences. Standalone kernel benchmark on H100 (Qwen3.5 MoE decode, B=1, H_q=16, H_kv=2, D=256, bf16): Lk ET Tiled (us) ET Split-K (us) Speedup 64 131.8 259.5 0.5x 512 98.9 221.5 0.4x 4096 199.9 214.4 0.9x 8192 392.2 211.3 1.9x 16384 775.3 211.8 3.7x Split-K breaks even around Lk=4096 and dominates at longer sequences where the tiled kernel's single-CTA-per-head bottleneck becomes severe. This PR was authored with the assistance of Claude
1 parent 8e2c488 commit 40e361b

5 files changed

Lines changed: 750 additions & 19 deletions

File tree

backends/cuda/benchmarks/benchmark_sdpa.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
import torch
2323
import torch.nn.functional as F
2424

25-
from executorch.backends.cuda.triton.kernels.sdpa import sdpa as triton_sdpa
25+
from executorch.backends.cuda.triton.kernels.sdpa import (
26+
sdpa as triton_sdpa,
27+
sdpa_decode_splitk as triton_splitk,
28+
)
2629
from torch.nn.attention import sdpa_kernel, SDPBackend
2730
from triton.testing import do_bench
2831

@@ -50,6 +53,10 @@ def _run_triton(q, k, v, attn_mask, enable_gqa):
5053
return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
5154

5255

56+
def _run_splitk(q, k, v, attn_mask, enable_gqa):
57+
return triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
58+
59+
5360
def _run_pytorch_default(q, k, v, attn_mask, enable_gqa):
5461
return F.scaled_dot_product_attention(
5562
q,
@@ -77,6 +84,7 @@ def _run_flash(q, k, v, attn_mask, enable_gqa):
7784

7885
BACKENDS = {
7986
"triton": ("ET Triton (GQA)", _run_triton),
87+
"splitk": ("ET Split-K (GQA)", _run_splitk),
8088
"pytorch": ("PyTorch", _run_pytorch_default),
8189
"flash": ("Flash (expanded KV)", _run_flash),
8290
"efficient": (

backends/cuda/tests/test_triton_sdpa.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def _max_abs_error(out, ref):
6767
return (out.float() - ref.float()).abs().max().item()
6868

6969

70+
# bf16 kernel vs fp32 reference tolerance.
71+
# The benchmark cross-validates backends at 1e-2; tests use the same bar.
72+
MAX_ABS_TOL = 1e-2
73+
74+
7075
# ---------------------------------------------------------------------------
7176
# Test configurations adapted from FlashAttention
7277
# ---------------------------------------------------------------------------
@@ -130,7 +135,7 @@ def test_mha_basic(self):
130135

131136
self.assertFalse(torch.isnan(out).any(), "NaN in output")
132137
self.assertLess(
133-
_max_abs_error(out, ref), 0.05, f"D={D} Lq={Lq} Lk={Lk}"
138+
_max_abs_error(out, ref), MAX_ABS_TOL, f"D={D} Lq={Lq} Lk={Lk}"
134139
)
135140

136141
def test_mha_causal(self):
@@ -148,7 +153,7 @@ def test_mha_causal(self):
148153
ref = _reference_sdpa(q, k, v, is_causal=True)
149154

150155
self.assertFalse(torch.isnan(out).any())
151-
self.assertLess(_max_abs_error(out, ref), 0.05)
156+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
152157

153158
def test_mha_bool_mask(self):
154159
"""MHA with explicit bool attention mask."""
@@ -168,7 +173,7 @@ def test_mha_bool_mask(self):
168173
ref = _reference_sdpa(q, k, v, attn_mask=mask)
169174

170175
self.assertFalse(torch.isnan(out).any())
171-
self.assertLess(_max_abs_error(out, ref), 0.05)
176+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
172177

173178
def test_mha_non_pow2_head_dim(self):
174179
"""MHA with non-power-of-2 head dimensions."""
@@ -187,7 +192,7 @@ def test_mha_non_pow2_head_dim(self):
187192
ref = _reference_sdpa(q, k, v)
188193

189194
self.assertFalse(torch.isnan(out).any())
190-
self.assertLess(_max_abs_error(out, ref), 0.05)
195+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
191196

192197
def test_mha_non_pow2_causal(self):
193198
"""MHA with non-pow2 head dim and causal masking."""
@@ -204,7 +209,7 @@ def test_mha_non_pow2_causal(self):
204209
ref = _reference_sdpa(q, k, v, is_causal=True)
205210

206211
self.assertFalse(torch.isnan(out).any())
207-
self.assertLess(_max_abs_error(out, ref), 0.05)
212+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
208213

209214
# ------------------------------------------------------------------
210215
# GQA tests
@@ -230,7 +235,7 @@ def test_gqa_decode(self):
230235
self.assertEqual(out.shape, (B, H_q, Lq, D))
231236
self.assertFalse(torch.isnan(out).any())
232237
self.assertLess(
233-
_max_abs_error(out, ref), 0.05, f"{label} D={D} Lk={Lk}"
238+
_max_abs_error(out, ref), MAX_ABS_TOL, f"{label} D={D} Lk={Lk}"
234239
)
235240

236241
def test_gqa_decode_with_mask(self):
@@ -253,7 +258,7 @@ def test_gqa_decode_with_mask(self):
253258
ref = _reference_sdpa(q, k, v, attn_mask=mask)
254259

255260
self.assertFalse(torch.isnan(out).any())
256-
self.assertLess(_max_abs_error(out, ref), 0.05)
261+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
257262

258263
def test_gqa_short_seqlen(self):
259264
"""GQA with short seqlen_q (2-8)."""
@@ -270,7 +275,7 @@ def test_gqa_short_seqlen(self):
270275
ref = _reference_sdpa(q, k, v)
271276

272277
self.assertFalse(torch.isnan(out).any())
273-
self.assertLess(_max_abs_error(out, ref), 0.05)
278+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
274279

275280
def test_gqa_prefill(self):
276281
"""GQA prefill (long seqlen_q)."""
@@ -290,7 +295,7 @@ def test_gqa_prefill(self):
290295

291296
self.assertEqual(out.shape, (B, H_q, L, D))
292297
self.assertFalse(torch.isnan(out).any())
293-
self.assertLess(_max_abs_error(out, ref), 0.05)
298+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
294299

295300
def test_gqa_non_pow2_head_dim(self):
296301
"""GQA with non-power-of-2 head dimensions."""
@@ -308,7 +313,7 @@ def test_gqa_non_pow2_head_dim(self):
308313

309314
self.assertFalse(torch.isnan(out).any())
310315
self.assertLess(
311-
_max_abs_error(out, ref), 0.05, f"D={D} Lq={Lq} Lk={Lk}"
316+
_max_abs_error(out, ref), MAX_ABS_TOL, f"D={D} Lq={Lq} Lk={Lk}"
312317
)
313318

314319
def test_gqa_causal_prefill(self):
@@ -326,7 +331,7 @@ def test_gqa_causal_prefill(self):
326331
ref = _reference_sdpa(q, k, v, is_causal=True)
327332

328333
self.assertFalse(torch.isnan(out).any())
329-
self.assertLess(_max_abs_error(out, ref), 0.05)
334+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
330335

331336
def test_gqa_causal_decode_with_mask(self):
332337
"""GQA decode with causal-like bool mask (simulating KV cache)."""
@@ -352,7 +357,7 @@ def test_gqa_causal_decode_with_mask(self):
352357
ref = _reference_sdpa(q, k, v, attn_mask=mask)
353358

354359
self.assertFalse(torch.isnan(out).any())
355-
self.assertLess(_max_abs_error(out, ref), 0.05)
360+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
356361

357362
def test_gqa_batch_size(self):
358363
"""GQA with batch_size > 1."""
@@ -368,7 +373,7 @@ def test_gqa_batch_size(self):
368373
ref = _reference_sdpa(q, k, v)
369374

370375
self.assertFalse(torch.isnan(out).any())
371-
self.assertLess(_max_abs_error(out, ref), 0.05)
376+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
372377

373378
# ------------------------------------------------------------------
374379
# Qwen 3.5 MoE configuration
@@ -393,7 +398,7 @@ def test_qwen35_moe_config(self):
393398
self.assertEqual(out.shape, (B, H_q, Lq, D))
394399
self.assertFalse(torch.isnan(out).any())
395400
self.assertLess(
396-
_max_abs_error(out, ref), 0.05, f"Qwen config Lq={Lq} Lk={Lk}"
401+
_max_abs_error(out, ref), MAX_ABS_TOL, f"Qwen config Lq={Lq} Lk={Lk}"
397402
)
398403

399404
# ------------------------------------------------------------------
@@ -427,7 +432,7 @@ def test_custom_scale(self):
427432
ref = _reference_sdpa(q, k, v, scale=scale)
428433

429434
self.assertFalse(torch.isnan(out).any())
430-
self.assertLess(_max_abs_error(out, ref), 0.05)
435+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
431436

432437
def test_all_masked(self):
433438
"""All-masked block should produce zeros, not NaN."""
@@ -508,7 +513,7 @@ def test_non_pow2_no_mask(self):
508513
ref = _reference_sdpa(q, k, v)
509514

510515
self.assertFalse(torch.isnan(out).any())
511-
self.assertLess(_max_abs_error(out, ref), 0.05)
516+
self.assertLess(_max_abs_error(out, ref), MAX_ABS_TOL)
512517

513518

514519
if __name__ == "__main__":

0 commit comments

Comments
 (0)