Skip to content

Commit 49a7564

Browse files
ggml webgpu: fix workgroup dispatch limit for large batch sizes (ggml-org#19965)
* ggml-webgpu: fix workgroup dispatch limit for large batch sizes WebGPU limits workgroup sizes to 65535 per dimension. Large MUL_MAT operations with batch sizes exceedeing this limi would fail. * add compute_2d_workgroups() helper to split total workgroup ID across X/Y dimensions * update mul_mat_reg_tile.wgsl to reconstruct linear workgroup ID from 2D dispatch * update mul_mat_subgroup_matrix.wgsl to reconstruct linear workgroup ID from 2D dispatch * update mul_mat.wgsl to compute global index from 2D workgroup coordinates * refactor all three mul_mat dispatch paths to use the shared helper * ggml-webgpu: add bounds checking for over-dispatched workgroups 2D workgroup dispatch can over-dispatch when total workgroups don't divide evenly into the 65535 per-dimension limit. Extra workgroups would compute invalid batch indices, causing memory corruption. * add batch_idx bound check to mul_mat_reg_tile.wgsl and mul_mat_subgroup_matrix.wgsl to prevent over-dispatched workgroups from accessing invalid memory * fixes test failures with large batch sizes (eg., bs=[128, 1024]) * ggml-webgpu: add back TODO for spliting large sizes into batches * Optimize 2d workgroup provisioning * Set some parameters that increase speed --------- Co-authored-by: Reese Levine <reeselevine1@gmail.com>
1 parent 4d828bd commit 49a7564

4 files changed

Lines changed: 49 additions & 20 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
3232
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
3333

34+
// Return a rectangular grid of workgroups with minimal over-provisioned workgroups.
35+
// Assumes that the total number of workgroups does not exceed max_per_dim^2.
36+
static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {
37+
wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));
38+
wg_x = CEIL_DIV(total_wg, wg_y);
39+
}
40+
3441
#ifdef GGML_WEBGPU_DEBUG
3542
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
3643
# define WEBGPU_DEBUG_BUF_ELEMS 512
@@ -69,8 +76,8 @@
6976

7077
/* Constants */
7178

72-
#define WEBGPU_NUM_PARAM_BUFS 16u
73-
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
79+
#define WEBGPU_NUM_PARAM_BUFS 48u
80+
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
7481
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
7582
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
7683
// parameter buffer pool
@@ -1146,18 +1153,17 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
11461153
};
11471154

11481155
// Calculate workgroup dimensions
1149-
uint32_t wg_x = 1;
1150-
uint32_t wg_y = 1;
1156+
uint32_t wg_x = 1;
1157+
uint32_t wg_y = 1;
1158+
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
11511159

11521160
if (use_fast && is_vec) {
11531161
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
11541162

11551163
uint32_t batches = dst->ne[2] * dst->ne[3];
11561164
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
11571165
uint32_t total_wg = output_groups * batches;
1158-
// TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups
1159-
wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
1160-
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
1166+
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
11611167
} else if (use_fast) {
11621168
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
11631169

@@ -1176,12 +1182,14 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
11761182
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
11771183
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
11781184
}
1179-
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
1185+
uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
1186+
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
1187+
11801188
} else { // legacy
11811189
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
11821190
uint32_t wg_size = decisions->wg_size;
1183-
wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
1184-
wg_y = 1;
1191+
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
1192+
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
11851193
}
11861194

11871195
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -679,19 +679,24 @@ struct MulMatParams {
679679
@group(0) @binding(3) var<uniform> params: MulMatParams;
680680

681681
@compute @workgroup_size(256)
682-
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
682+
fn main(@builtin(local_invocation_id) local_id: vec3<u32>,
683+
@builtin(workgroup_id) wg_id: vec3<u32>,
684+
@builtin(num_workgroups) num_wg: vec3<u32>) {
685+
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
686+
let global_idx = wg_linear * 256u + local_id.x;
687+
683688
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
684-
if (global_id.x >= total) {
689+
if (global_idx >= total) {
685690
return;
686691
}
687692

688693
let dst2_stride = params.m * params.n;
689694
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
690695

691-
let dst3_idx = global_id.x / dst3_stride;
696+
let dst3_idx = global_idx / dst3_stride;
692697
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
693698
let src13_idx = dst3_idx; // src1 is not broadcast
694-
let dst3_rem = global_id.x % dst3_stride;
699+
let dst3_rem = global_idx % dst3_stride;
695700

696701
let dst2_idx = dst3_rem / dst2_stride;
697702
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
5454

5555
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
5656
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
57-
@builtin(local_invocation_id) local_id: vec3<u32>) {
57+
@builtin(local_invocation_id) local_id: vec3<u32>,
58+
@builtin(num_workgroups) num_wg: vec3<u32>) {
5859

5960
let thread_id = local_id.x;
6061
let local_m = get_local_m(thread_id);
@@ -64,9 +65,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
6465
let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
6566
let wg_per_matrix = wg_m_count * wg_n_count;
6667

67-
let batch_idx = wg_id.x / wg_per_matrix;
68+
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
6869

69-
let wg_in_batch = wg_id.x % wg_per_matrix;
70+
let batch_idx = wg_linear / wg_per_matrix;
71+
72+
let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
73+
if (batch_idx >= total_batches) {
74+
return;
75+
}
76+
77+
let wg_in_batch = wg_linear % wg_per_matrix;
7078
let wg_m = wg_in_batch % wg_m_count;
7179
let wg_n = wg_in_batch / wg_m_count;
7280

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ var<workgroup> shmem: array<f16, SHMEM_SIZE>;
6969
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
7070
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
7171
@builtin(local_invocation_id) local_id: vec3<u32>,
72-
@builtin(subgroup_id) subgroup_id: u32) {
72+
@builtin(subgroup_id) subgroup_id: u32,
73+
@builtin(num_workgroups) num_wg: vec3<u32>) {
7374

7475
let thread_id = local_id.x;
7576
let subgroup_m = subgroup_id % SUBGROUP_M;
@@ -79,9 +80,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
7980
let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
8081
let wg_per_matrix = wg_m_count * wg_n_count;
8182

82-
let batch_idx = wg_id.x / wg_per_matrix;
83+
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
8384

84-
let wg_in_batch = wg_id.x % wg_per_matrix;
85+
let batch_idx = wg_linear / wg_per_matrix;
86+
87+
let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
88+
if (batch_idx >= total_batches) {
89+
return;
90+
}
91+
92+
let wg_in_batch = wg_linear % wg_per_matrix;
8593
let wg_m = wg_in_batch % wg_m_count;
8694
let wg_n = wg_in_batch / wg_m_count;
8795

0 commit comments

Comments
 (0)