Skip to content

Commit f463ccc

Browse files
committed
Switch to 16x16x16 cooperative matrix
~30% faster + compatible with AMD
1 parent 83ab0f7 commit f463ccc

52 files changed

Lines changed: 421 additions & 295 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

internal/RendererGPU.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -531,28 +531,28 @@ Ray::NS::Renderer::InitUNetFilter(const bool alias_memory,
531531
Buffer temp_upload_buf;
532532

533533
if (use_fp16_) {
534-
const int total_count = SetupUNetWeights<uint16_t>(8, nullptr, nullptr);
534+
const int total_count = SetupUNetWeights<uint16_t>(16, nullptr, nullptr);
535535

536536
temp_upload_buf =
537537
Buffer{"UNet Weights CBN Upload", ctx_.get(), eBufType::Upload, uint32_t(total_count * sizeof(uint16_t))};
538538
unet_weights_ =
539539
Buffer{"UNet Weights CBN", ctx_.get(), eBufType::Storage, uint32_t(total_count * sizeof(uint16_t))};
540540

541541
uint16_t *out_weights = (uint16_t *)temp_upload_buf.Map();
542-
SetupUNetWeights(8, &unet_offsets_, out_weights);
542+
SetupUNetWeights(16, &unet_offsets_, out_weights);
543543
temp_upload_buf.Unmap();
544544

545545
CopyBufferToBuffer(temp_upload_buf, 0, unet_weights_, 0, sizeof(uint16_t) * total_count, cmd_buf);
546546
} else {
547-
const int total_count = SetupUNetWeights<float>(8, nullptr, nullptr);
547+
const int total_count = SetupUNetWeights<float>(16, nullptr, nullptr);
548548

549549
temp_upload_buf =
550550
Buffer{"UNet Weights CBN Upload", ctx_.get(), eBufType::Upload, uint32_t(total_count * sizeof(float))};
551551
unet_weights_ =
552552
Buffer{"UNet Weights CBN", ctx_.get(), eBufType::Storage, uint32_t(total_count * sizeof(float))};
553553

554554
float *out_weights = (float *)temp_upload_buf.Map();
555-
SetupUNetWeights(8, &unet_offsets_, out_weights);
555+
SetupUNetWeights(16, &unet_offsets_, out_weights);
556556
temp_upload_buf.Unmap();
557557

558558
CopyBufferToBuffer(temp_upload_buf, 0, unet_weights_, 0, sizeof(float) * total_count, cmd_buf);

internal/RendererVK.cpp

Lines changed: 48 additions & 48 deletions
Large diffs are not rendered by default.

internal/Vk/ContextVK.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ bool Ray::Vk::Context::Init(ILog *log, const VulkanDevice &vk_device, const Vulk
215215

216216
CheckVkPhysicalDeviceFeatures(api_, physical_device_, device_properties_, mem_properties_, graphics_family_index_,
217217
raytracing_supported_, ray_query_supported_, fp16_supported_, int64_supported_,
218-
int64_atomics_supported_, coop_matrix_supported_, pageable_memory_supported_);
218+
int64_atomics_supported_, coop_matrix_size_, pageable_memory_supported_);
219219

220220
// mask out unsupported stages
221221
if (!raytracing_supported_) {
@@ -227,7 +227,7 @@ bool Ray::Vk::Context::Init(ILog *log, const VulkanDevice &vk_device, const Vulk
227227

228228
if (!external_ && !InitVkDevice(api_, device_, physical_device_, graphics_family_index_, raytracing_supported_,
229229
ray_query_supported_, fp16_supported_, int64_supported_, int64_atomics_supported_,
230-
coop_matrix_supported_, pageable_memory_supported_, log)) {
230+
coop_matrix_size_[0] != -1, pageable_memory_supported_, log)) {
231231
return false;
232232
}
233233

@@ -565,7 +565,7 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
565565
uint32_t &out_graphics_family_index,
566566
bool &out_raytracing_supported, bool &out_ray_query_supported,
567567
bool &out_shader_fp16_supported, bool &out_shader_int64_supported,
568-
bool &out_int64_atomics_supported, bool &out_coop_matrix_supported,
568+
bool &out_int64_atomics_supported, int out_coop_matrix_size[3],
569569
bool &out_pageable_memory_supported) {
570570
api.vkGetPhysicalDeviceProperties(physical_device, &out_device_properties);
571571
api.vkGetPhysicalDeviceMemoryProperties(physical_device, &out_mem_properties);
@@ -591,16 +591,19 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
591591

592592
bool acc_struct_supported = false, raytracing_supported = false, ray_query_supported = false,
593593
shader_fp16_supported = false, shader_int64_supported = false, storage_fp16_supported = false,
594-
coop_matrix_supported = false, shader_buf_int64_atomics_supported = false, memory_priority_supported = false,
594+
shader_buf_int64_atomics_supported = false, memory_priority_supported = false,
595595
pageable_memory_supported = false;
596596

597+
int coop_matrix_size[3] = {-1, -1, -1};
598+
597599
{ // check for features support
598600
uint32_t extension_count;
599601
api.vkEnumerateDeviceExtensionProperties(physical_device, nullptr, &extension_count, nullptr);
600602

601603
SmallVector<VkExtensionProperties, 16> available_extensions(extension_count);
602604
api.vkEnumerateDeviceExtensionProperties(physical_device, nullptr, &extension_count, &available_extensions[0]);
603605

606+
bool coop_matrix_supported = false;
604607
for (uint32_t j = 0; j < extension_count; j++) {
605608
const VkExtensionProperties &ext = available_extensions[j];
606609

@@ -671,7 +674,10 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
671674
for (const VkCooperativeMatrixPropertiesKHR &p : coop_matrix_props) {
672675
if (p.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && p.BType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
673676
p.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && p.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
674-
p.MSize == 16 && p.NSize == 8 && p.KSize == 8 && p.scope == VK_SCOPE_SUBGROUP_KHR) {
677+
p.MSize == 16 && p.NSize == 16 && p.KSize == 16 && p.scope == VK_SCOPE_SUBGROUP_KHR) {
678+
coop_matrix_size[0] = 16;
679+
coop_matrix_size[1] = 16;
680+
coop_matrix_size[2] = 16;
675681
found = true;
676682
break;
677683
}
@@ -685,14 +691,15 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
685691
out_shader_fp16_supported = (shader_fp16_supported && storage_fp16_supported);
686692
out_shader_int64_supported = shader_int64_supported;
687693
out_int64_atomics_supported = shader_buf_int64_atomics_supported;
688-
out_coop_matrix_supported = coop_matrix_supported;
694+
memcpy(out_coop_matrix_size, coop_matrix_size, 3 * sizeof(int));
689695
out_pageable_memory_supported = (memory_priority_supported && pageable_memory_supported);
690696
}
691697

692698
bool Ray::Vk::Context::InitVkDevice(const Api &api, VkDevice &device, VkPhysicalDevice physical_device,
693-
uint32_t graphics_family_index, bool enable_raytracing, bool enable_ray_query,
694-
bool enable_fp16, bool enable_int64, bool enable_int64_atomics,
695-
bool enable_coop_matrix, bool enable_pageable_memory, ILog *log) {
699+
const uint32_t graphics_family_index, const bool enable_raytracing,
700+
const bool enable_ray_query, const bool enable_fp16, const bool enable_int64,
701+
const bool enable_int64_atomics, const bool enable_coop_matrix,
702+
const bool enable_pageable_memory, ILog *log) {
696703
VkDeviceQueueCreateInfo queue_create_infos[2] = {{}, {}};
697704
const float queue_priorities[] = {1.0f};
698705

internal/Vk/ContextVK.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class Context {
4646

4747
bool subgroup_supported_ = false;
4848

49-
bool coop_matrix_supported_ = false;
49+
int coop_matrix_size_[3] = {-1, -1, -1};
5050

5151
bool pageable_memory_supported_ = false;
5252

@@ -94,7 +94,7 @@ class Context {
9494
bool int64_supported() const { return int64_supported_; }
9595
bool int64_atomics_supported() const { return int64_atomics_supported_; }
9696
bool subgroup_supported() const { return subgroup_supported_; }
97-
bool coop_matrix_supported() const { return coop_matrix_supported_; }
97+
const int *coop_matrix_size() const { return coop_matrix_size_; }
9898

9999
uint32_t supported_stages_mask() const { return supported_stages_mask_; };
100100
bool image_blit_supported() const { return true; }
@@ -148,15 +148,15 @@ class Context {
148148
private:
149149
static bool InitVkInstance(const Api &api, VkInstance &instance, const char *enabled_layers[],
150150
int enabled_layers_count, int validation_level, ILog *log);
151-
static bool ChooseVkPhysicalDevice(const Api &api, VkPhysicalDevice &physical_device, std::string_view preferred_device,
152-
VkInstance instance, ILog *log);
151+
static bool ChooseVkPhysicalDevice(const Api &api, VkPhysicalDevice &physical_device,
152+
std::string_view preferred_device, VkInstance instance, ILog *log);
153153
static void CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalDevice &physical_device,
154154
VkPhysicalDeviceProperties &device_properties,
155155
VkPhysicalDeviceMemoryProperties &mem_properties,
156156
uint32_t &graphics_family_index, bool &out_raytracing_supported,
157157
bool &out_ray_query_supported, bool &out_shader_fp16_supported,
158158
bool &out_shader_int64_supported, bool &out_int64_atomics_supported,
159-
bool &out_coop_matrix_supported, bool &out_pageable_memory_supported);
159+
int out_coop_matrix_size[3], bool &out_pageable_memory_supported);
160160
static bool InitVkDevice(const Api &api, VkDevice &device, VkPhysicalDevice physical_device,
161161
uint32_t graphics_family_index, bool enable_raytracing, bool enable_ray_query,
162162
bool enable_fp16, bool enable_int64, bool enable_int64_atomics, bool enable_coop_matrix,

0 commit comments

Comments
 (0)