4949#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4
5050#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 1024
5151
52- #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 8
52+ #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
5353#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 1024
5454
5555// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
56- #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
56+ #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
5757// Requires at least two (and multiple of 2) k-quant blocks per tile
5858#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 2048
5959
@@ -613,7 +613,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
613613
614614struct ggml_webgpu_mul_mat_vec_shader_decisions {
615615 uint32_t wg_size;
616- uint32_t tile_k;
617616 uint32_t outputs_per_wg;
618617 uint32_t vec_size;
619618};
@@ -1335,17 +1334,10 @@ class ggml_webgpu_shader_lib {
13351334 }
13361335
13371336 webgpu_pipeline get_mul_mat_vec_pipeline (const ggml_webgpu_shader_lib_context & context) {
1338- const bool use_row_tiled =
1339- context.src0 ->type == GGML_TYPE_F32 || context.src0 ->type == GGML_TYPE_F16 || context.src0 ->type == GGML_TYPE_Q4_0 ||
1340- context.src0 ->type == GGML_TYPE_Q4_1 || context.src0 ->type == GGML_TYPE_Q5_0 || context.src0 ->type == GGML_TYPE_Q5_1 ||
1341- context.src0 ->type == GGML_TYPE_Q8_0 || context.src0 ->type == GGML_TYPE_Q8_1 || context.src0 ->type == GGML_TYPE_Q6_K ||
1342- context.src0 ->type == GGML_TYPE_Q4_K || context.src0 ->type == GGML_TYPE_Q5_K || context.src0 ->type == GGML_TYPE_Q3_K ||
1343- context.src0 ->type == GGML_TYPE_Q2_K;
13441337 ggml_webgpu_mul_mat_vec_pipeline_key key = {
13451338 .src0_type = context.src0 ->type ,
13461339 .src1_type = context.src1 ->type ,
13471340 .vectorized = (context.src0 ->ne [0 ] % 4 == 0 &&
1348- (use_row_tiled || context.dst ->ne [0 ] % 4 == 0 ) &&
13491341 (context.src0 ->type == GGML_TYPE_F32 || context.src0 ->type == GGML_TYPE_F16)) ?
13501342 1 :
13511343 0
@@ -1357,8 +1349,8 @@ class ggml_webgpu_shader_lib {
13571349 }
13581350
13591351 std::vector<std::string> defines;
1360- std::string variant = use_row_tiled ? " mul_mat_vec_row_tiled " : " mul_mat_vec" ;
1361- const char * shader_src = use_row_tiled ? wgsl_mul_mat_vec_row_tiled : wgsl_mul_mat_vec;
1352+ std::string variant = " mul_mat_vec" ;
1353+ const char * shader_src = wgsl_mul_mat_vec;
13621354
13631355 // src0 type (matrix row)
13641356 switch (context.src0 ->type ) {
@@ -1407,33 +1399,25 @@ class ggml_webgpu_shader_lib {
14071399 defines.push_back (key.vectorized ? " VEC" : " SCALAR" );
14081400
14091401 uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
1410- uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
14111402 uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
14121403
14131404 if (key.src0_type >= GGML_TYPE_Q2_K) {
1414- tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
14151405 outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
14161406 } else if (key.src0_type >= GGML_TYPE_Q4_0) {
1417- tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
14181407 outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
14191408 }
14201409
14211410 defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (wg_size));
14221411 defines.push_back (std::string (" OUTPUTS_PER_WG=" ) + std::to_string (outputs_per_wg));
1423- if (use_row_tiled) {
1424- defines.push_back (context.supports_subgroups ? " USE_SUBGROUP_REDUCTION" : " USE_WORKGROUP_REDUCTION" );
1425- variant += context.supports_subgroups ? " _sg_reduce" : " _wg_reduce" ;
1426- } else {
1427- defines.push_back (std::string (" TILE_K=" ) + std::to_string (tile_k));
1428- }
1412+ defines.push_back (context.supports_subgroups ? " USE_SUBGROUP_REDUCTION" : " USE_WORKGROUP_REDUCTION" );
1413+ variant += context.supports_subgroups ? " _sg_reduce" : " _wg_reduce" ;
14291414 if (key.vectorized ) {
14301415 variant += " _vectorized" ;
14311416 }
14321417
14331418 auto processed = preprocessor.preprocess (shader_src, defines);
14341419 auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
14351420 decisions->wg_size = wg_size;
1436- decisions->tile_k = tile_k;
14371421 decisions->outputs_per_wg = outputs_per_wg;
14381422 decisions->vec_size = key.vectorized ? 4 : 1 ;
14391423
0 commit comments