Skip to content

Commit ff03bd3

Browse files
committed
Vulkan: add external allocator lifecycle
1 parent d58798a commit ff03bd3

6 files changed

Lines changed: 364 additions & 39 deletions

File tree

src/runtime/HalideRuntimeVulkan.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,45 @@ extern int halide_vulkan_release_context(void *user_context,
105105
VkDevice device,
106106
VkQueue queue,
107107
VkDebugUtilsMessengerEXT messenger);
108+
109+
typedef int (*halide_vulkan_acquire_context_t)(void *user_context,
110+
struct halide_vulkan_memory_allocator **allocator,
111+
VkInstance *instance,
112+
VkDevice *device,
113+
VkPhysicalDevice *physical_device,
114+
VkQueue *queue,
115+
uint32_t *queue_family_index,
116+
VkDebugUtilsMessengerEXT *messenger,
117+
bool create);
118+
typedef int (*halide_vulkan_release_context_t)(void *user_context,
119+
VkInstance instance,
120+
VkDevice device,
121+
VkQueue queue,
122+
VkDebugUtilsMessengerEXT messenger);
123+
124+
/** Override the Vulkan context acquisition callback. Returns the previous handler. */
125+
extern halide_vulkan_acquire_context_t halide_set_vulkan_acquire_context(halide_vulkan_acquire_context_t handler);
126+
127+
/** Override the Vulkan context release callback. Returns the previous handler. */
128+
extern halide_vulkan_release_context_t halide_set_vulkan_release_context(halide_vulkan_release_context_t handler);
129+
130+
/** Create or validate Halide allocator state for an external Vulkan context.
131+
* The embedder owns the Vulkan handles and the returned allocator lifetime.
132+
*/
133+
extern int halide_vulkan_acquire_memory_allocator(void *user_context,
134+
struct halide_vulkan_memory_allocator **allocator,
135+
VkInstance instance,
136+
VkDevice device,
137+
VkPhysicalDevice physical_device);
138+
139+
/** Release Halide allocator state for an external Vulkan context.
140+
* The embedder must ensure no Halide work is still using it.
141+
*/
142+
extern int halide_vulkan_release_memory_allocator(void *user_context,
143+
struct halide_vulkan_memory_allocator *allocator,
144+
VkInstance instance,
145+
VkDevice device,
146+
VkPhysicalDevice physical_device);
108147
// --
109148

110149
// Override the default allocation callbacks (default uses Vulkan runtime implementation)

src/runtime/gpu_context_common.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class GPUCompilationCache {
127127
}
128128

