@@ -65,6 +65,7 @@ typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV {
6565#include <shared_mutex>
6666#include <mutex>
6767#include <future>
68+ #include <condition_variable>
6869#include <thread>
6970
7071#if defined(_MSC_VER)
@@ -159,8 +160,9 @@ struct vk_pipeline_struct {
159160 uint32_t align;
160161 // true if fields have been set by ggml_vk_create_pipeline
161162 bool initialized {};
162- // set to true to request the pipeline is compiled
163- std::atomic<bool> needed {};
163+ // true while a compile is in flight, used to dedupe concurrent claims.
164+ // Protected by device->compile_mutex.
165+ bool compile_pending {};
164166 // set to true when the shader has been compiled
165167 std::atomic<bool> compiled {};
166168 // number of registers used, extracted from pipeline executable properties
@@ -621,6 +623,13 @@ struct vk_device_struct {
621623 std::recursive_mutex mutex;
622624 mutable std::shared_mutex pinned_memory_mutex;
623625
626+ // Guards compile_pending, all_pipelines, and the dynamic pipeline maps
627+ // (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile
628+ // runs with no lock held, so different pipelines can compile in parallel.
629+ // Lock order is device->mutex -> compile_mutex, never the reverse.
630+ std::mutex compile_mutex;
631+ std::condition_variable compile_cv;
632+
624633 vk::PhysicalDevice physical_device;
625634 vk::PhysicalDeviceProperties properties;
626635 std::string name;
@@ -1729,7 +1738,7 @@ struct ggml_vk_garbage_collector {
17291738};
17301739
17311740static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
1732- static void ggml_vk_load_shaders(vk_device& device);
1741+ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr );
17331742static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
17341743
17351744static bool vk_memory_logger_enabled = false;
@@ -2196,11 +2205,6 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
21962205 ctx->device->device.resetFences({ ctx->fence });
21972206}
21982207
2199- // variables to track number of compiles in progress
2200- static uint32_t compile_count = 0;
2201- static std::mutex compile_count_mutex;
2202- static std::condition_variable compile_count_cond;
2203-
22042208static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367;
22052209static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447;
22062210static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4;
@@ -2495,7 +2499,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
24952499 std::cerr << "ggml_vulkan: " << e.what() << std::endl;
24962500 throw e;
24972501 }
2498- pipeline->compiled = true;
24992502
25002503 if (vk_instance.debug_utils_support) {
25012504 vk::DebugUtilsObjectNameInfoEXT duoni;
@@ -2544,14 +2547,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
25442547 }
25452548 }
25462549
2547- device->all_pipelines.push_back(pipeline);
2548-
25492550 {
2550- std::lock_guard<std::mutex> guard(compile_count_mutex);
2551- assert(compile_count > 0);
2552- compile_count--;
2551+ std::lock_guard<std::mutex> guard(device->compile_mutex);
2552+ device->all_pipelines.push_back(pipeline);
2553+ pipeline->compiled = true;
2554+ pipeline->compile_pending = false;
25532555 }
2554- compile_count_cond .notify_all();
2556+ device->compile_cv .notify_all();
25552557}
25562558
25572559static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
@@ -2567,8 +2569,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx,
25672569 VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
25682570 ctx->pipeline_descriptor_set_requirements += n;
25692571 if (!pipeline->compiled) {
2570- pipeline->needed = true;
2571- ggml_vk_load_shaders(ctx->device);
2572+ ggml_vk_load_shaders(ctx->device, pipeline);
25722573 }
25732574 ggml_pipeline_allocate_descriptor_sets(ctx);
25742575}
@@ -3567,10 +3568,26 @@ static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type
35673568#endif
35683569}
35693570
3570- static void ggml_vk_load_shaders(vk_device& device) {
3571+ // load_shaders walks the pipeline list under compile_mutex and either claims
3572+ // the requested pipeline for compilation or, if another thread is already
3573+ // compiling it, drops the lock and waits on compile_cv. Compiles themselves
3574+ // run unlocked.
3575+ struct CompileTask {
3576+ vk_pipeline pipeline;
3577+ size_t spv_size;
3578+ const void * spv_data;
3579+ std::string entrypoint;
3580+ uint32_t parameter_count;
3581+ std::array<uint32_t, 3> wg_denoms;
3582+ std::vector<uint32_t> specialization_constants;
3583+ bool disable_robustness;
3584+ bool require_full_subgroups;
3585+ uint32_t required_subgroup_size;
3586+ };
3587+
3588+ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
35713589 VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
35723590
3573- std::lock_guard<std::recursive_mutex> guard(device->mutex);
35743591 // some shaders have a minimum subgroup size
35753592 const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
35763593 const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
@@ -3600,6 +3617,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
36003617 l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
36013618
36023619 uint32_t l_align, m_align, s_align;
3620+
3621+ vk_pipeline wait_pipeline;
3622+ CompileTask claimed_task {};
3623+ bool has_claimed_task = false;
3624+
3625+ // The rest of the walk reads and writes shared device state, so hold the
3626+ // lock until we're done deciding what to compile.
3627+ std::unique_lock<std::mutex> compile_lock(device->compile_mutex);
3628+
36033629 if (device->coopmat2) {
36043630 // spec constants and tile sizes for non-quant matmul/matmul_id
36053631 l_warptile = { 256, 128, 256, 64, 1 };
@@ -3785,7 +3811,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
37853811 device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
37863812 }
37873813
3788- std::vector<std::future<void>> compiles;
37893814 auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
37903815 uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
37913816 uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
@@ -3819,23 +3844,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
38193844#endif
38203845 }
38213846
3822- if (!pipeline->needed || pipeline->compiled) {
3847+ // We only care about the pipeline this call asked for; the rest
3848+ // (including the 64-bit indexing variant) are handled by their
3849+ // own request_descriptor_sets / load_shaders calls.
3850+ if (pipeline.get() != requested.get()) {
38233851 continue;
38243852 }
3825- // TODO: We're no longer benefitting from the async compiles (shaders are
3826- // compiled individually, as needed) and this complexity can be removed.
3827- {
3828- // wait until fewer than N compiles are in progress
3829- uint32_t N = std::max(1u, std::thread::hardware_concurrency());
3830- std::unique_lock<std::mutex> guard(compile_count_mutex);
3831- while (compile_count >= N) {
3832- compile_count_cond.wait(guard);
3833- }
3834- compile_count++;
3853+
3854+ if (pipeline->compiled) {
3855+ continue;
38353856 }
38363857
3837- compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
3838- parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
3858+ wait_pipeline = pipeline;
3859+
3860+ if (!pipeline->compile_pending) {
3861+ pipeline->compile_pending = true;
3862+ claimed_task.pipeline = pipeline;
3863+ claimed_task.spv_size = spv_size;
3864+ claimed_task.spv_data = spv_data;
3865+ claimed_task.entrypoint = entrypoint;
3866+ claimed_task.parameter_count = parameter_count;
3867+ claimed_task.wg_denoms = wg_denoms;
3868+ claimed_task.specialization_constants = specialization_constants;
3869+ claimed_task.disable_robustness = disable_robustness;
3870+ claimed_task.require_full_subgroups = require_full_subgroups;
3871+ claimed_task.required_subgroup_size = required_subgroup_size;
3872+ has_claimed_task = true;
3873+ }
38393874 }
38403875 };
38413876
@@ -5332,8 +5367,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
53325367 }
53335368 }
53345369
5335- for (auto &c : compiles) {
5336- c.wait();
5370+ // Drop compile_mutex so other threads can walk while we compile.
5371+ compile_lock.unlock();
5372+
5373+ // Compile what we claimed; create_pipeline_func reacquires compile_mutex
5374+ // at the end to flip compile_pending/compiled and notify waiters.
5375+ if (has_claimed_task) {
5376+ auto & task = claimed_task;
5377+ ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data,
5378+ task.entrypoint, task.parameter_count, task.wg_denoms,
5379+ task.specialization_constants, task.disable_robustness,
5380+ task.require_full_subgroups, task.required_subgroup_size);
5381+ }
5382+
5383+ // Another thread may be compiling the pipeline we need; block on it here.
5384+ if (wait_pipeline) {
5385+ std::unique_lock<std::mutex> wait_lock(device->compile_mutex);
5386+ device->compile_cv.wait(wait_lock, [&] {
5387+ return wait_pipeline->compiled.load();
5388+ });
53375389 }
53385390}
53395391
@@ -9722,7 +9774,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
97229774 vk_pipeline pipeline = nullptr;
97239775
97249776 {
9725- std::lock_guard<std::recursive_mutex > guard(ctx->device->mutex );
9777+ std::lock_guard<std::mutex > guard(ctx->device->compile_mutex );
97269778 auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
97279779 auto it = pipelines.find(fa_pipeline_state);
97289780 if (it != pipelines.end()) {
@@ -9786,13 +9838,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
97869838
97879839 vk_pipeline pipeline_fa_mask_opt = nullptr;
97889840 if (use_mask_opt) {
9789- std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
9790- auto &pipelines = ctx->device->pipeline_fa_mask_opt;
9791- auto it = pipelines.find({Br, Bc});
9792- if (it != pipelines.end()) {
9793- pipeline_fa_mask_opt = it->second;
9794- } else {
9795- pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
9841+ {
9842+ std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
9843+ auto &pipelines = ctx->device->pipeline_fa_mask_opt;
9844+ auto it = pipelines.find({Br, Bc});
9845+ if (it != pipelines.end()) {
9846+ pipeline_fa_mask_opt = it->second;
9847+ } else {
9848+ pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
9849+ }
97969850 }
97979851 assert(pipeline_fa_mask_opt);
97989852 ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
@@ -10326,7 +10380,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1032610380 vk_pipeline pipeline = nullptr;
1032710381
1032810382 {
10329- std::lock_guard<std::recursive_mutex > guard(ctx->device->mutex );
10383+ std::lock_guard<std::mutex > guard(ctx->device->compile_mutex );
1033010384 auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
1033110385 if (it != ctx->device->pipeline_solve_tri_f32.end()) {
1033210386 pipeline = it->second;
@@ -10485,7 +10539,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1048510539 vk_pipeline pipeline = nullptr;
1048610540
1048710541 {
10488- std::lock_guard<std::recursive_mutex > guard(ctx->device->mutex );
10542+ std::lock_guard<std::mutex > guard(ctx->device->compile_mutex );
1048910543 auto it = pipelines->find(conv2d_pipeline_state);
1049010544 if (it != pipelines->end()) {
1049110545 pipeline = it->second;
0 commit comments