Skip to content

Commit ff207ea

Browse files
committed
Add split-K decode SDPA dispatch to Qwen3.5 MoE attention
Dual-method export (decode T=1, prefill T>=2) lets the model use a simple if/else on T instead of torch.cond, eliminating the GPU-to-CPU sync overhead that torch.cond's predicate evaluation requires. Decode calls sdpa_decode_splitk (split-K flash-decoding for high KV occupancy), prefill calls tiled sdpa. Guard sdpa_decode_splitk validation behind isinstance(L_q, int) so AOTI tracing with symbolic shapes doesn't trip the L_q==1 check. Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal, enable_gqa) for consistent API; unsupported args fail with clear messages. This PR was authored with the assistance of Claude
1 parent f2bcffd commit ff207ea

2 files changed

Lines changed: 61 additions & 19 deletions

File tree

backends/cuda/triton/kernels/sdpa.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,39 +1390,67 @@ def sdpa_decode_splitk(
13901390
key: torch.Tensor,
13911391
value: torch.Tensor,
13921392
attn_mask: Optional[torch.Tensor] = None,
1393+
dropout_p: float = 0.0,
1394+
is_causal: bool = False,
13931395
scale: float = 0.0,
1396+
enable_gqa: bool = False,
13941397
) -> torch.Tensor:
1398+
"""Split-K flash-decoding SDPA for L_q=1 (decode step).
1399+
1400+
Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
1401+
enable_gqa is accepted but ignored — GQA is handled natively via
1402+
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
1403+
"""
1404+
_validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa)
1405+
13951406
B, H_q, L_q, D = query.shape
13961407
_, H_kv, L_kv, _ = key.shape
13971408

1409+
out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)
1410+
13981411
# is_causal is a no-op at L_q=1 (single query can't attend to future
13991412
# positions), so we accept it silently for API compatibility with callers
14001413
# that always pass is_causal=True for decode.
14011414

1402-
if L_q != 1:
1403-
raise RuntimeError(
1404-
f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}"
1405-
)
1406-
if H_q % H_kv != 0:
1407-
raise RuntimeError(
1408-
f"H_q must be divisible by H_kv; got H_q={H_q}, H_kv={H_kv}"
1409-
)
1410-
if not _is_power_of_2(D):
1411-
raise RuntimeError(
1412-
f"sdpa_decode_splitk requires power-of-2 head dim; got D={D}"
1413-
)
1415+
# Validation — only check at runtime (concrete shapes), not during AOTI
1416+
# tracing where shapes are symbolic. torch.cond traces both branches with
1417+
# the same symbolic L_q, so L_q is not necessarily 1 during tracing.
1418+
if isinstance(L_q, int):
1419+
if L_q != 1:
1420+
raise RuntimeError(
1421+
f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}"
1422+
)
1423+
if H_q % H_kv != 0:
1424+
raise RuntimeError(
1425+
f"H_q must be divisible by H_kv; got H_q={H_q}, H_kv={H_kv}"
1426+
)
1427+
if not _is_power_of_2(D):
1428+
raise RuntimeError(
1429+
f"sdpa_decode_splitk requires power-of-2 head dim; got D={D}"
1430+
)
14141431

14151432
num_groups = H_q // H_kv
1416-
out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)
14171433
sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale
14181434
HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params(
14191435
attn_mask, B, L_q, L_kv
14201436
)
14211437

14221438
_launch_decode_splitk(
1423-
query, key, value, out,
1424-
B, H_q, H_kv, L_kv, D, sm_scale,
1425-
HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk,
1439+
query,
1440+
key,
1441+
value,
1442+
out,
1443+
B,
1444+
H_q,
1445+
H_kv,
1446+
L_kv,
1447+
D,
1448+
sm_scale,
1449+
HAS_MASK,
1450+
Mask_ptr,
1451+
stride_mb,
1452+
stride_mq,
1453+
stride_mk,
14261454
num_groups,
14271455
)
14281456
return out
@@ -1434,7 +1462,10 @@ def _sdpa_decode_splitk_abstract(
14341462
key: torch.Tensor,
14351463
value: torch.Tensor,
14361464
attn_mask: Optional[torch.Tensor] = None,
1465+
dropout_p: float = 0.0,
1466+
is_causal: bool = False,
14371467
scale: float = 0.0,
1468+
enable_gqa: bool = False,
14381469
) -> torch.Tensor:
14391470
assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype"
14401471
B, H_q, L_q, D = query.shape

examples/models/qwen3_5_moe/model.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222
import torch.nn as nn
23+
2324
from torch.nn import functional as F
2425

2526

@@ -285,9 +286,19 @@ def forward(self, x, input_pos):
285286
)
286287
else:
287288
k, v = self.kv_cache.update(input_pos, k, v)
288-
y = F.scaled_dot_product_attention(
289-
q, k, v, attn_mask=attn_mask, enable_gqa=True
290-
)
289+
# The export produces two methods — decode (T=1, static) and
290+
# prefill (T>=2, dynamic). Each traces only one branch, so no
291+
# torch.cond is needed and we avoid GPU→CPU sync overhead.
292+
if T == 1:
293+
from executorch.backends.cuda.triton.kernels.sdpa import (
294+
sdpa_decode_splitk,
295+
)
296+
297+
y = sdpa_decode_splitk(q, k, v, attn_mask=attn_mask)
298+
else:
299+
from executorch.backends.cuda.triton.kernels.sdpa import sdpa
300+
301+
y = sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=True)
291302

292303
y = y.transpose(1, 2).contiguous().view(B, T, -1)
293304

0 commit comments

Comments
 (0)