@@ -749,8 +749,11 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src
749749 ggml_webgpu_make_tensor_bind_group_entry (ctx, 1 , dst),
750750 };
751751
752- uint32_t wg_x = CEIL_DIV (ne, decisions->wg_size );
753- return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x);
752+ uint32_t wg_x;
753+ uint32_t wg_y;
754+ uint32_t total_wg = CEIL_DIV (ne, decisions->wg_size );
755+ compute_2d_workgroups (total_wg, ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension , wg_x, wg_y);
756+ return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x, wg_y);
754757}
755758
756759static webgpu_encoded_op ggml_webgpu_set (webgpu_context & ctx,
@@ -974,9 +977,10 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
974977
975978 auto * decisions = static_cast <ggml_webgpu_generic_shader_decisions *>(pipeline.context .get ());
976979
980+ uint32_t wg_x;
981+ uint32_t wg_y;
977982 uint32_t total_wg = CEIL_DIV ((uint32_t ) ggml_nelements (dst), decisions->wg_size );
978- uint32_t wg_x = std::min (ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension , total_wg);
979- uint32_t wg_y = CEIL_DIV (total_wg, wg_x);
983+ compute_2d_workgroups (total_wg, ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension , wg_x, wg_y);
980984
981985 return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x, wg_y);
982986}
@@ -1064,9 +1068,10 @@ static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx,
10641068
10651069 auto * decisions = static_cast <ggml_webgpu_generic_shader_decisions *>(pipeline.context .get ());
10661070
1071+ uint32_t wg_x;
1072+ uint32_t wg_y;
10671073 uint32_t total_wg = CEIL_DIV ((uint32_t ) ggml_nelements (dst), decisions->wg_size );
1068- uint32_t wg_x = std::min (ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension , total_wg);
1069- uint32_t wg_y = CEIL_DIV (total_wg, wg_x);
1074+ compute_2d_workgroups (total_wg, ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension , wg_x, wg_y);
10701075
10711076 return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x, wg_y);
10721077}
@@ -1689,14 +1694,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
16891694 gathered_count_ids_binding_size),
16901695 };
16911696
1692- const uint32_t max_wg_per_dim = ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension ;
1693-
1694- const uint32_t gather_total_wg = param_n_expert;
1695- const uint32_t gather_wg_x = std::min (gather_total_wg, max_wg_per_dim);
1696- const uint32_t gather_wg_y = CEIL_DIV (gather_total_wg, gather_wg_x);
1697+ // n_expert is much less than maxComputeWorkgroupsPerDimension (e.g., n_exeprt=256 at Qwen3.5-35B-A3B)
1698+ const uint32_t gather_wg_x = param_n_expert;
16971699
16981700 dispatches.push_back ({
1699- gather_pipeline, std::move (gather_params), std::move (gather_entries), { gather_wg_x, gather_wg_y }
1701+ gather_pipeline, std::move (gather_params), std::move (gather_entries), { gather_wg_x, 1 }
17001702 });
17011703
17021704 // params for mul_mat_id.wgsl
@@ -1748,7 +1750,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
17481750 uint32_t max_wg_n = CEIL_DIV (total_gathered, tile_n_s) + max_active_experts;
17491751 uint32_t total_wg = wg_m * max_wg_n;
17501752
1751- compute_2d_workgroups (total_wg, max_wg_per_dim , wg_x, wg_y);
1753+ compute_2d_workgroups (total_wg, ctx-> global_ctx -> capabilities . limits . maxComputeWorkgroupsPerDimension , wg_x, wg_y);
17521754
17531755 dispatches.push_back ({
17541756 main_pipeline, std::move (main_params), std::move (main_entries), { wg_x, wg_y }
@@ -2771,10 +2773,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor *
27712773 block_size, npr, nrows
27722774 };
27732775
2774- const uint32_t total_wg_init = npr * nrows;
2775- const uint32_t max_wg = ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension ;
2776- const uint32_t wg_x_init = std::min (total_wg_init, max_wg);
2777- const uint32_t wg_y_init = CEIL_DIV (total_wg_init, wg_x_init);
2776+ uint32_t wg_x_init;
2777+ uint32_t wg_y_init;
2778+ const uint32_t total_wg_init = npr * nrows;
2779+ const uint32_t max_wg_per_dim = ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension ;
2780+ compute_2d_workgroups (total_wg_init, max_wg_per_dim, wg_x_init, wg_y_init);
2781+
27782782 std::vector<wgpu::BindGroupEntry> init_entries = {
27792783 ggml_webgpu_make_tensor_bind_group_entry (ctx, 0 , src),
27802784 ggml_webgpu_make_bind_group_entry (1 , ggml_webgpu_tensor_buf (dst), init_align_offset, init_binding_size)
@@ -2831,9 +2835,11 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor *
28312835 ggml_webgpu_make_bind_group_entry (2 , ggml_webgpu_tensor_buf (dst), align_out, size_out)
28322836 };
28332837
2838+ uint32_t wg_x_merge;
2839+ uint32_t wg_y_merge;
28342840 const uint32_t total_wg_merge = nm * nrows;
2835- const uint32_t wg_x_merge = std::min (total_wg_merge, max_wg );
2836- const uint32_t wg_y_merge = CEIL_DIV (total_wg_merge, wg_x_merge);
2841+ compute_2d_workgroups (total_wg_merge, max_wg_per_dim, wg_x_merge, wg_y_merge );
2842+
28372843 dispatches.push_back ({
28382844 argsort_merge_pipeline, std::move (merge_params), std::move (merge_entries), { wg_x_merge, wg_y_merge }
28392845 });
@@ -2953,9 +2959,12 @@ static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * s
29532959
29542960 webgpu_pipeline pipeline = ctx->shader_lib ->get_upscale_pipeline (shader_lib_ctx);
29552961 auto * decisions = static_cast <ggml_webgpu_generic_shader_decisions *>(pipeline.context .get ());
2956- uint32_t total_wg = CEIL_DIV ((uint32_t ) ggml_nelements (dst), decisions->wg_size );
2957- uint32_t wg_x = std::min (ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension , total_wg);
2958- uint32_t wg_y = CEIL_DIV (total_wg, wg_x);
2962+
2963+ uint32_t wg_x;
2964+ uint32_t wg_y;
2965+ uint32_t total_wg = CEIL_DIV ((uint32_t ) ggml_nelements (dst), decisions->wg_size );
2966+ compute_2d_workgroups (total_wg, ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension , wg_x, wg_y);
2967+
29592968 return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x, wg_y);
29602969}
29612970
0 commit comments