Skip to content

Commit 55ac090

Browse files
authored
vulkan: don't hold the device mutex while compiling pipelines (ggml-org#23641)
* vulkan: don't hold the device mutex while compiling pipelines We need to hold a lock while we traverse all pipelines and lazily initialize them, but we don't need to hold it while the pipeline is being compiled. And it doesn't need to be the same lock as the device mutex. We call load_shaders each time a pipeline is needed, so we only need to compile that one pipeline (and, for example, don't want to end up compiling a pipeline that another thread should be compiling). * remove 'needed'
1 parent bef69f1 commit 55ac090

1 file changed

Lines changed: 99 additions & 45 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 99 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV {
6565
#include <shared_mutex>
6666
#include <mutex>
6767
#include <future>
68+
#include <condition_variable>
6869
#include <thread>
6970

7071
#if defined(_MSC_VER)
@@ -159,8 +160,9 @@ struct vk_pipeline_struct {
159160
uint32_t align;
160161
// true if fields have been set by ggml_vk_create_pipeline
161162
bool initialized {};
162-
// set to true to request the pipeline is compiled
163-
std::atomic<bool> needed {};
163+
// true while a compile is in flight, used to dedupe concurrent claims.
164+
// Protected by device->compile_mutex.
165+
bool compile_pending {};
164166
// set to true when the shader has been compiled
165167
std::atomic<bool> compiled {};
166168
// number of registers used, extracted from pipeline executable properties
@@ -621,6 +623,13 @@ struct vk_device_struct {
621623
std::recursive_mutex mutex;
622624
mutable std::shared_mutex pinned_memory_mutex;
623625

626+
// Guards compile_pending, all_pipelines, and the dynamic pipeline maps
627+
// (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile
628+
// runs with no lock held, so different pipelines can compile in parallel.
629+
// Lock order is device->mutex -> compile_mutex, never the reverse.
630+
std::mutex compile_mutex;
631+
std::condition_variable compile_cv;
632+
624633
vk::PhysicalDevice physical_device;
625634
vk::PhysicalDeviceProperties properties;
626635
std::string name;
@@ -1729,7 +1738,7 @@ struct ggml_vk_garbage_collector {
17291738
};
17301739

17311740
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
1732-
static void ggml_vk_load_shaders(vk_device& device);
1741+
static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr);
17331742
static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
17341743

17351744
static bool vk_memory_logger_enabled = false;
@@ -2196,11 +2205,6 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
21962205
ctx->device->device.resetFences({ ctx->fence });
21972206
}
21982207

2199-
// variables to track number of compiles in progress
2200-
static uint32_t compile_count = 0;
2201-
static std::mutex compile_count_mutex;
2202-
static std::condition_variable compile_count_cond;
2203-
22042208
static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367;
22052209
static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447;
22062210
static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4;
@@ -2495,7 +2499,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
24952499
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
24962500
throw e;
24972501
}
2498-
pipeline->compiled = true;
24992502

25002503
if (vk_instance.debug_utils_support) {
25012504
vk::DebugUtilsObjectNameInfoEXT duoni;
@@ -2544,14 +2547,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
25442547
}
25452548
}
25462549

2547-
device->all_pipelines.push_back(pipeline);
2548-
25492550
{
2550-
std::lock_guard<std::mutex> guard(compile_count_mutex);
2551-
assert(compile_count > 0);
2552-
compile_count--;
2551+
std::lock_guard<std::mutex> guard(device->compile_mutex);
2552+
device->all_pipelines.push_back(pipeline);
2553+
pipeline->compiled = true;
2554+
pipeline->compile_pending = false;
25532555
}
2554-
compile_count_cond.notify_all();
2556+
device->compile_cv.notify_all();
25552557
}
25562558

25572559
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
@@ -2567,8 +2569,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx,
25672569
VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
25682570
ctx->pipeline_descriptor_set_requirements += n;
25692571
if (!pipeline->compiled) {
2570-
pipeline->needed = true;
2571-
ggml_vk_load_shaders(ctx->device);
2572+
ggml_vk_load_shaders(ctx->device, pipeline);
25722573
}
25732574
ggml_pipeline_allocate_descriptor_sets(ctx);
25742575
}
@@ -3567,10 +3568,26 @@ static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type
35673568
#endif
35683569
}
35693570

3570-
static void ggml_vk_load_shaders(vk_device& device) {
3571+
// load_shaders walks the pipeline list under compile_mutex and either claims
3572+
// the requested pipeline for compilation or, if another thread is already
3573+
// compiling it, drops the lock and waits on compile_cv. Compiles themselves
3574+
// run unlocked.
3575+
struct CompileTask {
3576+
vk_pipeline pipeline;
3577+
size_t spv_size;
3578+
const void * spv_data;
3579+
std::string entrypoint;
3580+
uint32_t parameter_count;
3581+
std::array<uint32_t, 3> wg_denoms;
3582+
std::vector<uint32_t> specialization_constants;
3583+
bool disable_robustness;
3584+
bool require_full_subgroups;
3585+
uint32_t required_subgroup_size;
3586+
};
3587+
3588+
static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
35713589
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
35723590

3573-
std::lock_guard<std::recursive_mutex> guard(device->mutex);
35743591
// some shaders have a minimum subgroup size
35753592
const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
35763593
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
@@ -3600,6 +3617,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
36003617
l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
36013618

36023619
uint32_t l_align, m_align, s_align;
3620+
3621+
vk_pipeline wait_pipeline;
3622+
CompileTask claimed_task {};
3623+
bool has_claimed_task = false;
3624+
3625+
// The rest of the walk reads and writes shared device state, so hold the
3626+
// lock until we're done deciding what to compile.
3627+
std::unique_lock<std::mutex> compile_lock(device->compile_mutex);
3628+
36033629
if (device->coopmat2) {
36043630
// spec constants and tile sizes for non-quant matmul/matmul_id
36053631
l_warptile = { 256, 128, 256, 64, 1 };
@@ -3785,7 +3811,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
37853811
device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
37863812
}
37873813

3788-
std::vector<std::future<void>> compiles;
37893814
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
37903815
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
37913816
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
@@ -3819,23 +3844,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
38193844
#endif
38203845
}
38213846

