Skip to content

Commit 4d828bd

Browse files
authored
ggml webgpu: Clean up per-thread parameter buffer pool and job submission logic (ggml-org#19772)
* Allow webgpu_buf_pool to resize if needed, remove inflight_threads, and replace inflight_threads with num_kernels for submission * Run clang-format * Keep track of num batched kernels that have not been submitted yet * Run clang-format * Increase buf pool max size * Increase param buf pool init size * Remove webgpu buf pool resizing * Merge with master * Add buffer pool growth * Move buffer pool growth outside of lock * Reduce max pool size to 32 * Run clang-format * Only resize param buf pool
1 parent 36a7a65 commit 4d828bd

1 file changed

Lines changed: 47 additions & 18 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,28 @@ struct webgpu_buf_pool {
133133
// which can run on a different thread than the calling thread.
134134
std::mutex mutex;
135135
std::condition_variable cv;
136+
size_t cur_pool_size;
137+
size_t max_pool_size;
138+
wgpu::Device device;
139+
wgpu::BufferUsage host_buf_usage;
140+
wgpu::BufferUsage dev_buf_usage;
141+
size_t buf_size;
142+
bool should_grow;
136143

137144
void init(wgpu::Device device,
138145
int num_bufs,
139146
size_t buf_size,
140147
wgpu::BufferUsage dev_buf_usage,
141-
wgpu::BufferUsage host_buf_usage) {
148+
wgpu::BufferUsage host_buf_usage,
149+
bool should_grow = false,
150+
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
151+
this->max_pool_size = max_pool_size;
152+
this->cur_pool_size = num_bufs;
153+
this->device = device;
154+
this->host_buf_usage = host_buf_usage;
155+
this->dev_buf_usage = dev_buf_usage;
156+
this->buf_size = buf_size;
157+
this->should_grow = should_grow;
142158
for (int i = 0; i < num_bufs; i++) {
143159
wgpu::Buffer host_buf;
144160
wgpu::Buffer dev_buf;
@@ -150,6 +166,25 @@ struct webgpu_buf_pool {
150166

151167
webgpu_pool_bufs alloc_bufs() {
152168
std::unique_lock<std::mutex> lock(mutex);
169+
if (!free.empty()) {
170+
webgpu_pool_bufs bufs = free.back();
171+
free.pop_back();
172+
return bufs;
173+
}
174+
175+
// Try growing the pool if no free buffers
176+
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
177+
cur_pool_size++;
178+
wgpu::Buffer host_buf;
179+
wgpu::Buffer dev_buf;
180+
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
181+
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
182+
183+
if (!(host_buf && dev_buf)) {
184+
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
185+
}
186+
return webgpu_pool_bufs{ host_buf, dev_buf };
187+
}
153188
cv.wait(lock, [this] { return !free.empty(); });
154189
webgpu_pool_bufs bufs = free.back();
155190
free.pop_back();
@@ -243,6 +278,7 @@ struct webgpu_gpu_profile_buf_pool {
243278
#endif
244279

245280
struct webgpu_command {
281+
uint32_t num_kernels;
246282
wgpu::CommandBuffer commands;
247283
std::vector<webgpu_pool_bufs> params_bufs;
248284
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
@@ -280,7 +316,6 @@ struct webgpu_global_context_struct {
280316

281317
webgpu_buf_pool memset_buf_pool;
282318
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
283-
std::atomic_uint inflight_threads = 0;
284319

285320
#ifdef GGML_WEBGPU_CPU_PROFILE
286321
// Profiling: labeled CPU time in ms (total)
@@ -426,13 +461,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
426461
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
427462
std::vector<webgpu_submission_futures> & futures,
428463
bool block = true) {
429-
// If we have too many in-flight submissions, wait on the oldest one first. If
430-
// there are many threads, inflight_max may be 0, meaning that we must wait on
431-
// all futures.
432-
uint64_t timeout_ms = block ? UINT64_MAX : 0;
433-
uint32_t inflight_threads = ctx->inflight_threads;
434-
uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
435-
while (futures.size() >= inflight_max && futures.size() > 0) {
464+
// If we have too many in-flight submissions, wait on the oldest one first.
465+
uint64_t timeout_ms = block ? UINT64_MAX : 0;
466+
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
436467
ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
437468
futures.erase(futures.begin());
438469
}
@@ -651,6 +682,7 @@ static webgpu_command ggml_backend_webgpu_build_multi(
651682
result.commands = commands;
652683
result.params_bufs = params_bufs_list;
653684
result.set_rows_error_bufs = set_rows_error_bufs;
685+
result.num_kernels = pipelines.size();
654686
#ifdef GGML_WEBGPU_GPU_PROFILE
655687
result.timestamp_query_bufs = ts_bufs;
656688
// TODO: handle multiple pipeline names
@@ -2081,19 +2113,17 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
20812113

20822114
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
20832115

2084-
ctx->global_ctx->inflight_threads++;
2085-
20862116
std::vector<webgpu_command> commands;
20872117
std::vector<webgpu_submission_futures> futures;
2118+
uint32_t num_batched_kernels = 0;
20882119
for (int i = 0; i < cgraph->n_nodes; i++) {
20892120
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
20902121
commands.push_back(*cmd);
2122+
num_batched_kernels += cmd.value().num_kernels;
20912123
}
2092-
// compute the batch size based on the number of inflight threads
2093-
uint32_t inflight_threads = ctx->global_ctx->inflight_threads;
2094-
uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
2095-
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
2096-
if (commands.size() >= batch_size) {
2124+
2125+
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
2126+
num_batched_kernels = 0;
20972127
futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
20982128
&ctx->set_rows_error_buf_pool));
20992129
// Process events and check for completed submissions
@@ -2109,7 +2139,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
21092139
}
21102140

21112141
ggml_backend_webgpu_wait(ctx->global_ctx, futures);
2112-
ctx->global_ctx->inflight_threads--;
21132142
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
21142143
return GGML_STATUS_SUCCESS;
21152144
}
@@ -2727,7 +2756,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
27272756
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
27282757
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
27292758
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2730-
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2759+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
27312760
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
27322761
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
27332762
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,

0 commit comments

Comments
 (0)