@@ -13,6 +13,42 @@ using namespace Halide::Runtime::Internal::Vulkan;
1313
1414// --------------------------------------------------------------------------
1515
16+ namespace Halide {
17+ namespace Runtime {
18+ namespace Internal {
19+ namespace Vulkan {
20+
21+ ALWAYS_INLINE int vk_load_external_context_functions (void *user_context, VkInstance instance, VkDevice device) {
22+ if (vkGetInstanceProcAddr == nullptr ) {
23+ vk_load_vulkan_loader_functions (user_context);
24+ if (vkGetInstanceProcAddr == nullptr ) {
25+ error (user_context) << " Vulkan: Failed to resolve loader functions for external context!\n " ;
26+ return halide_error_code_symbol_not_found;
27+ }
28+ }
29+
30+ vk_load_vulkan_instance_functions (user_context, instance);
31+ if (vkGetPhysicalDeviceProperties == nullptr || vkGetDeviceProcAddr == nullptr ) {
32+ error (user_context) << " Vulkan: Failed to resolve instance functions for external context!\n " ;
33+ return halide_error_code_symbol_not_found;
34+ }
35+
36+ vk_load_vulkan_device_functions (user_context, device);
37+ if (vkCreateBuffer == nullptr || vkAllocateMemory == nullptr ) {
38+ error (user_context) << " Vulkan: Failed to resolve device functions for external context!\n " ;
39+ return halide_error_code_symbol_not_found;
40+ }
41+
42+ return halide_error_code_success;
43+ }
44+
45+ } // namespace Vulkan
46+ } // namespace Internal
47+ } // namespace Runtime
48+ } // namespace Halide
49+
50+ // --------------------------------------------------------------------------
51+
1652extern " C" {
1753
1854// --------------------------------------------------------------------------
@@ -79,6 +115,82 @@ WEAK int halide_vulkan_release_context(void *user_context, VkInstance instance,
79115 return halide_error_code_success;
80116}
81117
118+ WEAK int halide_vulkan_acquire_memory_allocator (void *user_context,
119+ halide_vulkan_memory_allocator **allocator,
120+ VkInstance instance,
121+ VkDevice device,
122+ VkPhysicalDevice physical_device) {
123+ if (allocator == nullptr ) {
124+ error (user_context) << " Vulkan: allocator output pointer is null!\n " ;
125+ return halide_error_code_buffer_argument_is_null;
126+ }
127+ if (instance == VK_NULL_HANDLE || device == VK_NULL_HANDLE || physical_device == VK_NULL_HANDLE) {
128+ error (user_context) << " Vulkan: invalid external context handles for allocator acquisition!\n " ;
129+ return halide_error_code_device_interface_no_device;
130+ }
131+
132+ int error_code = vk_load_external_context_functions (user_context, instance, device);
133+ if (error_code != halide_error_code_success) {
134+ return error_code;
135+ }
136+
137+ VulkanMemoryAllocator *runtime_allocator =
138+ reinterpret_cast <VulkanMemoryAllocator *>(*allocator);
139+ if (runtime_allocator != nullptr ) {
140+ if (runtime_allocator->current_device () != device ||
141+ runtime_allocator->current_physical_device () != physical_device) {
142+ error (user_context) << " Vulkan: external allocator does not match supplied device handles!\n " ;
143+ return halide_error_code_internal_error;
144+ }
145+ return halide_error_code_success;
146+ }
147+
148+ const VkAllocationCallbacks *alloc_callbacks =
149+ halide_vulkan_get_allocation_callbacks (user_context);
150+ runtime_allocator =
151+ vk_create_memory_allocator (user_context, device, physical_device, alloc_callbacks);
152+ if (runtime_allocator == nullptr ) {
153+ error (user_context) << " Vulkan: Failed to create memory allocator for external context!\n " ;
154+ return halide_error_code_out_of_memory;
155+ }
156+
157+ *allocator = reinterpret_cast <halide_vulkan_memory_allocator *>(runtime_allocator);
158+ return halide_error_code_success;
159+ }
160+
161+ WEAK int halide_vulkan_release_memory_allocator (void *user_context,
162+ halide_vulkan_memory_allocator *allocator,
163+ VkInstance instance,
164+ VkDevice device,
165+ VkPhysicalDevice physical_device) {
166+ VulkanMemoryAllocator *runtime_allocator =
167+ reinterpret_cast <VulkanMemoryAllocator *>(allocator);
168+ if (runtime_allocator == nullptr ) {
169+ return halide_error_code_success;
170+ }
171+ if (instance == VK_NULL_HANDLE || device == VK_NULL_HANDLE || physical_device == VK_NULL_HANDLE) {
172+ error (user_context) << " Vulkan: invalid external context handles for allocator release!\n " ;
173+ return halide_error_code_device_interface_no_device;
174+ }
175+ if (runtime_allocator->current_device () != device ||
176+ runtime_allocator->current_physical_device () != physical_device) {
177+ error (user_context) << " Vulkan: external allocator does not match supplied device handles during release!\n " ;
178+ return halide_error_code_internal_error;
179+ }
180+
181+ int error_code = vk_load_external_context_functions (user_context, instance, device);
182+ if (error_code != halide_error_code_success) {
183+ return error_code;
184+ }
185+ if (vkDestroyShaderModule == nullptr || vkFreeMemory == nullptr ) {
186+ error (user_context) << " Vulkan: Failed to resolve device functions for external allocator release!\n " ;
187+ return halide_error_code_symbol_not_found;
188+ }
189+
190+ vk_destroy_shader_modules (user_context, runtime_allocator);
191+ return vk_destroy_memory_allocator (user_context, runtime_allocator);
192+ }
193+
82194WEAK bool halide_vulkan_is_initialized () {
83195 halide_mutex_lock (&thread_lock);
84196 bool is_initialized = (cached_instance != nullptr ) && (cached_device != nullptr );
@@ -159,7 +271,7 @@ WEAK int halide_vulkan_initialize_kernels(void *user_context, void **state_ptr,
159271 debug (user_context) << " halide_vulkan_initialize_kernels got compilation_cache mutex.\n " ;
160272
161273 VulkanCompilationCacheEntry *cache_entry = nullptr ;
162- if (!compilation_cache.kernel_state_setup (user_context, state_ptr, ctx.device , cache_entry,
274+ if (!compilation_cache.kernel_state_setup (user_context, state_ptr, ctx.allocator , cache_entry,
163275 Halide::Runtime::Internal::Vulkan::vk_compile_kernel_module,
164276 user_context, ctx.allocator , src, size)) {
165277 error (user_context) << " Vulkan: Failed to setup compilation cache!\n " ;
@@ -185,7 +297,7 @@ WEAK void halide_vulkan_finalize_kernels(void *user_context, void *state_ptr) {
185297
186298 VulkanContext ctx (user_context);
187299 if (ctx.error == halide_error_code_success) {
188- compilation_cache.release_hold (user_context, ctx.device , state_ptr);
300+ compilation_cache.release_hold (user_context, ctx.allocator , state_ptr);
189301 }
190302
191303#ifdef DEBUG_RUNTIME
@@ -1151,7 +1263,7 @@ WEAK int halide_vulkan_run(void *user_context,
11511263
11521264 // 1. Get the shader module cache entry
11531265 VulkanCompilationCacheEntry *cache_entry = nullptr ;
1154- bool found = compilation_cache.lookup (ctx.device , state_ptr, cache_entry);
1266+ bool found = compilation_cache.lookup (ctx.allocator , state_ptr, cache_entry);
11551267 if (!found || (cache_entry == nullptr )) {
11561268 error (user_context) << " Vulkan: Failed to locate shader module! Unable to proceed!\n " ;
11571269 return halide_error_code_internal_error;
0 commit comments