2626// Matrix multiplication parameters
2727
2828// Register tiling parameters
29- #define WEBGPU_MUL_MAT_TILE_M 8
30- #define WEBGPU_MUL_MAT_TILE_N 8
29+ #define WEBGPU_MUL_MAT_TILE_M 4
30+ #define WEBGPU_MUL_MAT_TILE_N 4
3131#define WEBGPU_MUL_MAT_WG_SIZE_M 8
3232#define WEBGPU_MUL_MAT_WG_SIZE_N 8
33- #define WEBGPU_MUL_MAT_TILE_K 32
33+ #define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8
34+ #define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32
3435
3536// Subgroup matrix parameters
3637// The number of subgroups in the M dimension
3738#define WEBGPU_MUL_MAT_SUBGROUP_M 2
3839// The number of subgroups in the N dimension
39- #define WEBGPU_MUL_MAT_SUBGROUP_N 2
40+ #define WEBGPU_MUL_MAT_SUBGROUP_N 4
4041// The number of subgroup matrices each subgroup accumulates over
4142#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
4243#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
44+ #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32
45+ #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32
4346
4447// Matrix-vector multiplication parameters
4548#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
@@ -1734,13 +1737,24 @@ class ggml_webgpu_shader_lib {
17341737 // VEC/SCALAR controls
17351738 defines.push_back (key.vectorized ? " VEC" : " SCALAR" );
17361739
1740+ const bool is_quant = ggml_is_quantized (context.src0 ->type );
1741+
1742+ uint32_t tile_k;
1743+ if (key.use_subgroup_matrix ) {
1744+ tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT
1745+ : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
1746+ } else {
1747+ tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
1748+ : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
1749+ }
1750+
17371751 // Tiles
17381752 defines.push_back (" TILE_M=" + std::to_string (WEBGPU_MUL_MAT_TILE_M) + " u" );
17391753 defines.push_back (" TILE_N=" + std::to_string (WEBGPU_MUL_MAT_TILE_N) + " u" );
1740- defines.push_back (" TILE_K=" + std::to_string (WEBGPU_MUL_MAT_TILE_K) + " u" );
17411754
17421755 // Subgroup matrix specifics
17431756 if (key.use_subgroup_matrix ) {
1757+ defines.push_back (" TILE_K=" + std::to_string (tile_k) + " u" );
17441758 defines.push_back (" MAX_SUBGROUP_SIZE=" + std::to_string (context.max_subgroup_size ) + " u" );
17451759 defines.push_back (" SUBGROUP_M=" + std::to_string (WEBGPU_MUL_MAT_SUBGROUP_M) + " u" );
17461760 defines.push_back (" SUBGROUP_N=" + std::to_string (WEBGPU_MUL_MAT_SUBGROUP_N) + " u" );
@@ -1760,12 +1774,13 @@ class ggml_webgpu_shader_lib {
17601774 if (!key.use_subgroup_matrix ) {
17611775 defines.push_back (" WORKGROUP_SIZE_M=" + std::to_string (WEBGPU_MUL_MAT_WG_SIZE_M) + " u" );
17621776 defines.push_back (" WORKGROUP_SIZE_N=" + std::to_string (WEBGPU_MUL_MAT_WG_SIZE_N) + " u" );
1777+ defines.push_back (" TILE_K=" + std::to_string (tile_k) + " u" );
17631778 }
17641779
17651780 auto processed = preprocessor.preprocess (shader_src, defines);
17661781
17671782 auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
1768- decisions->tile_k = WEBGPU_MUL_MAT_TILE_K ;
1783+ decisions->tile_k = tile_k ;
17691784 decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
17701785 decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
17711786 decisions->use_subgroup_matrix = key.use_subgroup_matrix ;
@@ -1962,10 +1977,15 @@ class ggml_webgpu_shader_lib {
19621977
19631978 defines.push_back (" SCALAR" );
19641979
1980+ // mul_mat_id is register-tile only.
1981+ const uint32_t tile_k = ggml_is_quantized (context.src0 ->type )
1982+ ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
1983+ : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
1984+
19651985 // Tiles
19661986 defines.push_back (" TILE_M=" + std::to_string (WEBGPU_MUL_MAT_TILE_M) + " u" );
19671987 defines.push_back (" TILE_N=" + std::to_string (WEBGPU_MUL_MAT_TILE_N) + " u" );
1968- defines.push_back (" TILE_K=" + std::to_string (WEBGPU_MUL_MAT_TILE_K ) + " u" );
1988+ defines.push_back (" TILE_K=" + std::to_string (tile_k ) + " u" );
19691989
19701990 defines.push_back (" WORKGROUP_SIZE_M=" + std::to_string (WEBGPU_MUL_MAT_WG_SIZE_M) + " u" );
19711991 defines.push_back (" WORKGROUP_SIZE_N=" + std::to_string (WEBGPU_MUL_MAT_WG_SIZE_N) + " u" );
@@ -1976,7 +1996,7 @@ class ggml_webgpu_shader_lib {
19761996 auto processed = preprocessor.preprocess (wgsl_mul_mat_id, defines);
19771997
19781998 auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
1979- decisions->tile_k = WEBGPU_MUL_MAT_TILE_K ;
1999+ decisions->tile_k = tile_k ;
19802000 decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
19812001 decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
19822002 decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
0 commit comments