Skip to content

Commit eaa2054

Browse files
committed
Vulkan: expose allocator lifecycle
When an embedder owns the Vulkan instance, device, and queue, it still needs a Halide allocator for shader modules, staging buffers, and other runtime allocations. Add acquire/release helpers for the opaque Vulkan allocator so embedders can store it with their context and return it from later acquire calls. The helpers reload Vulkan function pointers for the supplied context, validate that a reused allocator still matches the device, and release only Halide-owned allocator and shader-module state. Key the Vulkan compilation cache by allocator instead of VkDevice so external allocators have independent shader-module lifetimes even when they share the same device handle.
1 parent 95991c9 commit eaa2054

5 files changed

Lines changed: 177 additions & 27 deletions

File tree

src/runtime/HalideRuntimeVulkan.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,41 @@ extern int halide_vulkan_release_context(void *user_context,
105105
VkDevice device,
106106
VkQueue queue,
107107
VkDebugUtilsMessengerEXT messenger);
108+
109+
/** Ensure a Halide Vulkan memory allocator exists for an externally-managed
110+
* Vulkan context. Intended for embedders that override
111+
* halide_vulkan_acquire_context()/halide_vulkan_release_context().
112+
*
113+
* The embedder should store the returned allocator with the same object that
114+
* owns the external context, return it from later acquire-context calls for
115+
* that context, and release it when that external context is torn down.
116+
*
117+
* This call refreshes Halide's Vulkan dispatch tables for the supplied
118+
* instance/device. If `*allocator` is null, a new allocator bound to
119+
* `device`/`physical_device` is created and stored back. If `*allocator` is
120+
* non-null, it must already be bound to the supplied device.
121+
*/
122+
extern int halide_vulkan_acquire_memory_allocator(void *user_context,
123+
struct halide_vulkan_memory_allocator **allocator,
124+
VkInstance instance,
125+
VkDevice device,
126+
VkPhysicalDevice physical_device);
127+
128+
/** Destroy a Halide Vulkan memory allocator created for an externally-managed
129+
* Vulkan context after the embedder has ensured no in-flight Halide work is
130+
* using it. This only releases Halide-owned allocator and shader-module state;
131+
* it does not destroy the Vulkan instance, device, queue, or any
132+
* embedder-owned debug messenger.
133+
*
134+
* This call refreshes Halide's Vulkan dispatch tables for the supplied
135+
* instance/device. The supplied device and physical_device must match the
136+
* allocator's context.
137+
*/
138+
extern int halide_vulkan_release_memory_allocator(void *user_context,
139+
struct halide_vulkan_memory_allocator *allocator,
140+
VkInstance instance,
141+
VkDevice device,
142+
VkPhysicalDevice physical_device);
108143
// --
109144

110145
// Override the default allocation callbacks (default uses Vulkan runtime implementation)

src/runtime/gpu_context_common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class GPUCompilationCache {
127127
}
128128

