Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions backends/apple/metal/runtime/ops/op_sdpa.mm
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@
#define INSTANTIATE_SDPA_VECTOR_HEADS(DTYPE) \
INSTANTIATE_SDPA_VECTOR(DTYPE, 64, 64); \
INSTANTIATE_SDPA_VECTOR(DTYPE, 96, 96); \
INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128);
INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128); \
INSTANTIATE_SDPA_VECTOR(DTYPE, 256, 256);

INSTANTIATE_SDPA_VECTOR_HEADS(float);
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
Expand Down Expand Up @@ -430,11 +431,11 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
throw std::runtime_error("Unsupported dtype for Metal SDPA kernel");
}

// Select head_dim - must match exactly one of the supported sizes (64, 96, 128)
// Select head_dim - must match exactly one of the supported sizes (64, 96, 128, 256)
int64_t head_dim = headSize;
if (head_dim != 64 && head_dim != 96 && head_dim != 128) {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, or 128)", head_dim);
throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, or 128");
if (head_dim != 64 && head_dim != 96 && head_dim != 128 && head_dim != 256) {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, 128, or 256)", head_dim);
throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, 128, or 256");
}

std::string kernel_name = "sdpa_vector_" + type_name + "_" + std::to_string(head_dim) + "_" + std::to_string(head_dim);
Expand Down
11 changes: 3 additions & 8 deletions backends/apple/metal/runtime/shims/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ extern "C" {
bool is_dtype_supported_in_et_metal(int32_t dtype) {
switch (dtype) {
case static_cast<int32_t>(SupportedDTypes::UINT8):
case static_cast<int32_t>(SupportedDTypes::INT32):
case static_cast<int32_t>(SupportedDTypes::INT64):
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
case static_cast<int32_t>(SupportedDTypes::BOOL):
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
return true;
default:
Expand All @@ -35,14 +37,7 @@ AOTITorchError validate_dtype(int32_t dtype) {
return Error::Ok;
}

ET_LOG(
Error,
"Unsupported dtype: %d. Supported dtypes: %d (uint8), %d (int64), %d (float32), %d (bfloat16)",
dtype,
static_cast<int32_t>(SupportedDTypes::UINT8),
static_cast<int32_t>(SupportedDTypes::INT64),
static_cast<int32_t>(SupportedDTypes::FLOAT32),
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
ET_LOG(Error, "Unsupported dtype: %d", dtype);
return Error::InvalidArgument;
}

Expand Down
4 changes: 2 additions & 2 deletions backends/apple/metal/runtime/shims/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ enum class SupportedDTypes : int32_t {
UINT8 = 0, // PyTorch's uint8 dtype code
// INT8 = 1, // PyTorch's int8 dtype code
// INT16 = 2, // PyTorch's int16 dtype code
// INT32 = 3, // PyTorch's int32 dtype code
INT32 = 3, // PyTorch's int32 dtype code
INT64 = 4, // PyTorch's int64 dtype code
// FLOAT16 = 5, // PyTorch's float16 dtype code
FLOAT32 = 6, // PyTorch's float32 dtype code
// FLOAT64 = 7, // PyTorch's float64 dtype code
// BOOL = 11, // PyTorch's bool dtype code
BOOL = 11, // PyTorch's bool dtype code
BFLOAT16 = 15 // PyTorch's bfloat16 dtype code
};

Expand Down
25 changes: 25 additions & 0 deletions backends/apple/metal/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,31 @@ def __init__(self):
}


# -------------------------------------------------------------------------
# SDPA with head_dim=256 (Qwen 3.5 MoE)
# -------------------------------------------------------------------------


class SDPAHeadDim256(nn.Module):
"""SDPA with head_dim=256, required by Qwen 3.5 MoE full attention layers."""

def forward(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)


MODULE_REGISTRY["sdpa_head_dim_256"] = {
"model_class": SDPAHeadDim256,
"input_shapes": [(1, 4, 8, 256), (1, 4, 8, 256), (1, 4, 8, 256)],
"description": "SDPA with head_dim=256 (Qwen 3.5 MoE)",
"atol_float32": 1e-4,
"atol_bfloat16": 5e-2,
}


# =============================================================================
# Helper Functions
# =============================================================================
Expand Down
Loading