Skip to content

Commit 4898af2

Browse files
Metal backend: Add SDPA head_dim=256 support (#18875)
Qwen 3.5 MoE uses head_dim=256 for full attention layers. The existing SDPA Metal kernel only instantiated head_dim 64, 96, 128. At D=256 each thread handles 8 QK elements (8 x 32 threads = 256 dims); register pressure and threadgroup memory are well within Apple GPU limits.
1 parent ad27a45 commit 4898af2

2 files changed

Lines changed: 31 additions & 5 deletions

File tree

backends/apple/metal/runtime/ops/op_sdpa.mm

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@
226226
#define INSTANTIATE_SDPA_VECTOR_HEADS(DTYPE) \
227227
INSTANTIATE_SDPA_VECTOR(DTYPE, 64, 64); \
228228
INSTANTIATE_SDPA_VECTOR(DTYPE, 96, 96); \
229-
INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128);
229+
INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128); \
230+
INSTANTIATE_SDPA_VECTOR(DTYPE, 256, 256);
230231
231232
INSTANTIATE_SDPA_VECTOR_HEADS(float);
232233
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
@@ -430,11 +431,11 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
430431
throw std::runtime_error("Unsupported dtype for Metal SDPA kernel");
431432
}
432433

433-
// Select head_dim - must match exactly one of the supported sizes (64, 96, 128)
434+
// Select head_dim - must match exactly one of the supported sizes (64, 96, 128, 256)
434435
int64_t head_dim = headSize;
435-
if (head_dim != 64 && head_dim != 96 && head_dim != 128) {
436-
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, or 128)", head_dim);
437-
throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, or 128");
436+
if (head_dim != 64 && head_dim != 96 && head_dim != 128 && head_dim != 256) {
437+
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, 128, or 256)", head_dim);
438+
throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, 128, or 256");
438439
}
439440

440441
std::string kernel_name = "sdpa_vector_" + type_name + "_" + std::to_string(head_dim) + "_" + std::to_string(head_dim);

backends/apple/metal/tests/test_modules.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,31 @@ def __init__(self):
639639
}
640640

641641

642+
# -------------------------------------------------------------------------
643+
# SDPA with head_dim=256 (Qwen 3.5 MoE)
644+
# -------------------------------------------------------------------------
645+
646+
647+
class SDPAHeadDim256(nn.Module):
648+
"""SDPA with head_dim=256, required by Qwen 3.5 MoE full attention layers."""
649+
650+
def forward(
651+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
652+
) -> torch.Tensor:
653+
return torch.nn.functional.scaled_dot_product_attention(
654+
query, key, value, dropout_p=0.0, is_causal=False
655+
)
656+
657+
658+
MODULE_REGISTRY["sdpa_head_dim_256"] = {
659+
"model_class": SDPAHeadDim256,
660+
"input_shapes": [(1, 4, 8, 256), (1, 4, 8, 256), (1, 4, 8, 256)],
661+
"description": "SDPA with head_dim=256 (Qwen 3.5 MoE)",
662+
"atol_float32": 1e-4,
663+
"atol_bfloat16": 5e-2,
664+
}
665+
666+
642667
# =============================================================================
643668
# Helper Functions
644669
# =============================================================================

0 commit comments

Comments
 (0)