@@ -133,12 +133,28 @@ struct webgpu_buf_pool {
133133 // which can run on a different thread than the calling thread.
134134 std::mutex mutex;
135135 std::condition_variable cv;
136+ size_t cur_pool_size;
137+ size_t max_pool_size;
138+ wgpu::Device device;
139+ wgpu::BufferUsage host_buf_usage;
140+ wgpu::BufferUsage dev_buf_usage;
141+ size_t buf_size;
142+ bool should_grow;
136143
137144 void init (wgpu::Device device,
138145 int num_bufs,
139146 size_t buf_size,
140147 wgpu::BufferUsage dev_buf_usage,
141- wgpu::BufferUsage host_buf_usage) {
148+ wgpu::BufferUsage host_buf_usage,
149+ bool should_grow = false ,
150+ size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2 ) {
151+ this ->max_pool_size = max_pool_size;
152+ this ->cur_pool_size = num_bufs;
153+ this ->device = device;
154+ this ->host_buf_usage = host_buf_usage;
155+ this ->dev_buf_usage = dev_buf_usage;
156+ this ->buf_size = buf_size;
157+ this ->should_grow = should_grow;
142158 for (int i = 0 ; i < num_bufs; i++) {
143159 wgpu::Buffer host_buf;
144160 wgpu::Buffer dev_buf;
@@ -150,6 +166,25 @@ struct webgpu_buf_pool {
150166
151167 webgpu_pool_bufs alloc_bufs () {
152168 std::unique_lock<std::mutex> lock (mutex);
169+ if (!free.empty ()) {
170+ webgpu_pool_bufs bufs = free.back ();
171+ free.pop_back ();
172+ return bufs;
173+ }
174+
175+ // Try growing the pool if no free buffers
176+ if (free.empty () && cur_pool_size < max_pool_size && should_grow) {
177+ cur_pool_size++;
178+ wgpu::Buffer host_buf;
179+ wgpu::Buffer dev_buf;
180+ ggml_webgpu_create_buffer (device, host_buf, buf_size, host_buf_usage, " ggml_webgpu_host_pool_buf" );
181+ ggml_webgpu_create_buffer (device, dev_buf, buf_size, dev_buf_usage, " ggml_webgpu_dev_pool_buf" );
182+
183+ if (!(host_buf && dev_buf)) {
184+ GGML_ABORT (" webgpu_buf_pool: failed to allocate buffers" );
185+ }
186+ return webgpu_pool_bufs{ host_buf, dev_buf };
187+ }
153188 cv.wait (lock, [this ] { return !free.empty (); });
154189 webgpu_pool_bufs bufs = free.back ();
155190 free.pop_back ();
@@ -243,6 +278,7 @@ struct webgpu_gpu_profile_buf_pool {
243278#endif
244279
245280struct webgpu_command {
281+ uint32_t num_kernels;
246282 wgpu::CommandBuffer commands;
247283 std::vector<webgpu_pool_bufs> params_bufs;
248284 std::optional<webgpu_pool_bufs> set_rows_error_bufs;
@@ -280,7 +316,6 @@ struct webgpu_global_context_struct {
280316
281317 webgpu_buf_pool memset_buf_pool;
282318 std::map<int , webgpu_pipeline> memset_pipelines; // variant or type index
283- std::atomic_uint inflight_threads = 0 ;
284319
285320#ifdef GGML_WEBGPU_CPU_PROFILE
286321 // Profiling: labeled CPU time in ms (total)
@@ -426,13 +461,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
426461static void ggml_backend_webgpu_wait (webgpu_global_context & ctx,
427462 std::vector<webgpu_submission_futures> & futures,
428463 bool block = true ) {
429- // If we have too many in-flight submissions, wait on the oldest one first. If
430- // there are many threads, inflight_max may be 0, meaning that we must wait on
431- // all futures.
432- uint64_t timeout_ms = block ? UINT64_MAX : 0 ;
433- uint32_t inflight_threads = ctx->inflight_threads ;
434- uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max (inflight_threads, 1u );
435- while (futures.size () >= inflight_max && futures.size () > 0 ) {
464+ // If we have too many in-flight submissions, wait on the oldest one first.
465+ uint64_t timeout_ms = block ? UINT64_MAX : 0 ;
466+ while (futures.size () >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
436467 ctx->instance .WaitAny (futures[0 ].futures .size (), futures[0 ].futures .data (), UINT64_MAX);
437468 futures.erase (futures.begin ());
438469 }
@@ -651,6 +682,7 @@ static webgpu_command ggml_backend_webgpu_build_multi(
651682 result.commands = commands;
652683 result.params_bufs = params_bufs_list;
653684 result.set_rows_error_bufs = set_rows_error_bufs;
685+ result.num_kernels = pipelines.size ();
654686#ifdef GGML_WEBGPU_GPU_PROFILE
655687 result.timestamp_query_bufs = ts_bufs;
656688 // TODO: handle multiple pipeline names
@@ -2081,19 +2113,17 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
20812113
20822114 WEBGPU_CPU_PROFILE_TOTAL_START (graph_compute);
20832115
2084- ctx->global_ctx ->inflight_threads ++;
2085-
20862116 std::vector<webgpu_command> commands;
20872117 std::vector<webgpu_submission_futures> futures;
2118+ uint32_t num_batched_kernels = 0 ;
20882119 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
20892120 if (auto cmd = ggml_webgpu_encode_node (ctx, cgraph->nodes [i])) {
20902121 commands.push_back (*cmd);
2122+ num_batched_kernels += cmd.value ().num_kernels ;
20912123 }
2092- // compute the batch size based on the number of inflight threads
2093- uint32_t inflight_threads = ctx->global_ctx ->inflight_threads ;
2094- uint32_t batch_size = std::min (std::max (1u , WEBGPU_NUM_PARAM_BUFS / std::max (inflight_threads, 1u )),
2095- WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
2096- if (commands.size () >= batch_size) {
2124+
2125+ if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
2126+ num_batched_kernels = 0 ;
20972127 futures.push_back (ggml_backend_webgpu_submit (ctx->global_ctx , commands, ctx->param_buf_pool ,
20982128 &ctx->set_rows_error_buf_pool ));
20992129 // Process events and check for completed submissions
@@ -2109,7 +2139,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
21092139 }
21102140
21112141 ggml_backend_webgpu_wait (ctx->global_ctx , futures);
2112- ctx->global_ctx ->inflight_threads --;
21132142 WEBGPU_CPU_PROFILE_TOTAL_END (graph_compute, ctx->global_ctx );
21142143 return GGML_STATUS_SUCCESS;
21152144}
@@ -2727,7 +2756,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
27272756 webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx ->device );
27282757 webgpu_ctx->param_buf_pool .init (webgpu_ctx->global_ctx ->device , WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
27292758 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2730- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2759+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true );
27312760 webgpu_ctx->set_rows_error_buf_pool .init (webgpu_ctx->global_ctx ->device , WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
27322761 WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
27332762 wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
0 commit comments