@@ -146,8 +146,13 @@ struct webgpu_submission_futures {
146146struct webgpu_buf_pool {
147147 std::vector<webgpu_pool_bufs> free;
148148
149- std::mutex mutex;
150-
149+ // The pool must be synchronized because
150+ // 1. The memset pool is shared globally by every ggml buffer,
151+ // since allocating a pool per ggml buffer would consume too much memory.
152+ // 2. For the per-thread buffer pools in webgpu_context,
153+ // buffers are allocated and freed in Dawn callbacks,
154+ // which can run on a different thread than the calling thread.
155+ std::mutex mutex;
151156 std::condition_variable cv;
152157
153158 void init (wgpu::Device device,
@@ -266,7 +271,7 @@ struct webgpu_command {
266271#endif
267272};
268273
269- struct webgpu_capabilities_base {
274+ struct webgpu_capabilities {
270275 wgpu::Limits limits;
271276 bool supports_subgroup_matrix = false ;
272277
@@ -286,11 +291,11 @@ struct webgpu_global_context_struct {
286291 wgpu::Device device;
287292 wgpu::Queue queue;
288293
289- webgpu_capabilities_base capabilities;
294+ webgpu_capabilities capabilities;
290295 // Shared buffer to move data from device to host
291- wgpu::Buffer get_tensor_staging_buf;
296+ wgpu::Buffer get_tensor_staging_buf;
292297 // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
293- std::recursive_mutex mutex;
298+ std::recursive_mutex mutex;
294299
295300 webgpu_buf_pool memset_buf_pool;
296301 std::map<int , webgpu_pipeline> memset_pipelines; // variant or type index
@@ -361,7 +366,6 @@ struct webgpu_context_struct {
361366 std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
362367
363368 size_t memset_bytes_per_thread;
364-
365369};
366370
367371typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -383,9 +387,8 @@ struct ggml_backend_webgpu_device_context {
383387
384388// Per-thread data required to actually run WebGPU operations in a backend instance
385389struct ggml_backend_webgpu_context {
386- webgpu_context webgpu_ctx;
387- std::once_flag init_once;
388- std::string name;
390+ webgpu_context webgpu_ctx;
391+ std::string name;
389392};
390393
391394// Per-thread data related to buffers
@@ -861,20 +864,15 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
861864 };
862865
863866 webgpu_pipeline pipeline;
864- {
865- // TODO: remove guard once pipeline caches are per-thread
866- std::lock_guard<std::recursive_mutex> lock (ctx->global_ctx ->mutex );
867- auto it = ctx->pad_pipelines .find (pipeline_key);
868- if (it != ctx->pad_pipelines .end ()) {
869- pipeline = it->second ;
870- } else {
871- ggml_webgpu_processed_shader processed =
872- ggml_webgpu_preprocess_pad_shader (ctx->p , wgsl_pad, shader_lib_ctx);
873- pipeline =
874- ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
875- pipeline.context = processed.decisions ;
876- ctx->pad_pipelines .emplace (pipeline_key, pipeline);
877- }
867+ auto it = ctx->pad_pipelines .find (pipeline_key);
868+ if (it != ctx->pad_pipelines .end ()) {
869+ pipeline = it->second ;
870+ } else {
871+ ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader (ctx->p , wgsl_pad, shader_lib_ctx);
872+ pipeline =
873+ ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
874+ pipeline.context = processed.decisions ;
875+ ctx->pad_pipelines .emplace (pipeline_key, pipeline);
878876 }
879877
880878 ggml_webgpu_generic_shader_decisions decisions =
@@ -944,20 +942,16 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
944942 };
945943
946944 webgpu_pipeline pipeline;
947- // TODO: remove guard once pipeline caches are per-thread
948- {
949- std::lock_guard<std::recursive_mutex> lock (ctx->global_ctx ->mutex );
950- auto it = ctx->set_rows_pipelines .find (key);
951- if (it != ctx->set_rows_pipelines .end ()) {
952- pipeline = it->second ;
953- } else {
954- ggml_webgpu_processed_shader processed =
955- ggml_webgpu_preprocess_set_rows_shader (ctx->p , wgsl_set_rows, shader_lib_ctx);
956- pipeline =
957- ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
958- pipeline.context = processed.decisions ;
959- ctx->set_rows_pipelines .emplace (key, pipeline);
960- }
945+ auto it = ctx->set_rows_pipelines .find (key);
946+ if (it != ctx->set_rows_pipelines .end ()) {
947+ pipeline = it->second ;
948+ } else {
949+ ggml_webgpu_processed_shader processed =
950+ ggml_webgpu_preprocess_set_rows_shader (ctx->p , wgsl_set_rows, shader_lib_ctx);
951+ pipeline =
952+ ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
953+ pipeline.context = processed.decisions ;
954+ ctx->set_rows_pipelines .emplace (key, pipeline);
961955 }
962956
963957 ggml_webgpu_generic_shader_decisions decisions =
@@ -1261,29 +1255,25 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
12611255 };
12621256
12631257 webgpu_pipeline pipeline;
1264- // TODO: remove guard once pipeline caches are per-thread
1265- {
1266- std::lock_guard<std::recursive_mutex> lock (ctx->global_ctx ->mutex );
1267- auto it = ctx->flash_attn_pipelines .find (key);
1268- if (it != ctx->flash_attn_pipelines .end ()) {
1269- pipeline = it->second ;
1270- } else {
1271- ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
1272- .key = key,
1273- .sg_mat_m = ctx->global_ctx ->capabilities .sg_mat_m ,
1274- .sg_mat_n = ctx->global_ctx ->capabilities .sg_mat_n ,
1275- .sg_mat_k = ctx->global_ctx ->capabilities .sg_mat_k ,
1276- .wg_mem_limit_bytes = ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupStorageSize ,
1277- .max_subgroup_size = ctx->global_ctx ->capabilities .max_subgroup_size
1278- };
1279-
1280- ggml_webgpu_processed_shader processed =
1281- ggml_webgpu_preprocess_flash_attn_shader (ctx->p , wgsl_flash_attn, shader_lib_ctx);
1282- pipeline =
1283- ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
1284- pipeline.context = processed.decisions ;
1285- ctx->flash_attn_pipelines .emplace (key, pipeline);
1286- }
1258+ auto it = ctx->flash_attn_pipelines .find (key);
1259+ if (it != ctx->flash_attn_pipelines .end ()) {
1260+ pipeline = it->second ;
1261+ } else {
1262+ ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
1263+ .key = key,
1264+ .sg_mat_m = ctx->global_ctx ->capabilities .sg_mat_m ,
1265+ .sg_mat_n = ctx->global_ctx ->capabilities .sg_mat_n ,
1266+ .sg_mat_k = ctx->global_ctx ->capabilities .sg_mat_k ,
1267+ .wg_mem_limit_bytes = ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupStorageSize ,
1268+ .max_subgroup_size = ctx->global_ctx ->capabilities .max_subgroup_size
1269+ };
1270+
1271+ ggml_webgpu_processed_shader processed =
1272+ ggml_webgpu_preprocess_flash_attn_shader (ctx->p , wgsl_flash_attn, shader_lib_ctx);
1273+ pipeline =
1274+ ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
1275+ pipeline.context = processed.decisions ;
1276+ ctx->flash_attn_pipelines .emplace (key, pipeline);
12871277 }
12881278
12891279 ggml_webgpu_flash_attn_shader_decisions decisions =
@@ -1308,20 +1298,16 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
13081298 };
13091299
13101300 webgpu_pipeline pipeline;
1311- {
1312- // TODO: remove guard once pipeline caches are per-thread
1313- std::lock_guard<std::recursive_mutex> lock (ctx->global_ctx ->mutex );
1314- auto it = ctx->unary_pipelines .find (pipeline_key);
1315- if (it != ctx->unary_pipelines .end ()) {
1316- pipeline = it->second ;
1317- } else {
1318- ggml_webgpu_processed_shader processed =
1319- ggml_webgpu_preprocess_unary_shader (ctx->p , wgsl_unary, shader_lib_ctx);
1320- pipeline =
1321- ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
1322- pipeline.context = processed.decisions ;
1323- ctx->unary_pipelines .emplace (pipeline_key, pipeline);
1324- }
1301+ auto it = ctx->unary_pipelines .find (pipeline_key);
1302+ if (it != ctx->unary_pipelines .end ()) {
1303+ pipeline = it->second ;
1304+ } else {
1305+ ggml_webgpu_processed_shader processed =
1306+ ggml_webgpu_preprocess_unary_shader (ctx->p , wgsl_unary, shader_lib_ctx);
1307+ pipeline =
1308+ ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
1309+ pipeline.context = processed.decisions ;
1310+ ctx->unary_pipelines .emplace (pipeline_key, pipeline);
13251311 }
13261312
13271313 ggml_webgpu_generic_shader_decisions decisions =
@@ -1743,19 +1729,15 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
17431729 };
17441730
17451731 webgpu_pipeline pipeline;
1746- {
1747- // TODO: remove guard once pipeline caches are per-thread
1748- std::lock_guard<std::recursive_mutex> lock (ctx->global_ctx ->mutex );
1749- auto it = ctx->argmax_pipelines .find (shader_lib_ctx.vec4 );
1750- if (it != ctx->argmax_pipelines .end ()) {
1751- pipeline = it->second ;
1752- } else {
1753- ggml_webgpu_processed_shader processed =
1754- ggml_webgpu_preprocess_generic_shader (ctx->p , wgsl_argmax, shader_lib_ctx, " argmax" );
1755- pipeline =
1756- ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
1757- ctx->argmax_pipelines .emplace (shader_lib_ctx.vec4 , pipeline);
1758- }
1732+ auto it = ctx->argmax_pipelines .find (shader_lib_ctx.vec4 );
1733+ if (it != ctx->argmax_pipelines .end ()) {
1734+ pipeline = it->second ;
1735+ } else {
1736+ ggml_webgpu_processed_shader processed =
1737+ ggml_webgpu_preprocess_generic_shader (ctx->p , wgsl_argmax, shader_lib_ctx, " argmax" );
1738+ pipeline =
1739+ ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
1740+ ctx->argmax_pipelines .emplace (shader_lib_ctx.vec4 , pipeline);
17591741 }
17601742 uint32_t wg_x = ggml_nelements (dst);
17611743 return ggml_backend_webgpu_build (ctx->global_ctx , ctx->param_buf_pool , pipeline, params, entries, wg_x);
@@ -1772,9 +1754,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
17721754 .order = order
17731755 };
17741756
1775- std::lock_guard<std::recursive_mutex> lock (ctx->global_ctx ->mutex );
1776- webgpu_pipeline argsort_pipeline;
1777- auto it = ctx->argsort_pipelines .find (order);
1757+ webgpu_pipeline argsort_pipeline;
1758+ auto it = ctx->argsort_pipelines .find (order);
17781759 if (it != ctx->argsort_pipelines .end ()) {
17791760 argsort_pipeline = it->second ;
17801761 } else {
@@ -1963,19 +1944,15 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
19631944 .max_wg_size = ctx->global_ctx ->capabilities .limits .maxComputeInvocationsPerWorkgroup ,
19641945 };
19651946 webgpu_pipeline pipeline;
1966- // TODO: remove guard once pipeline caches are per-thread
1967- {
1968- std::lock_guard<std::recursive_mutex> lock (ctx->global_ctx ->mutex );
1969- auto it = ctx->cumsum_pipelines .find (1 );
1970- if (it != ctx->cumsum_pipelines .end ()) {
1971- pipeline = it->second ;
1972- } else {
1973- ggml_webgpu_processed_shader processed =
1974- ggml_webgpu_preprocess_generic_shader (ctx->p , wgsl_cumsum, shader_lib_ctx, " cumsum" );
1975- pipeline =
1976- ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
1977- ctx->cumsum_pipelines .emplace (1 , pipeline);
1978- }
1947+ auto it = ctx->cumsum_pipelines .find (1 );
1948+ if (it != ctx->cumsum_pipelines .end ()) {
1949+ pipeline = it->second ;
1950+ } else {
1951+ ggml_webgpu_processed_shader processed =
1952+ ggml_webgpu_preprocess_generic_shader (ctx->p , wgsl_cumsum, shader_lib_ctx, " cumsum" );
1953+ pipeline =
1954+ ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
1955+ ctx->cumsum_pipelines .emplace (1 , pipeline);
19791956 }
19801957 uint32_t wg_x = ggml_nrows (dst);
19811958 return ggml_backend_webgpu_build (ctx->global_ctx , ctx->param_buf_pool , pipeline, params, entries, wg_x);
@@ -2009,19 +1986,15 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s
20091986 };
20101987
20111988 webgpu_pipeline pipeline;
2012- {
2013- // TODO: remove guard once pipeline caches are per-thread
2014- std::lock_guard<std::recursive_mutex> lock (ctx->global_ctx ->mutex );
2015- auto it = ctx->sum_rows_pipelines .find (1 );
2016- if (it != ctx->sum_rows_pipelines .end ()) {
2017- pipeline = it->second ;
2018- } else {
2019- ggml_webgpu_processed_shader processed =
2020- ggml_webgpu_preprocess_generic_shader (ctx->p , wgsl_sum_rows, shader_lib_ctx, " sum_rows" );
2021- pipeline =
2022- ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
2023- ctx->sum_rows_pipelines .emplace (1 , pipeline);
2024- }
1989+ auto it = ctx->sum_rows_pipelines .find (1 );
1990+ if (it != ctx->sum_rows_pipelines .end ()) {
1991+ pipeline = it->second ;
1992+ } else {
1993+ ggml_webgpu_processed_shader processed =
1994+ ggml_webgpu_preprocess_generic_shader (ctx->p , wgsl_sum_rows, shader_lib_ctx, " sum_rows" );
1995+ pipeline =
1996+ ggml_webgpu_create_pipeline (ctx->global_ctx ->device , processed.wgsl .c_str (), processed.variant .c_str ());
1997+ ctx->sum_rows_pipelines .emplace (1 , pipeline);
20251998 }
20261999 uint32_t wg_x = total_sum ? 1 : ggml_nrows (dst);
20272000 return ggml_backend_webgpu_build (ctx->global_ctx , ctx->param_buf_pool , pipeline, params, entries, wg_x);
@@ -3016,10 +2989,10 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
30162989
30172990#ifdef GGML_WEBGPU_GPU_PROFILE
30182991 // Initialize buffer pool for timestamp queries, used for profiling
3019- ctx->webgpu_global_ctx ->timestamp_query_buf_pool .init (ctx-> webgpu_global_ctx -> device , WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
3020- WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
3021- wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
3022- wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
2992+ ctx->webgpu_global_ctx ->timestamp_query_buf_pool .init (
2993+ ctx-> webgpu_global_ctx -> device , WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
2994+ wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
2995+ wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
30232996#endif
30242997
30252998 GGML_LOG_INFO (
0 commit comments