Skip to content

Commit f12cc6d

Browse files
authored
ggml-webgpu: remove legacy constants (ggml-org#23672)
1 parent aa50b2c commit f12cc6d

1 file changed

Lines changed: 4 additions & 12 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,6 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) {
9494
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
9595
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
9696

97-
// For operations which process a row in parallel, this seems like a reasonable
98-
// default
99-
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
100-
101-
// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to
102-
// implementations so this can be removed, necessary only for get_rows right now
103-
#define WEBGPU_MAX_WG_SIZE 288
104-
10597
/* End Constants */
10698

10799
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
@@ -631,7 +623,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
631623
size_t size) {
632624
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
633625
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) };
634-
size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
626+
size_t bytes_per_wg = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread;
635627
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
636628

637629
ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t));
@@ -1366,7 +1358,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx,
13661358
shader_lib_ctx.src0 = src;
13671359
shader_lib_ctx.src1 = nullptr;
13681360
shader_lib_ctx.dst = dst;
1369-
shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE;
1361+
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
13701362

13711363
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
13721364
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
@@ -3716,13 +3708,13 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
37163708

37173709
static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
37183710
// we use the maximum workgroup size for the memset pipeline
3719-
size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
3711+
size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
37203712
// Size the bytes_per_thread so that the largest buffer size can be handled
37213713
ctx->capabilities.memset_bytes_per_thread =
37223714
CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
37233715
std::vector<wgpu::ConstantEntry> constants(2);
37243716
constants[0].key = "wg_size";
3725-
constants[0].value = WEBGPU_MAX_WG_SIZE;
3717+
constants[0].value = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
37263718
constants[1].key = "bytes_per_thread";
37273719
constants[1].value = ctx->capabilities.memset_bytes_per_thread;
37283720
ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);

0 commit comments

Comments
 (0)