@@ -215,7 +215,7 @@ bool Ray::Vk::Context::Init(ILog *log, const VulkanDevice &vk_device, const Vulk
215215
216216 CheckVkPhysicalDeviceFeatures (api_, physical_device_, device_properties_, mem_properties_, graphics_family_index_,
217217 raytracing_supported_, ray_query_supported_, fp16_supported_, int64_supported_,
218- int64_atomics_supported_, coop_matrix_supported_ , pageable_memory_supported_);
218+ int64_atomics_supported_, coop_matrix_size_ , pageable_memory_supported_);
219219
220220 // mask out unsupported stages
221221 if (!raytracing_supported_) {
@@ -227,7 +227,7 @@ bool Ray::Vk::Context::Init(ILog *log, const VulkanDevice &vk_device, const Vulk
227227
228228 if (!external_ && !InitVkDevice (api_, device_, physical_device_, graphics_family_index_, raytracing_supported_,
229229 ray_query_supported_, fp16_supported_, int64_supported_, int64_atomics_supported_,
230- coop_matrix_supported_ , pageable_memory_supported_, log)) {
230+ coop_matrix_size_[ 0 ] != - 1 , pageable_memory_supported_, log)) {
231231 return false ;
232232 }
233233
@@ -565,7 +565,7 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
565565 uint32_t &out_graphics_family_index,
566566 bool &out_raytracing_supported, bool &out_ray_query_supported,
567567 bool &out_shader_fp16_supported, bool &out_shader_int64_supported,
568- bool &out_int64_atomics_supported, bool &out_coop_matrix_supported ,
568+ bool &out_int64_atomics_supported, int out_coop_matrix_size[ 3 ] ,
569569 bool &out_pageable_memory_supported) {
570570 api.vkGetPhysicalDeviceProperties (physical_device, &out_device_properties);
571571 api.vkGetPhysicalDeviceMemoryProperties (physical_device, &out_mem_properties);
@@ -591,16 +591,19 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
591591
592592 bool acc_struct_supported = false , raytracing_supported = false , ray_query_supported = false ,
593593 shader_fp16_supported = false , shader_int64_supported = false , storage_fp16_supported = false ,
594- coop_matrix_supported = false , shader_buf_int64_atomics_supported = false , memory_priority_supported = false ,
594+ shader_buf_int64_atomics_supported = false , memory_priority_supported = false ,
595595 pageable_memory_supported = false ;
596596
597+ int coop_matrix_size[3 ] = {-1 , -1 , -1 };
598+
597599 { // check for features support
598600 uint32_t extension_count;
599601 api.vkEnumerateDeviceExtensionProperties (physical_device, nullptr , &extension_count, nullptr );
600602
601603 SmallVector<VkExtensionProperties, 16 > available_extensions (extension_count);
602604 api.vkEnumerateDeviceExtensionProperties (physical_device, nullptr , &extension_count, &available_extensions[0 ]);
603605
606+ bool coop_matrix_supported = false ;
604607 for (uint32_t j = 0 ; j < extension_count; j++) {
605608 const VkExtensionProperties &ext = available_extensions[j];
606609
@@ -671,7 +674,10 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
671674 for (const VkCooperativeMatrixPropertiesKHR &p : coop_matrix_props) {
672675 if (p.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && p.BType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
673676 p.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && p.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
674- p.MSize == 16 && p.NSize == 8 && p.KSize == 8 && p.scope == VK_SCOPE_SUBGROUP_KHR ) {
677+ p.MSize == 16 && p.NSize == 16 && p.KSize == 16 && p.scope == VK_SCOPE_SUBGROUP_KHR ) {
678+ coop_matrix_size[0 ] = 16 ;
679+ coop_matrix_size[1 ] = 16 ;
680+ coop_matrix_size[2 ] = 16 ;
675681 found = true ;
676682 break ;
677683 }
@@ -685,14 +691,15 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
685691 out_shader_fp16_supported = (shader_fp16_supported && storage_fp16_supported);
686692 out_shader_int64_supported = shader_int64_supported;
687693 out_int64_atomics_supported = shader_buf_int64_atomics_supported;
688- out_coop_matrix_supported = coop_matrix_supported ;
694+ memcpy (out_coop_matrix_size, coop_matrix_size, 3 * sizeof ( int )) ;
689695 out_pageable_memory_supported = (memory_priority_supported && pageable_memory_supported);
690696}
691697
692698bool Ray::Vk::Context::InitVkDevice (const Api &api, VkDevice &device, VkPhysicalDevice physical_device,
693- uint32_t graphics_family_index, bool enable_raytracing, bool enable_ray_query,
694- bool enable_fp16, bool enable_int64, bool enable_int64_atomics,
695- bool enable_coop_matrix, bool enable_pageable_memory, ILog *log) {
699+ const uint32_t graphics_family_index, const bool enable_raytracing,
700+ const bool enable_ray_query, const bool enable_fp16, const bool enable_int64,
701+ const bool enable_int64_atomics, const bool enable_coop_matrix,
702+ const bool enable_pageable_memory, ILog *log) {
696703 VkDeviceQueueCreateInfo queue_create_infos[2 ] = {{}, {}};
697704 const float queue_priorities[] = {1 .0f };
698705
0 commit comments