Skip to content

Commit 72dd6a6

Browse files
committed
Add split-K decode SDPA dispatch to Qwen3.5 MoE attention
Dual-method export (decode T=1, prefill T>=2) lets the model use a simple if/else on T instead of torch.cond, eliminating the GPU-to-CPU sync overhead that torch.cond's predicate evaluation requires. Decode calls sdpa_decode_splitk (split-K flash-decoding for high KV occupancy), prefill calls tiled sdpa. Guard sdpa_decode_splitk validation behind isinstance(L_q, int) so AOTI tracing with symbolic shapes doesn't trip the L_q==1 check. Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal, enable_gqa) for consistent API; unsupported args fail with clear messages.
1 parent 151692c commit 72dd6a6

2 files changed

Lines changed: 60 additions & 17 deletions

File tree

backends/cuda/triton/kernels/sdpa.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,35 +1390,69 @@ def sdpa_decode_splitk(
13901390
key: torch.Tensor,
13911391
value: torch.Tensor,
13921392
attn_mask: Optional[torch.Tensor] = None,
1393+
dropout_p: float = 0.0,
1394+
is_causal: bool = False,
13931395
scale: float = 0.0,
1396+
enable_gqa: bool = False,
13941397
) -> torch.Tensor:
1398+
"""Split-K flash-decoding SDPA for L_q=1 (decode step).
1399+
1400+
Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
1401+
enable_gqa is accepted but ignored — GQA is handled natively via
1402+
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
1403+
"""
1404+
_validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa)
1405+
13951406
B, H_q, L_q, D = query.shape
13961407
_, H_kv, L_kv, _ = key.shape
13971408

1398-
if L_q != 1:
1399-
raise RuntimeError(
1400-
f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}"
1401-
)
1402-
if H_q % H_kv != 0:
1403-
raise RuntimeError(
1404-
f"H_q must be divisible by H_kv; got H_q={H_q}, H_kv={H_kv}"
1405-
)
1406-
if not _is_power_of_2(D):
1409+
out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)
1410+
1411+
if is_causal:
14071412
raise RuntimeError(
1408-
f"sdpa_decode_splitk requires power-of-2 head dim; got D={D}"
1413+
"sdpa_decode_splitk does not support is_causal=True "
1414+
"(causal masking is a no-op at L_q=1; pass attn_mask instead)"
14091415
)
14101416

1417+
# Validation — only check at runtime (concrete shapes), not during AOTI
1418+
# tracing where shapes are symbolic. torch.cond traces both branches with
1419+
# the same symbolic L_q, so L_q is not necessarily 1 during tracing.
1420+
if isinstance(L_q, int):
1421+
if L_q != 1:
1422+
raise RuntimeError(
1423+
f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}"
1424+
)
1425+
if H_q % H_kv != 0:
1426+
raise RuntimeError(
1427+
f"H_q must be divisible by H_kv; got H_q={H_q}, H_kv={H_kv}"
1428+
)
1429+
if not _is_power_of_2(D):
1430+
raise RuntimeError(
1431+
f"sdpa_decode_splitk requires power-of-2 head dim; got D={D}"
1432+
)
1433+
14111434
num_groups = H_q // H_kv
1412-
out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)
14131435
sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale
14141436
HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params(
14151437
attn_mask, B, L_q, L_kv
14161438
)
14171439

14181440
_launch_decode_splitk(
1419-
query, key, value, out,
1420-
B, H_q, H_kv, L_kv, D, sm_scale,
1421-
HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk,
1441+
query,
1442+
key,
1443+
value,
1444+
out,
1445+
B,
1446+
H_q,
1447+
H_kv,
1448+
L_kv,
1449+
D,
1450+
sm_scale,
1451+
HAS_MASK,
1452+
Mask_ptr,
1453+
stride_mb,
1454+
stride_mq,
1455+
stride_mk,
14221456
num_groups,
14231457
)
14241458
return out
@@ -1430,7 +1464,10 @@ def _sdpa_decode_splitk_abstract(
14301464
key: torch.Tensor,
14311465
value: torch.Tensor,
14321466
attn_mask: Optional[torch.Tensor] = None,
1467+
dropout_p: float = 0.0,
1468+
is_causal: bool = False,
14331469
scale: float = 0.0,
1470+
enable_gqa: bool = False,
14341471
) -> torch.Tensor:
14351472
assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype"
14361473
B, H_q, L_q, D = query.shape

examples/models/qwen3_5_moe/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import torch
2222
import torch.nn as nn
23+
24+
from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk
2325
from torch.nn import functional as F
2426

2527

@@ -285,9 +287,13 @@ def forward(self, x, input_pos):
285287
)
286288
else:
287289
k, v = self.kv_cache.update(input_pos, k, v)
288-
y = F.scaled_dot_product_attention(
289-
q, k, v, attn_mask=attn_mask, enable_gqa=True
290-
)
290+
# The export produces two methods — decode (T=1, static) and
291+
# prefill (T>=2, dynamic). Each traces only one branch, so no
292+
# torch.cond is needed and we avoid GPU→CPU sync overhead.
293+
if T == 1:
294+
y = sdpa_decode_splitk(q, k, v, attn_mask=attn_mask)
295+
else:
296+
y = sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=True)
291297

292298
y = y.transpose(1, 2).contiguous().view(B, T, -1)
293299

0 commit comments

Comments
 (0)