Skip to content

Commit 745adbe

Browse files
nikhilJain17liparetejas
authored andcommitted
Remove pipeline cache mutexes (ggml-org#19195)
* Remove mutex for pipeline caches, since they are now per-thread. * Add comment * Run clang-format * Cleanup * Run CI again * Run CI once more * Run clang-format
1 parent 08b8c17 commit 745adbe

1 file changed

Lines changed: 94 additions & 121 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 94 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,13 @@ struct webgpu_submission_futures {
146146
struct webgpu_buf_pool {
147147
std::vector<webgpu_pool_bufs> free;
148148

149-
std::mutex mutex;
150-
149+
// The pool must be synchronized because
150+
// 1. The memset pool is shared globally by every ggml buffer,
151+
// since allocating a pool per ggml buffer would consume too much memory.
152+
// 2. For the per-thread buffer pools in webgpu_context,
153+
// buffers are allocated and freed in Dawn callbacks,
154+
// which can run on a different thread than the calling thread.
155+
std::mutex mutex;
151156
std::condition_variable cv;
152157

153158
void init(wgpu::Device device,
@@ -266,7 +271,7 @@ struct webgpu_command {
266271
#endif
267272
};
268273

269-
struct webgpu_capabilities_base {
274+
struct webgpu_capabilities {
270275
wgpu::Limits limits;
271276
bool supports_subgroup_matrix = false;
272277

@@ -286,11 +291,11 @@ struct webgpu_global_context_struct {
286291
wgpu::Device device;
287292
wgpu::Queue queue;
288293

289-
webgpu_capabilities_base capabilities;
294+
webgpu_capabilities capabilities;
290295
// Shared buffer to move data from device to host
291-
wgpu::Buffer get_tensor_staging_buf;
296+
wgpu::Buffer get_tensor_staging_buf;
292297
// Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
293-
std::recursive_mutex mutex;
298+
std::recursive_mutex mutex;
294299

295300
webgpu_buf_pool memset_buf_pool;
296301
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
@@ -361,7 +366,6 @@ struct webgpu_context_struct {
361366
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
362367

363368
size_t memset_bytes_per_thread;
364-
365369
};
366370

367371
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -383,9 +387,8 @@ struct ggml_backend_webgpu_device_context {
383387

384388
// Per-thread data required to actually run WebGPU operations in a backend instance
385389
struct ggml_backend_webgpu_context {
386-
webgpu_context webgpu_ctx;
387-
std::once_flag init_once;
388-
std::string name;
390+
webgpu_context webgpu_ctx;
391+
std::string name;
389392
};
390393

391394
// Per-thread data related to buffers
@@ -861,20 +864,15 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
861864
};
862865

863866
webgpu_pipeline pipeline;
864-
{
865-
// TODO: remove guard once pipeline caches are per-thread
866-
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
867-
auto it = ctx->pad_pipelines.find(pipeline_key);
868-
if (it != ctx->pad_pipelines.end()) {
869-
pipeline = it->second;
870-
} else {
871-
ggml_webgpu_processed_shader processed =
872-
ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
873-
pipeline =
874-
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
875-
pipeline.context = processed.decisions;
876-
ctx->pad_pipelines.emplace(pipeline_key, pipeline);
877-
}
867+
auto it = ctx->pad_pipelines.find(pipeline_key);
868+
if (it != ctx->pad_pipelines.end()) {
869+
pipeline = it->second;
870+
} else {
871+
ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
872+
pipeline =
873+
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
874+
pipeline.context = processed.decisions;
875+
ctx->pad_pipelines.emplace(pipeline_key, pipeline);
878876
}
879877

880878
ggml_webgpu_generic_shader_decisions decisions =
@@ -944,20 +942,16 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
944942
};
945943

946944
webgpu_pipeline pipeline;
947-
// TODO: remove guard once pipeline caches are per-thread
948-
{
949-
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
950-
auto it = ctx->set_rows_pipelines.find(key);
951-
if (it != ctx->set_rows_pipelines.end()) {
952-
pipeline = it->second;
953-
} else {
954-
ggml_webgpu_processed_shader processed =
955-
ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
956-
pipeline =
957-
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
958-
pipeline.context = processed.decisions;
959-
ctx->set_rows_pipelines.emplace(key, pipeline);
960-
}
945+
auto it = ctx->set_rows_pipelines.find(key);
946+
if (it != ctx->set_rows_pipelines.end()) {
947+
pipeline = it->second;
948+
} else {
949+
ggml_webgpu_processed_shader processed =
950+
ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
951+
pipeline =
952+
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
953+
pipeline.context = processed.decisions;
954+
ctx->set_rows_pipelines.emplace(key, pipeline);
961955
}
962956

963957
ggml_webgpu_generic_shader_decisions decisions =
@@ -1261,29 +1255,25 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
12611255
};
12621256

12631257
webgpu_pipeline pipeline;
1264-
// TODO: remove guard once pipeline caches are per-thread
1265-
{
1266-
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
1267-
auto it = ctx->flash_attn_pipelines.find(key);
1268-
if (it != ctx->flash_attn_pipelines.end()) {
1269-
pipeline = it->second;
1270-
} else {
1271-
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
1272-
.key = key,
1273-
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
1274-
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
1275-
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
1276-
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1277-
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
1278-
};
1279-
1280-
ggml_webgpu_processed_shader processed =
1281-
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
1282-
pipeline =
1283-
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1284-
pipeline.context = processed.decisions;
1285-
ctx->flash_attn_pipelines.emplace(key, pipeline);
1286-
}
1258+
auto it = ctx->flash_attn_pipelines.find(key);
1259+
if (it != ctx->flash_attn_pipelines.end()) {
1260+
pipeline = it->second;
1261+
} else {
1262+
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
1263+
.key = key,
1264+
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
1265+
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
1266+
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
1267+
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1268+
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
1269+
};
1270+
1271+
ggml_webgpu_processed_shader processed =
1272+
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
1273+
pipeline =
1274+
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1275+
pipeline.context = processed.decisions;
1276+
ctx->flash_attn_pipelines.emplace(key, pipeline);
12871277
}
12881278

12891279
ggml_webgpu_flash_attn_shader_decisions decisions =
@@ -1308,20 +1298,16 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
13081298
};
13091299

13101300
webgpu_pipeline pipeline;
1311-
{
1312-
// TODO: remove guard once pipeline caches are per-thread
1313-
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
1314-
auto it = ctx->unary_pipelines.find(pipeline_key);
1315-
if (it != ctx->unary_pipelines.end()) {
1316-
pipeline = it->second;
1317-
} else {
1318-
ggml_webgpu_processed_shader processed =
1319-
ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
1320-
pipeline =
1321-
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1322-
pipeline.context = processed.decisions;
1323-
ctx->unary_pipelines.emplace(pipeline_key, pipeline);
1324-
}
1301+
auto it = ctx->unary_pipelines.find(pipeline_key);
1302+
if (it != ctx->unary_pipelines.end()) {
1303+
pipeline = it->second;
1304+
} else {
1305+
ggml_webgpu_processed_shader processed =
1306+
ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
1307+
pipeline =
1308+
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1309+
pipeline.context = processed.decisions;
1310+
ctx->unary_pipelines.emplace(pipeline_key, pipeline);
13251311
}
13261312

13271313
ggml_webgpu_generic_shader_decisions decisions =
@@ -1743,19 +1729,15 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
17431729
};
17441730

17451731
webgpu_pipeline pipeline;
1746-
{
1747-
// TODO: remove guard once pipeline caches are per-thread
1748-
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
1749-
auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
1750-
if (it != ctx->argmax_pipelines.end()) {
1751-
pipeline = it->second;
1752-
} else {
1753-
ggml_webgpu_processed_shader processed =
1754-
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
1755-
pipeline =
1756-
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1757-
ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
1758-
}
1732+
auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
1733+
if (it != ctx->argmax_pipelines.end()) {
1734+
pipeline = it->second;
1735+
} else {
1736+
ggml_webgpu_processed_shader processed =
1737+
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
1738+
pipeline =
1739+
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1740+
ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
17591741
}
17601742
uint32_t wg_x = ggml_nelements(dst);
17611743
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
@@ -1772,9 +1754,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
17721754
.order = order
17731755
};
17741756

1775-
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
1776-
webgpu_pipeline argsort_pipeline;
1777-
auto it = ctx->argsort_pipelines.find(order);
1757+
webgpu_pipeline argsort_pipeline;
1758+
auto it = ctx->argsort_pipelines.find(order);
17781759
if (it != ctx->argsort_pipelines.end()) {
17791760
argsort_pipeline = it->second;
17801761
} else {
@@ -1963,19 +1944,15 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
19631944
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
19641945
};
19651946
webgpu_pipeline pipeline;
1966-
// TODO: remove guard once pipeline caches are per-thread
1967-
{
1968-
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
1969-
auto it = ctx->cumsum_pipelines.find(1);
1970-
if (it != ctx->cumsum_pipelines.end()) {
1971-
pipeline = it->second;
1972-
} else {
1973-
ggml_webgpu_processed_shader processed =
1974-
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
1975-
pipeline =
1976-
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1977-
ctx->cumsum_pipelines.emplace(1, pipeline);
1978-
}
1947+
auto it = ctx->cumsum_pipelines.find(1);
1948+
if (it != ctx->cumsum_pipelines.end()) {
1949+
pipeline = it->second;
1950+
} else {
1951+
ggml_webgpu_processed_shader processed =
1952+
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
1953+
pipeline =
1954+
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1955+
ctx->cumsum_pipelines.emplace(1, pipeline);
19791956
}
19801957
uint32_t wg_x = ggml_nrows(dst);
19811958
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
@@ -2009,19 +1986,15 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s
20091986
};
20101987

20111988
webgpu_pipeline pipeline;
2012-
{
2013-
// TODO: remove guard once pipeline caches are per-thread
2014-
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
2015-
auto it = ctx->sum_rows_pipelines.find(1);
2016-
if (it != ctx->sum_rows_pipelines.end()) {
2017-
pipeline = it->second;
2018-
} else {
2019-
ggml_webgpu_processed_shader processed =
2020-
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
2021-
pipeline =
2022-
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
2023-
ctx->sum_rows_pipelines.emplace(1, pipeline);
2024-
}
1989+
auto it = ctx->sum_rows_pipelines.find(1);
1990+
if (it != ctx->sum_rows_pipelines.end()) {
1991+
pipeline = it->second;
1992+
} else {
1993+
ggml_webgpu_processed_shader processed =
1994+
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
1995+
pipeline =
1996+
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1997+
ctx->sum_rows_pipelines.emplace(1, pipeline);
20251998
}
20261999
uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
20272000
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
@@ -3016,10 +2989,10 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
30162989

30172990
#ifdef GGML_WEBGPU_GPU_PROFILE
30182991
// Initialize buffer pool for timestamp queries, used for profiling
3019-
ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
3020-
WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
3021-
wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
3022-
wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
2992+
ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
2993+
ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
2994+
wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
2995+
wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
30232996
#endif
30242997

30252998
GGML_LOG_INFO(

0 commit comments

Comments
 (0)