Skip to content

Commit c9ccaba

Browse files
Enable fused SDPA vector kernel for asymmetric Q/V head dims (192, 128) (#3637)
Co-authored-by: Cheng <git@zcbenz.com>
1 parent 1bf65e3 commit c9ccaba

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

mlx/backend/metal/kernels/scaled_dot_product_attention.metal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ using namespace metal;
3232
instantiate_sdpa_vector(type, 64, 64) \
3333
instantiate_sdpa_vector(type, 96, 96) \
3434
instantiate_sdpa_vector(type, 128, 128) \
35+
instantiate_sdpa_vector(type, 192, 128) \
3536
instantiate_sdpa_vector(type, 256, 256) \
3637
instantiate_sdpa_vector_aggregation(type, 64) \
3738
instantiate_sdpa_vector_aggregation(type, 96) \

mlx/backend/metal/scaled_dot_product_attention.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,9 +619,10 @@ bool ScaledDotProductAttention::use_fallback(
619619
const int gqa_factor = num_query_heads / num_kv_heads;
620620

621621
const bool sdpa_vector_supported_head_dim =
622-
query_head_dim == value_head_dim &&
623-
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
624-
query_head_dim == 256);
622+
(query_head_dim == value_head_dim &&
623+
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
624+
query_head_dim == 256)) ||
625+
(query_head_dim == 192 && value_head_dim == 128);
625626
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
626627
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
627628

0 commit comments

Comments
 (0)