File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -777,15 +777,18 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
777777 const bool tile_can_dispatch_all_q_rows =
778778 context.max_subgroup_size > 0 &&
779779 context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size ;
780- const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
780+ const bool use_subgroup_matrix =
781+ context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
782+ context.src0 ->ne [0 ] % context.sg_mat_k == 0 && context.src2 ->ne [0 ] % context.sg_mat_n == 0 ;
783+ const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 &&
781784 V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
782785 (context.src0 ->ne [0 ] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0 ) &&
783786 (context.src2 ->ne [0 ] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0 ) &&
784787 tile_can_dispatch_all_q_rows && !use_vec;
785788
786789 decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
787790 use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
788- context. supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
791+ use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
789792 GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
790793
791794 if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
You can’t perform that action at this time.
0 commit comments