3822-
if (!pipeline->needed || pipeline->compiled) {
3847+
// We only care about the pipeline this call asked for; the rest
3848+
// (including the 64-bit indexing variant) are handled by their
3849+
// own request_descriptor_sets / load_shaders calls.
3850+
if (pipeline.get() != requested.get()) {
38233851
continue;
38243852
}
3825-
// TODO: We're no longer benefitting from the async compiles (shaders are
3826-
// compiled individually, as needed) and this complexity can be removed.
3827-
{
3828-
// wait until fewer than N compiles are in progress
3829-
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
3830-
std::unique_lock<std::mutex> guard(compile_count_mutex);
3831-
while (compile_count >= N) {
3832-
compile_count_cond.wait(guard);
3833-
}
3834-
compile_count++;
3853+
3854+
if (pipeline->compiled) {
3855+
continue;
38353856
}
38363857

3837-
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
3838-
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
3858+
wait_pipeline = pipeline;
3859+
3860+
if (!pipeline->compile_pending) {
3861+
pipeline->compile_pending = true;
3862+
claimed_task.pipeline = pipeline;
3863+
claimed_task.spv_size = spv_size;
3864+
claimed_task.spv_data = spv_data;
3865+
claimed_task.entrypoint = entrypoint;
3866+
claimed_task.parameter_count = parameter_count;
3867+
claimed_task.wg_denoms = wg_denoms;
3868+
claimed_task.specialization_constants = specialization_constants;
3869+
claimed_task.disable_robustness = disable_robustness;
3870+
claimed_task.require_full_subgroups = require_full_subgroups;
3871+
claimed_task.required_subgroup_size = required_subgroup_size;
3872+
has_claimed_task = true;
3873+
}
38393874
}
38403875
};
38413876

@@ -5332,8 +5367,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
53325367
}
53335368
}
53345369

5335-
for (auto &c : compiles) {
5336-
c.wait();
5370+
// Drop compile_mutex so other threads can walk while we compile.
5371+
compile_lock.unlock();
5372+
5373+
// Compile what we claimed; create_pipeline_func reacquires compile_mutex
5374+
// at the end to flip compile_pending/compiled and notify waiters.
5375+
if (has_claimed_task) {
5376+
auto & task = claimed_task;
5377+
ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data,
5378+
task.entrypoint, task.parameter_count, task.wg_denoms,
5379+
task.specialization_constants, task.disable_robustness,
5380+
task.require_full_subgroups, task.required_subgroup_size);
5381+
}
5382+
5383+
// Another thread may be compiling the pipeline we need; block on it here.
5384+
if (wait_pipeline) {
5385+
std::unique_lock<std::mutex> wait_lock(device->compile_mutex);
5386+
device->compile_cv.wait(wait_lock, [&] {
5387+
return wait_pipeline->compiled.load();
5388+
});
53375389
}
53385390
}
53395391

@@ -9722,7 +9774,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
97229774
vk_pipeline pipeline = nullptr;
97239775

97249776
{
9725-
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
9777+
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
97269778
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
97279779
auto it = pipelines.find(fa_pipeline_state);
97289780
if (it != pipelines.end()) {
@@ -9786,13 +9838,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
97869838

97879839
vk_pipeline pipeline_fa_mask_opt = nullptr;
97889840
if (use_mask_opt) {
9789-
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
9790-
auto &pipelines = ctx->device->pipeline_fa_mask_opt;
9791-
auto it = pipelines.find({Br, Bc});
9792-
if (it != pipelines.end()) {
9793-
pipeline_fa_mask_opt = it->second;
9794-
} else {
9795-
pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
9841+
{
9842+
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
9843+
auto &pipelines = ctx->device->pipeline_fa_mask_opt;
9844+
auto it = pipelines.find({Br, Bc});
9845+
if (it != pipelines.end()) {
9846+
pipeline_fa_mask_opt = it->second;
9847+
} else {
9848+
pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
9849+
}
97969850
}
97979851
assert(pipeline_fa_mask_opt);
97989852
ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
@@ -10326,7 +10380,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1032610380
vk_pipeline pipeline = nullptr;
1032710381

1032810382
{
10329-
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
10383+
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
1033010384
auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
1033110385
if (it != ctx->device->pipeline_solve_tri_f32.end()) {
1033210386
pipeline = it->second;
@@ -10485,7 +10539,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1048510539
vk_pipeline pipeline = nullptr;
1048610540

1048710541
{
10488-
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
10542+
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
1048910543
auto it = pipelines->find(conv2d_pipeline_state);
1049010544
if (it != pipelines->end()) {
1049110545
pipeline = it->second;

0 commit comments

Comments
 (0)