129129
for (int i = 0; i < (1 << log2_compilations_size); i++) {
130-
if (compilations[i].kernel_id > kInvalidId &&
130+
if (compilations[i].kernel_id > kDeletedId &&
131131
(all || (compilations[i].context == context)) &&
132132
compilations[i].use_count == 0) {
133133
debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state

src/runtime/runtime_api.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,10 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = {
213213
(void *)&halide_d3d12compute_release_context,
214214
(void *)&halide_d3d12compute_run,
215215
(void *)&halide_vulkan_acquire_context,
216+
(void *)&halide_vulkan_acquire_memory_allocator,
216217
(void *)&halide_vulkan_device_interface,
217218
(void *)&halide_vulkan_initialize_kernels,
219+
(void *)&halide_vulkan_release_memory_allocator,
218220
(void *)&halide_vulkan_release_context,
219221
(void *)&halide_vulkan_run,
220222
(void *)&halide_webgpu_device_interface,

src/runtime/vulkan.cpp

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,42 @@ using namespace Halide::Runtime::Internal::Vulkan;
1313

1414
// --------------------------------------------------------------------------
1515

16+
namespace Halide {
17+
namespace Runtime {
18+
namespace Internal {
19+
namespace Vulkan {
20+
21+
ALWAYS_INLINE int vk_load_external_context_functions(void *user_context, VkInstance instance, VkDevice device) {
22+
if (vkGetInstanceProcAddr == nullptr) {
23+
vk_load_vulkan_loader_functions(user_context);
24+
if (vkGetInstanceProcAddr == nullptr) {
25+
error(user_context) << "Vulkan: Failed to resolve loader functions for external context!\n";
26+
return halide_error_code_symbol_not_found;
27+
}
28+
}
29+
30+
vk_load_vulkan_instance_functions(user_context, instance);
31+
if (vkGetPhysicalDeviceProperties == nullptr || vkGetDeviceProcAddr == nullptr) {
32+
error(user_context) << "Vulkan: Failed to resolve instance functions for external context!\n";
33+
return halide_error_code_symbol_not_found;
34+
}
35+
36+
vk_load_vulkan_device_functions(user_context, device);
37+
if (vkCreateBuffer == nullptr || vkAllocateMemory == nullptr) {
38+
error(user_context) << "Vulkan: Failed to resolve device functions for external context!\n";
39+
return halide_error_code_symbol_not_found;
40+
}
41+
42+
return halide_error_code_success;
43+
}
44+
45+
} // namespace Vulkan
46+
} // namespace Internal
47+
} // namespace Runtime
48+
} // namespace Halide
49+
50+
// --------------------------------------------------------------------------
51+
1652
extern "C" {
1753

1854
// --------------------------------------------------------------------------
@@ -79,6 +115,82 @@ WEAK int halide_vulkan_release_context(void *user_context, VkInstance instance,
79115
return halide_error_code_success;
80116
}
81117

118+
WEAK int halide_vulkan_acquire_memory_allocator(void *user_context,
119+
halide_vulkan_memory_allocator **allocator,
120+
VkInstance instance,
121+
VkDevice device,
122+
VkPhysicalDevice physical_device) {
123+
if (allocator == nullptr) {
124+
error(user_context) << "Vulkan: allocator output pointer is null!\n";
125+
return halide_error_code_buffer_argument_is_null;
126+
}
127+
if (instance == VK_NULL_HANDLE || device == VK_NULL_HANDLE || physical_device == VK_NULL_HANDLE) {
128+
error(user_context) << "Vulkan: invalid external context handles for allocator acquisition!\n";
129+
return halide_error_code_device_interface_no_device;
130+
}
131+
132+
int error_code = vk_load_external_context_functions(user_context, instance, device);
133+
if (error_code != halide_error_code_success) {
134+
return error_code;
135+
}
136+
137+
VulkanMemoryAllocator *runtime_allocator =
138+
reinterpret_cast<VulkanMemoryAllocator *>(*allocator);
139+
if (runtime_allocator != nullptr) {
140+
if (runtime_allocator->current_device() != device ||
141+
runtime_allocator->current_physical_device() != physical_device) {
142+
error(user_context) << "Vulkan: external allocator does not match supplied device handles!\n";
143+
return halide_error_code_internal_error;
144+
}
145+
return halide_error_code_success;
146+
}
147+
148+
const VkAllocationCallbacks *alloc_callbacks =
149+
halide_vulkan_get_allocation_callbacks(user_context);
150+
runtime_allocator =
151+
vk_create_memory_allocator(user_context, device, physical_device, alloc_callbacks);
152+
if (runtime_allocator == nullptr) {
153+
error(user_context) << "Vulkan: Failed to create memory allocator for external context!\n";
154+
return halide_error_code_out_of_memory;
155+
}
156+
157+
*allocator = reinterpret_cast<halide_vulkan_memory_allocator *>(runtime_allocator);
158+
return halide_error_code_success;
159+
}
160+
161+
WEAK int halide_vulkan_release_memory_allocator(void *user_context,
162+
halide_vulkan_memory_allocator *allocator,
163+
VkInstance instance,
164+
VkDevice device,
165+
VkPhysicalDevice physical_device) {
166+
VulkanMemoryAllocator *runtime_allocator =
167+
reinterpret_cast<VulkanMemoryAllocator *>(allocator);
168+
if (runtime_allocator == nullptr) {
169+
return halide_error_code_success;
170+
}
171+
if (instance == VK_NULL_HANDLE || device == VK_NULL_HANDLE || physical_device == VK_NULL_HANDLE) {
172+
error(user_context) << "Vulkan: invalid external context handles for allocator release!\n";
173+
return halide_error_code_device_interface_no_device;
174+
}
175+
if (runtime_allocator->current_device() != device ||
176+
runtime_allocator->current_physical_device() != physical_device) {
177+
error(user_context) << "Vulkan: external allocator does not match supplied device handles during release!\n";
178+
return halide_error_code_internal_error;
179+
}
180+
181+
int error_code = vk_load_external_context_functions(user_context, instance, device);
182+
if (error_code != halide_error_code_success) {
183+
return error_code;
184+
}
185+
if (vkDestroyShaderModule == nullptr || vkFreeMemory == nullptr) {
186+
error(user_context) << "Vulkan: Failed to resolve device functions for external allocator release!\n";
187+
return halide_error_code_symbol_not_found;
188+
}
189+
190+
vk_destroy_shader_modules(user_context, runtime_allocator);
191+
return vk_destroy_memory_allocator(user_context, runtime_allocator);
192+
}
193+
82194
WEAK bool halide_vulkan_is_initialized() {
83195
halide_mutex_lock(&thread_lock);
84196
bool is_initialized = (cached_instance != nullptr) && (cached_device != nullptr);
@@ -159,7 +271,7 @@ WEAK int halide_vulkan_initialize_kernels(void *user_context, void **state_ptr,
159271
debug(user_context) << "halide_vulkan_initialize_kernels got compilation_cache mutex.\n";
160272

161273
VulkanCompilationCacheEntry *cache_entry = nullptr;
162-
if (!compilation_cache.kernel_state_setup(user_context, state_ptr, ctx.device, cache_entry,
274+
if (!compilation_cache.kernel_state_setup(user_context, state_ptr, ctx.allocator, cache_entry,
163275
Halide::Runtime::Internal::Vulkan::vk_compile_kernel_module,
164276
user_context, ctx.allocator, src, size)) {
165277
error(user_context) << "Vulkan: Failed to setup compilation cache!\n";
@@ -185,7 +297,7 @@ WEAK void halide_vulkan_finalize_kernels(void *user_context, void *state_ptr) {
185297

186298
VulkanContext ctx(user_context);
187299
if (ctx.error == halide_error_code_success) {
188-
compilation_cache.release_hold(user_context, ctx.device, state_ptr);
300+
compilation_cache.release_hold(user_context, ctx.allocator, state_ptr);
189301
}
190302

191303
#ifdef DEBUG_RUNTIME
@@ -1151,7 +1263,7 @@ WEAK int halide_vulkan_run(void *user_context,
11511263

11521264
// 1. Get the shader module cache entry
11531265
VulkanCompilationCacheEntry *cache_entry = nullptr;
1154-
bool found = compilation_cache.lookup(ctx.device, state_ptr, cache_entry);
1266+
bool found = compilation_cache.lookup(ctx.allocator, state_ptr, cache_entry);
11551267
if (!found || (cache_entry == nullptr)) {
11561268
error(user_context) << "Vulkan: Failed to locate shader module! Unable to proceed!\n";
11571269
return halide_error_code_internal_error;

src/runtime/vulkan_resources.h

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ struct VulkanCompilationCacheEntry {
7272
uint32_t module_count = 0;
7373
};
7474

75-
WEAK Halide::Internal::GPUCompilationCache<VkDevice, VulkanCompilationCacheEntry *> compilation_cache;
75+
WEAK Halide::Internal::GPUCompilationCache<VulkanMemoryAllocator *, VulkanCompilationCacheEntry *> compilation_cache;
7676

7777
// --------------------------------------------------------------------------
7878

@@ -1665,29 +1665,20 @@ void vk_destroy_compiled_shader_module(VulkanCompiledShaderModule *shader_module
16651665
return;
16661666
}
16671667

1668-
if (shader_module->descriptor_set_layouts) {
1669-
for (uint32_t n = 0; n < shader_module->shader_count; n++) {
1670-
debug(user_context) << " destroying descriptor set layout [" << n << "] " << shader_module->descriptor_set_layouts[n] << "\n";
1671-
vk_destroy_descriptor_set_layout(user_context, allocator, shader_module->descriptor_set_layouts[n]);
1672-
shader_module->descriptor_set_layouts[n] = VK_NULL_HANDLE;
1673-
}
1674-
debug(user_context) << " destroying descriptor set layout " << (void *)shader_module->descriptor_set_layouts << "\n";
1675-
vk_host_free(user_context, shader_module->descriptor_set_layouts, allocator->callbacks());
1676-
shader_module->descriptor_set_layouts = nullptr;
1677-
}
1678-
if (shader_module->pipeline_layout) {
1679-
debug(user_context) << " destroying pipeline layout " << (void *)shader_module->pipeline_layout << "\n";
1680-
vk_destroy_pipeline_layout(user_context, allocator, shader_module->pipeline_layout);
1681-
shader_module->pipeline_layout = VK_NULL_HANDLE;
1682-
}
1683-
16841668
if (shader_module->shader_bindings) {
16851669
#ifdef DEBUG_RUNTIME
16861670
debug(user_context)
16871671
<< " destroying shader bindings ("
16881672
<< "shader_module: " << shader_module << ", "
16891673
<< "shader_bindings: " << shader_module->shader_bindings << ")\n";
16901674
#endif
1675+
for (uint32_t n = 0; n < shader_module->shader_count; n++) {
1676+
if (shader_module->shader_bindings[n].compute_pipeline) {
1677+
debug(user_context) << " destroying shader binding compute pipeline [" << n << "]\n";
1678+
vk_destroy_compute_pipeline(user_context, allocator, shader_module->shader_bindings[n].compute_pipeline);
1679+
shader_module->shader_bindings[n].compute_pipeline = VK_NULL_HANDLE;
1680+
}
1681+
}
16911682
for (uint32_t n = 0; n < shader_module->shader_count; n++) {
16921683
debug(user_context) << " destroying shader binding [" << n << "] ";
16931684
if (shader_module->shader_bindings[n].entry_point_name) {
@@ -1717,15 +1708,25 @@ void vk_destroy_compiled_shader_module(VulkanCompiledShaderModule *shader_module
17171708
vk_host_free(user_context, shader_module->shader_bindings[n].shared_memory_allocations, allocator->callbacks());
17181709
shader_module->shader_bindings[n].shared_memory_allocations = nullptr;
17191710
}
1720-
if (shader_module->shader_bindings[n].compute_pipeline) {
1721-
debug(user_context) << " destroying shader binding compute pipeline [" << n << "]\n";
1722-
vk_destroy_compute_pipeline(user_context, allocator, shader_module->shader_bindings[n].compute_pipeline);
1723-
shader_module->shader_bindings[n].compute_pipeline = VK_NULL_HANDLE;
1724-
}
17251711
}
17261712
vk_host_free(user_context, shader_module->shader_bindings, allocator->callbacks());
17271713
shader_module->shader_bindings = nullptr;
17281714
}
1715+
if (shader_module->pipeline_layout) {
1716+
debug(user_context) << " destroying pipeline layout " << (void *)shader_module->pipeline_layout << "\n";
1717+
vk_destroy_pipeline_layout(user_context, allocator, shader_module->pipeline_layout);
1718+
shader_module->pipeline_layout = VK_NULL_HANDLE;
1719+
}
1720+
if (shader_module->descriptor_set_layouts) {
1721+
for (uint32_t n = 0; n < shader_module->shader_count; n++) {
1722+
debug(user_context) << " destroying descriptor set layout [" << n << "] " << shader_module->descriptor_set_layouts[n] << "\n";
1723+
vk_destroy_descriptor_set_layout(user_context, allocator, shader_module->descriptor_set_layouts[n]);
1724+
shader_module->descriptor_set_layouts[n] = VK_NULL_HANDLE;
1725+
}
1726+
debug(user_context) << " destroying descriptor set layout " << (void *)shader_module->descriptor_set_layouts << "\n";
1727+
vk_host_free(user_context, shader_module->descriptor_set_layouts, allocator->callbacks());
1728+
shader_module->descriptor_set_layouts = nullptr;
1729+
}
17291730
if (shader_module->shader_module) {
17301731
debug(user_context) << " destroying shader module " << (void *)shader_module->shader_module << "\n";
17311732
vkDestroyShaderModule(allocator->current_device(), shader_module->shader_module, allocator->callbacks());
@@ -1778,7 +1779,7 @@ int vk_destroy_shader_modules(void *user_context, VulkanMemoryAllocator *allocat
17781779
uint64_t t_before = halide_current_time_ns(user_context);
17791780
#endif
17801781
if (allocator != nullptr) {
1781-
compilation_cache.delete_context(user_context, allocator->current_device(), vk_destroy_compilation_cache_entry);
1782+
compilation_cache.delete_context(user_context, allocator, vk_destroy_compilation_cache_entry);
17821783
}
17831784

17841785
#ifdef DEBUG_RUNTIME

0 commit comments

Comments
 (0)