Skip to content

Commit ebe61e8

Browse files
digantdesaidigantdesai
authored andcommitted
Add torch.cond split-K decode dispatch to Qwen3.5 MoE attention
Runtime dispatch via torch.cond in FullAttention: split-K flash-decoding for decode (L_q==1) and standard tiled SDPA for prefill (L_q>1). Guard sdpa_decode_splitk validation behind isinstance(L_q, int) so AOTI tracing with symbolic shapes doesn't trip the L_q==1 check. Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal, enable_gqa) for drop-in use with torch.cond; unsupported args fail with clear messages.
1 parent 35c7a18 commit ebe61e8

2 files changed

Lines changed: 46 additions & 12 deletions

File tree

backends/cuda/triton/kernels/sdpa.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,26 +1372,50 @@ def sdpa_decode_splitk(
13721372
key: torch.Tensor,
13731373
value: torch.Tensor,
13741374
attn_mask: Optional[torch.Tensor] = None,
1375+
dropout_p: float = 0.0,
1376+
is_causal: bool = False,
13751377
scale: float = 0.0,
1378+
enable_gqa: bool = False,
13761379
) -> torch.Tensor:
1380+
"""Split-K flash-decoding SDPA for L_q=1 (decode step).
1381+
1382+
Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
1383+
enable_gqa is accepted but ignored — GQA is handled natively via
1384+
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
1385+
"""
13771386
B, H_q, L_q, D = query.shape
13781387
_, H_kv, L_kv, _ = key.shape
13791388

1380-
if L_q != 1:
1381-
raise RuntimeError(
1382-
f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}"
1383-
)
1384-
if H_q % H_kv != 0:
1389+
out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)
1390+
1391+
if dropout_p != 0.0:
13851392
raise RuntimeError(
1386-
f"H_q must be divisible by H_kv; got H_q={H_q}, H_kv={H_kv}"
1393+
f"sdpa_decode_splitk does not support dropout; got dropout_p={dropout_p}"
13871394
)
1388-
if not _is_power_of_2(D):
1395+
if is_causal:
13891396
raise RuntimeError(
1390-
f"sdpa_decode_splitk requires power-of-2 head dim; got D={D}"
1397+
"sdpa_decode_splitk does not support is_causal=True "
1398+
"(causal masking is a no-op at L_q=1; pass attn_mask instead)"
13911399
)
13921400

1401+
# Validation — only check at runtime (concrete shapes), not during AOTI
1402+
# tracing where shapes are symbolic. torch.cond traces both branches with
1403+
# the same symbolic L_q, so L_q is not necessarily 1 during tracing.
1404+
if isinstance(L_q, int):
1405+
if L_q != 1:
1406+
raise RuntimeError(
1407+
f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}"
1408+
)
1409+
if H_q % H_kv != 0:
1410+
raise RuntimeError(
1411+
f"H_q must be divisible by H_kv; got H_q={H_q}, H_kv={H_kv}"
1412+
)
1413+
if not _is_power_of_2(D):
1414+
raise RuntimeError(
1415+
f"sdpa_decode_splitk requires power-of-2 head dim; got D={D}"
1416+
)
1417+
13931418
num_groups = H_q // H_kv
1394-
out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)
13951419
sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale
13961420
HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params(
13971421
attn_mask, B, L_q, L_kv
@@ -1412,7 +1436,10 @@ def _sdpa_decode_splitk_abstract(
14121436
key: torch.Tensor,
14131437
value: torch.Tensor,
14141438
attn_mask: Optional[torch.Tensor] = None,
1439+
dropout_p: float = 0.0,
1440+
is_causal: bool = False,
14151441
scale: float = 0.0,
1442+
enable_gqa: bool = False,
14161443
) -> torch.Tensor:
14171444
assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype"
14181445
B, H_q, L_q, D = query.shape

examples/models/qwen3_5_moe/model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import torch.nn as nn
2323
from torch.nn import functional as F
2424

25+
from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk
26+
2527

2628
# ---------------------------------------------------------------------------
2729
# Config
@@ -267,10 +269,15 @@ def forward(self, x, input_pos):
267269
# KV cache
268270
k, v = self.kv_cache.update(input_pos, k, v)
269271

270-
# SDPA with GQA — kernel maps Q heads to KV heads internally
272+
# SDPA with GQA — runtime dispatch via torch.cond:
273+
# decode (L_q==1): split-K flash-decoding for high KV occupancy
274+
# prefill (L_q>1): standard tiled SDPA (m32/m64)
271275
attn_mask = self.mask[input_pos].unsqueeze(0).unsqueeze(0)
272-
y = F.scaled_dot_product_attention(
273-
q, k, v, attn_mask=attn_mask, enable_gqa=True
276+
y = torch.cond(
277+
q.shape[2] == 1,
278+
lambda q, k, v, mask: sdpa_decode_splitk(q, k, v, attn_mask=mask),
279+
lambda q, k, v, mask: sdpa(q, k, v, attn_mask=mask, enable_gqa=True),
280+
[q, k, v, attn_mask],
274281
)
275282

276283
y = y.transpose(1, 2).contiguous().view(B, T, -1)

0 commit comments

Comments
 (0)