@@ -94,14 +94,6 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) {
9494#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
9595#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
9696
97- // For operations which process a row in parallel, this seems like a reasonable
98- // default
99- #define WEBGPU_ROW_SPLIT_WG_SIZE 64
100-
101- // Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to
102- // implementations so this can be removed, necessary only for get_rows right now
103- #define WEBGPU_MAX_WG_SIZE 288
104-
10597/* End Constants */
10698
10799// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
@@ -631,7 +623,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
631623 size_t size) {
632624 std::vector<uint32_t > params = { (uint32_t ) offset, (uint32_t ) size, value };
633625 std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry (0 , buf, 0 , buf.GetSize ()) };
634- size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities .memset_bytes_per_thread ;
626+ size_t bytes_per_wg = ctx-> capabilities . limits . maxComputeInvocationsPerWorkgroup * ctx->capabilities .memset_bytes_per_thread ;
635627 uint32_t wg_x = CEIL_DIV (size + 3 , bytes_per_wg);
636628
637629 ctx->queue .WriteBuffer (ctx->memset_params_buf , 0 , params.data (), params.size () * sizeof (uint32_t ));
@@ -1366,7 +1358,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx,
13661358 shader_lib_ctx.src0 = src;
13671359 shader_lib_ctx.src1 = nullptr ;
13681360 shader_lib_ctx.dst = dst;
1369- shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE ;
1361+ shader_lib_ctx.max_wg_size = ctx-> global_ctx -> capabilities . limits . maxComputeInvocationsPerWorkgroup ;
13701362
13711363 webgpu_pipeline pipeline = ctx->shader_lib ->get_get_rows_pipeline (shader_lib_ctx);
13721364 auto * decisions = static_cast <ggml_webgpu_generic_shader_decisions *>(pipeline.context .get ());
@@ -3716,13 +3708,13 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
37163708
37173709static void ggml_webgpu_init_memset_pipeline (webgpu_global_context & ctx) {
37183710 // we use the maximum workgroup size for the memset pipeline
3719- size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities .limits .maxComputeWorkgroupsPerDimension ;
3711+ size_t max_threads = ctx-> capabilities . limits . maxComputeInvocationsPerWorkgroup * ctx->capabilities .limits .maxComputeWorkgroupsPerDimension ;
37203712 // Size the bytes_per_thread so that the largest buffer size can be handled
37213713 ctx->capabilities .memset_bytes_per_thread =
37223714 CEIL_DIV (ctx->capabilities .limits .maxStorageBufferBindingSize , max_threads);
37233715 std::vector<wgpu::ConstantEntry> constants (2 );
37243716 constants[0 ].key = " wg_size" ;
3725- constants[0 ].value = WEBGPU_MAX_WG_SIZE ;
3717+ constants[0 ].value = ctx-> capabilities . limits . maxComputeInvocationsPerWorkgroup ;
37263718 constants[1 ].key = " bytes_per_thread" ;
37273719 constants[1 ].value = ctx->capabilities .memset_bytes_per_thread ;
37283720 ctx->memset_pipeline = ggml_webgpu_create_pipeline (ctx->device , wgsl_memset, " memset" , constants);
0 commit comments