129129
for (int i = 0; i < (1 << log2_compilations_size); i++) {
130-
if (compilations[i].kernel_id > kInvalidId &&
130+
if (compilations[i].kernel_id > kDeletedId &&
131131
(all || (compilations[i].context == context)) &&
132132
compilations[i].use_count == 0) {
133133
debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state
@@ -168,6 +168,31 @@ class GPUCompilationCache {
168168
release_context_already_locked(user_context, false, context, f);
169169
}
170170

171+
template<typename ShouldDeleteModuleT, typename FreeModuleT>
172+
void delete_context_if(void *user_context, ContextT context,
173+
ShouldDeleteModuleT &should_delete, FreeModuleT &f) {
174+
ScopedMutexLock lock_guard(&mutex);
175+
176+
if (count == 0) {
177+
return;
178+
}
179+
180+
for (int i = 0; i < (1 << log2_compilations_size); i++) {
181+
if (compilations[i].kernel_id > kDeletedId &&
182+
compilations[i].context == context &&
183+
compilations[i].use_count == 0 &&
184+
should_delete(compilations[i].module_state)) {
185+
debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state
186+
<< " id " << compilations[i].kernel_id
187+
<< " context " << compilations[i].context << "\n";
188+
f(compilations[i].module_state);
189+
compilations[i].module_state = nullptr;
190+
compilations[i].kernel_id = kDeletedId;
191+
count--;
192+
}
193+
}
194+
}
195+
171196
template<typename FreeModuleT>
172197
void release_all(void *user_context, FreeModuleT &f) {
173198
ScopedMutexLock lock_guard(&mutex);

src/runtime/runtime_api.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,14 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = {
213213
(void *)&halide_d3d12compute_release_context,
214214
(void *)&halide_d3d12compute_run,
215215
(void *)&halide_vulkan_acquire_context,
216+
(void *)&halide_vulkan_acquire_memory_allocator,
216217
(void *)&halide_vulkan_device_interface,
217218
(void *)&halide_vulkan_initialize_kernels,
219+
(void *)&halide_vulkan_release_memory_allocator,
218220
(void *)&halide_vulkan_release_context,
219221
(void *)&halide_vulkan_run,
222+
(void *)&halide_set_vulkan_acquire_context,
223+
(void *)&halide_set_vulkan_release_context,
220224
(void *)&halide_webgpu_device_interface,
221225
(void *)&halide_webgpu_initialize_kernels,
222226
(void *)&halide_webgpu_finalize_kernels,

src/runtime/vulkan.cpp

Lines changed: 156 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,34 @@ using namespace Halide::Runtime::Internal::Vulkan;
1313

1414
// --------------------------------------------------------------------------
1515

16-
extern "C" {
16+
namespace Halide {
17+
namespace Runtime {
18+
namespace Internal {
19+
namespace Vulkan {
1720

18-
// --------------------------------------------------------------------------
21+
ALWAYS_INLINE int vk_load_vulkan_interface(void *user_context, VkInstance instance, VkDevice device) {
22+
if (vkGetInstanceProcAddr == nullptr) {
23+
vk_load_vulkan_loader_functions(user_context);
24+
if (vkGetInstanceProcAddr == nullptr) {
25+
error(user_context) << "Vulkan: Failed to resolve loader functions!\n";
26+
return halide_error_code_symbol_not_found;
27+
}
28+
}
29+
30+
vk_load_vulkan_instance_functions(user_context, instance);
31+
if (vkGetPhysicalDeviceProperties == nullptr || vkGetDeviceProcAddr == nullptr) {
32+
error(user_context) << "Vulkan: Failed to resolve instance functions!\n";
33+
return halide_error_code_symbol_not_found;
34+
}
35+
36+
vk_load_vulkan_device_functions(user_context, device);
37+
if (vkCreateBuffer == nullptr || vkAllocateMemory == nullptr) {
38+
error(user_context) << "Vulkan: Failed to resolve device functions!\n";
39+
return halide_error_code_symbol_not_found;
40+
}
41+
42+
return halide_error_code_success;
43+
}
1944

2045
// The default implementation of halide_acquire_vulkan_context uses
2146
// the global pointers above, and serializes access with a spin lock.
@@ -29,15 +54,15 @@ extern "C" {
2954
// call to halide_release_vulkan_context. halide_acquire_vulkan_context
3055
// should block while a previous call (if any) has not yet been
3156
// released via halide_release_vulkan_context.
32-
WEAK int halide_vulkan_acquire_context(void *user_context,
33-
halide_vulkan_memory_allocator **allocator,
34-
VkInstance *instance,
35-
VkDevice *device,
36-
VkPhysicalDevice *physical_device,
37-
VkQueue *queue,
38-
uint32_t *queue_family_index,
39-
VkDebugUtilsMessengerEXT *messenger,
40-
bool create) {
57+
WEAK int default_vulkan_acquire_context(void *user_context,
58+
halide_vulkan_memory_allocator **allocator,
59+
VkInstance *instance,
60+
VkDevice *device,
61+
VkPhysicalDevice *physical_device,
62+
VkQueue *queue,
63+
uint32_t *queue_family_index,
64+
VkDebugUtilsMessengerEXT *messenger,
65+
bool create) {
4166
#ifdef DEBUG_RUNTIME
4267
halide_start_clock(user_context);
4368
#endif
@@ -74,11 +99,130 @@ WEAK int halide_vulkan_acquire_context(void *user_context,
7499
return halide_error_code_success;
75100
}
76101

77-
WEAK int halide_vulkan_release_context(void *user_context, VkInstance instance, VkDevice device, VkQueue queue, VkDebugUtilsMessengerEXT messenger) {
102+
WEAK int default_vulkan_release_context(void *user_context, VkInstance instance, VkDevice device, VkQueue queue, VkDebugUtilsMessengerEXT messenger) {
78103
halide_mutex_unlock(&thread_lock);
79104
return halide_error_code_success;
80105
}
81106

107+
WEAK halide_vulkan_acquire_context_t vulkan_acquire_context_handler =
108+
default_vulkan_acquire_context;
109+
WEAK halide_vulkan_release_context_t vulkan_release_context_handler =
110+
default_vulkan_release_context;
111+
112+
} // namespace Vulkan
113+
} // namespace Internal
114+
} // namespace Runtime
115+
} // namespace Halide
116+
117+
// --------------------------------------------------------------------------
118+
119+
extern "C" {
120+
121+
// --------------------------------------------------------------------------
122+
123+
WEAK int halide_vulkan_acquire_context(void *user_context,
124+
halide_vulkan_memory_allocator **allocator,
125+
VkInstance *instance,
126+
VkDevice *device,
127+
VkPhysicalDevice *physical_device,
128+
VkQueue *queue,
129+
uint32_t *queue_family_index,
130+
VkDebugUtilsMessengerEXT *messenger,
131+
bool create) {
132+
return vulkan_acquire_context_handler(user_context, allocator, instance, device,
133+
physical_device, queue, queue_family_index,
134+
messenger, create);
135+
}
136+
137+
WEAK int halide_vulkan_release_context(void *user_context, VkInstance instance, VkDevice device, VkQueue queue, VkDebugUtilsMessengerEXT messenger) {
138+
return vulkan_release_context_handler(user_context, instance, device, queue, messenger);
139+
}
140+
141+
WEAK halide_vulkan_acquire_context_t halide_set_vulkan_acquire_context(halide_vulkan_acquire_context_t handler) {
142+
halide_vulkan_acquire_context_t result = vulkan_acquire_context_handler;
143+
vulkan_acquire_context_handler = handler ? handler : default_vulkan_acquire_context;
144+
return result;
145+
}
146+
147+
WEAK halide_vulkan_release_context_t halide_set_vulkan_release_context(halide_vulkan_release_context_t handler) {
148+
halide_vulkan_release_context_t result = vulkan_release_context_handler;
149+
vulkan_release_context_handler = handler ? handler : default_vulkan_release_context;
150+
return result;
151+
}
152+
153+
WEAK int halide_vulkan_acquire_memory_allocator(void *user_context,
154+
halide_vulkan_memory_allocator **allocator,
155+
VkInstance instance,
156+
VkDevice device,
157+
VkPhysicalDevice physical_device) {
158+
if (allocator == nullptr) {
159+
error(user_context) << "Vulkan: allocator output pointer is null!\n";
160+
return halide_error_code_buffer_argument_is_null;
161+
}
162+
if (instance == VK_NULL_HANDLE || device == VK_NULL_HANDLE || physical_device == VK_NULL_HANDLE) {
163+
error(user_context) << "Vulkan: invalid external context handles for allocator acquisition!\n";
164+
return halide_error_code_device_interface_no_device;
165+
}
166+
167+
VulkanMemoryAllocator *runtime_allocator =
168+
reinterpret_cast<VulkanMemoryAllocator *>(*allocator);
169+
if (runtime_allocator != nullptr) {
170+
if (runtime_allocator->current_device() != device ||
171+
runtime_allocator->current_physical_device() != physical_device) {
172+
error(user_context) << "Vulkan: external allocator does not match supplied device handles!\n";
173+
return halide_error_code_internal_error;
174+
}
175+
return halide_error_code_success;
176+
}
177+
178+
const VkAllocationCallbacks *alloc_callbacks =
179+
halide_vulkan_get_allocation_callbacks(user_context);
180+
181+
int error_code = vk_load_vulkan_interface(user_context, instance, device);
182+
if (error_code != halide_error_code_success) {
183+
return error_code;
184+
}
185+
186+
runtime_allocator =
187+
vk_create_memory_allocator(user_context, device, physical_device, alloc_callbacks);
188+
if (runtime_allocator == nullptr) {
189+
error(user_context) << "Vulkan: Failed to create memory allocator for external context!\n";
190+
return halide_error_code_out_of_memory;
191+
}
192+
193+
*allocator = reinterpret_cast<halide_vulkan_memory_allocator *>(runtime_allocator);
194+
return halide_error_code_success;
195+
}
196+
197+
WEAK int halide_vulkan_release_memory_allocator(void *user_context,
198+
halide_vulkan_memory_allocator *allocator,
199+
VkInstance instance,
200+
VkDevice device,
201+
VkPhysicalDevice physical_device) {
202+
VulkanMemoryAllocator *runtime_allocator =
203+
reinterpret_cast<VulkanMemoryAllocator *>(allocator);
204+
if (runtime_allocator == nullptr) {
205+
return halide_error_code_success;
206+
}
207+
if (instance == VK_NULL_HANDLE || device == VK_NULL_HANDLE || physical_device == VK_NULL_HANDLE) {
208+
error(user_context) << "Vulkan: invalid external context handles for allocator release!\n";
209+
return halide_error_code_device_interface_no_device;
210+
}
211+
if (runtime_allocator->current_device() != device ||
212+
runtime_allocator->current_physical_device() != physical_device) {
213+
error(user_context) << "Vulkan: external allocator does not match supplied device handles during release!\n";
214+
return halide_error_code_internal_error;
215+
}
216+
217+
if (vkDestroyShaderModule == nullptr || vkFreeMemory == nullptr) {
218+
error(user_context) << "Vulkan: Failed to resolve device functions for external allocator release!\n";
219+
return halide_error_code_symbol_not_found;
220+
}
221+
222+
vk_destroy_shader_modules(user_context, runtime_allocator);
223+
return vk_destroy_memory_allocator(user_context, runtime_allocator);
224+
}
225+
82226
WEAK bool halide_vulkan_is_initialized() {
83227
halide_mutex_lock(&thread_lock);
84228
bool is_initialized = (cached_instance != nullptr) && (cached_device != nullptr);

0 commit comments

Comments
 (0)