Skip to content

Commit 06a811d

Browse files
authored
add performance-portable tuning for register-tile and subgroup matmul (#22241)
1 parent 78433f6 commit 06a811d

1 file changed

Lines changed: 28 additions & 8 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,23 @@
2626
// Matrix multiplication parameters
2727

2828
// Register tiling parameters
29-
#define WEBGPU_MUL_MAT_TILE_M 8
30-
#define WEBGPU_MUL_MAT_TILE_N 8
29+
#define WEBGPU_MUL_MAT_TILE_M 4
30+
#define WEBGPU_MUL_MAT_TILE_N 4
3131
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
3232
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
33-
#define WEBGPU_MUL_MAT_TILE_K 32
33+
#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8
34+
#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32
3435

3536
// Subgroup matrix parameters
3637
// The number of subgroups in the M dimension
3738
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
3839
// The number of subgroups in the N dimension
39-
#define WEBGPU_MUL_MAT_SUBGROUP_N 2
40+
#define WEBGPU_MUL_MAT_SUBGROUP_N 4
4041
// The number of subgroup matrices each subgroup accumulates over
4142
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
4243
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
44+
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32
45+
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32
4346

4447
// Matrix-vector multiplication parameters
4548
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
@@ -1734,13 +1737,24 @@ class ggml_webgpu_shader_lib {
17341737
// VEC/SCALAR controls
17351738
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
17361739

1740+
const bool is_quant = ggml_is_quantized(context.src0->type);
1741+
1742+
uint32_t tile_k;
1743+
if (key.use_subgroup_matrix) {
1744+
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT
1745+
: WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
1746+
} else {
1747+
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
1748+
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
1749+
}
1750+
17371751
// Tiles
17381752
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
17391753
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
1740-
defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
17411754

17421755
// Subgroup matrix specifics
17431756
if (key.use_subgroup_matrix) {
1757+
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
17441758
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
17451759
defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
17461760
defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
@@ -1760,12 +1774,13 @@ class ggml_webgpu_shader_lib {
17601774
if (!key.use_subgroup_matrix) {
17611775
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
17621776
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
1777+
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
17631778
}
17641779

17651780
auto processed = preprocessor.preprocess(shader_src, defines);
17661781

17671782
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
1768-
decisions->tile_k = WEBGPU_MUL_MAT_TILE_K;
1783+
decisions->tile_k = tile_k;
17691784
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
17701785
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
17711786
decisions->use_subgroup_matrix = key.use_subgroup_matrix;
@@ -1962,10 +1977,15 @@ class ggml_webgpu_shader_lib {
19621977

19631978
defines.push_back("SCALAR");
19641979

1980+
// mul_mat_id is register-tile only.
1981+
const uint32_t tile_k = ggml_is_quantized(context.src0->type)
1982+
? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
1983+
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
1984+
19651985
// Tiles
19661986
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
19671987
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
1968-
defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
1988+
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
19691989

19701990
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
19711991
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
@@ -1976,7 +1996,7 @@ class ggml_webgpu_shader_lib {
19761996
auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
19771997

19781998
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
1979-
decisions->tile_k = WEBGPU_MUL_MAT_TILE_K;
1999+
decisions->tile_k = tile_k;
19802000
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
19812001
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
19822002
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;

0 commit comments

Comments
 (0)