Skip to content

Commit c40006a

Browse files
authored
ggml-webgpu: Fix how to dispatch WG to some ops (ggml-org#23750)
1 parent c6e4088 commit c40006a

3 files changed

Lines changed: 57 additions & 49 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -749,8 +749,11 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src
749749
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst),
750750
};
751751

752-
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
753-
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
752+
uint32_t wg_x;
753+
uint32_t wg_y;
754+
uint32_t total_wg = CEIL_DIV(ne, decisions->wg_size);
755+
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
756+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
754757
}
755758

756759
static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx,
@@ -974,9 +977,10 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
974977

975978
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
976979

980+
uint32_t wg_x;
981+
uint32_t wg_y;
977982
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
978-
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
979-
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
983+
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
980984

981985
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
982986
}
@@ -1064,9 +1068,10 @@ static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx,
10641068

10651069
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
10661070

1071+
uint32_t wg_x;
1072+
uint32_t wg_y;
10671073
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
1068-
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
1069-
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
1074+
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
10701075

10711076
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
10721077
}
@@ -1689,14 +1694,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
16891694
gathered_count_ids_binding_size),
16901695
};
16911696

1692-
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1693-
1694-
const uint32_t gather_total_wg = param_n_expert;
1695-
const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim);
1696-
const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x);
1697+
// n_expert is much less than maxComputeWorkgroupsPerDimension (e.g., n_exeprt=256 at Qwen3.5-35B-A3B)
1698+
const uint32_t gather_wg_x = param_n_expert;
16971699

16981700
dispatches.push_back({
1699-
gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y }
1701+
gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, 1 }
17001702
});
17011703

17021704
// params for mul_mat_id.wgsl
@@ -1748,7 +1750,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
17481750
uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts;
17491751
uint32_t total_wg = wg_m * max_wg_n;
17501752

1751-
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
1753+
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
17521754

17531755
dispatches.push_back({
17541756
main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y }
@@ -2771,10 +2773,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor *
27712773
block_size, npr, nrows
27722774
};
27732775

2774-
const uint32_t total_wg_init = npr * nrows;
2775-
const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
2776-
const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
2777-
const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
2776+
uint32_t wg_x_init;
2777+
uint32_t wg_y_init;
2778+
const uint32_t total_wg_init = npr * nrows;
2779+
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
2780+
compute_2d_workgroups(total_wg_init, max_wg_per_dim, wg_x_init, wg_y_init);
2781+
27782782
std::vector<wgpu::BindGroupEntry> init_entries = {
27792783
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src),
27802784
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size)
@@ -2831,9 +2835,11 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor *
28312835
ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out)
28322836
};
28332837

2838+
uint32_t wg_x_merge;
2839+
uint32_t wg_y_merge;
28342840
const uint32_t total_wg_merge = nm * nrows;
2835-
const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
2836-
const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
2841+
compute_2d_workgroups(total_wg_merge, max_wg_per_dim, wg_x_merge, wg_y_merge);
2842+
28372843
dispatches.push_back({
28382844
argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge }
28392845
});
@@ -2953,9 +2959,12 @@ static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * s
29532959

29542960
webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx);
29552961
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
2956-
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
2957-
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
2958-
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
2962+
2963+
uint32_t wg_x;
2964+
uint32_t wg_y;
2965+
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
2966+
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
2967+
29592968
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
29602969
}
29612970

ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,22 @@ struct Params{
4949
var<uniform> params: Params;
5050

5151
@compute @workgroup_size(WG_SIZE)
52-
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
53-
if (gid.x >= params.ne) {
52+
fn main(
53+
@builtin(global_invocation_index) gindex: u32,
54+
) {
55+
if (gindex >= params.ne) {
5456
return;
5557
}
5658

57-
var i = gid.x;
59+
var i = gindex;
5860
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
5961
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
6062
let i2 = i / (params.src_ne1 * params.src_ne0);
6163
i = i % (params.src_ne1 * params.src_ne0);
6264
let i1 = i / params.src_ne0;
6365
let i0 = i % params.src_ne0;
6466

65-
var j = gid.x;
67+
var j = gindex;
6668
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
6769
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
6870
let j2 = j / (params.dst_ne1 * params.dst_ne0);

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,32 @@ var<workgroup> count:atomic<u32>;
2121

2222
@compute @workgroup_size(WG_SIZE)
2323
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
24-
@builtin(local_invocation_id) local_id: vec3<u32>,
25-
@builtin(num_workgroups) num_wg: vec3<u32>) {
24+
@builtin(local_invocation_id) local_id: vec3<u32>) {
2625

2726
let thread_id = local_id.x;
28-
let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup
27+
let own_expert = wg_id.x; // the expert assigned to this workgroup
2928

30-
if (own_expert < params.n_expert) {
31-
if (thread_id == 0u) {
32-
atomicStore(&count, 0);
33-
}
29+
if (thread_id == 0u) {
30+
atomicStore(&count, 0);
31+
}
3432

35-
workgroupBarrier();
36-
37-
for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
38-
let row = i / params.n_expert_used;
39-
let col = i % params.n_expert_used;
40-
let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
41-
if (own_expert == expert) {
42-
let pos = atomicAdd(&count, 1u);
43-
let gathered_id = own_expert * params.n_tokens + pos;
44-
global_gathered_expert_used[gathered_id] = col;
45-
global_gathered_tokens[gathered_id] = row;
46-
}
33+
workgroupBarrier();
34+
35+
for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
36+
let row = i / params.n_expert_used;
37+
let col = i % params.n_expert_used;
38+
let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
39+
if (own_expert == expert) {
40+
let pos = atomicAdd(&count, 1u);
41+
let gathered_id = own_expert * params.n_tokens + pos;
42+
global_gathered_expert_used[gathered_id] = col;
43+
global_gathered_tokens[gathered_id] = row;
4744
}
45+
}
4846

49-
workgroupBarrier();
47+
workgroupBarrier();
5048

51-
if (thread_id == 0u) {
52-
gathered_count_ids[own_expert] = atomicLoad(&count);
53-
}
49+
if (thread_id == 0u) {
50+
gathered_count_ids[own_expert] = atomicLoad(&count);
5451
}
5552
}

0 commit comments

Comments
 (0)