Skip to content

Commit 4c1c3ac

Browse files
ggml-webgpu: only use subgroup-matrix path when head dims are divisible by sg_mat_k / sg_mat_n (ggml-org#23020)
1 parent 7f3f843 commit 4c1c3ac

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -777,15 +777,18 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
777777
const bool tile_can_dispatch_all_q_rows =
778778
context.max_subgroup_size > 0 &&
779779
context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size;
780-
const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
780+
const bool use_subgroup_matrix =
781+
context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
782+
context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0;
783+
const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 &&
781784
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
782785
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
783786
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
784787
tile_can_dispatch_all_q_rows && !use_vec;
785788

786789
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
787790
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
788-
context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
791+
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
789792
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
790793

791794
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {

0 commit comments

Comments
 (0)