Skip to content

Commit 7161216

Browse files
committed
Enable full subgroup together with coop matrices
1 parent 5b8de05 commit 7161216

4 files changed

Lines changed: 10 additions & 3 deletions

File tree

internal/RendererVK.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1948,14 +1948,16 @@ bool Ray::Vk::Renderer::InitUNetFilterPipelines(
19481948
internal_shaders_output_convolution_32_3_img_coop_8x8x16_CF32_comp_spv),
19491949
eShaderType::Comp, false}};
19501950
const int subgroup_size = use_coop_matrix_ ? 32 : -1;
1951+
const bool require_full_subgroup = use_coop_matrix_;
19511952
parallel_for(0, int(shaders_to_init.size()), [&](const int i) {
19521953
std::get<6>(shaders_to_init[i]) =
19531954
std::get<0>(shaders_to_init[i])
19541955
.Init(std::get<3>(shaders_to_init[i]), ctx_.get(), Inflate(std::get<4>(shaders_to_init[i])),
19551956
std::get<5>(shaders_to_init[i]), log);
19561957
std::get<1>(shaders_to_init[i]) =
19571958
Program{std::get<3>(shaders_to_init[i]), ctx_.get(), &std::get<0>(shaders_to_init[i]), log};
1958-
std::get<2>(shaders_to_init[i]).Init(ctx_.get(), &std::get<1>(shaders_to_init[i]), log, subgroup_size);
1959+
std::get<2>(shaders_to_init[i])
1960+
.Init(ctx_.get(), &std::get<1>(shaders_to_init[i]), log, subgroup_size, require_full_subgroup);
19591961
});
19601962

19611963
bool result = true;

internal/Vk/ContextVK.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ bool Ray::Vk::Context::Init(ILog *log, const VulkanDevice &vk_device, const Vulk
255255
api_.vkGetPhysicalDeviceFeatures2KHR(physical_device_, &feat2);
256256

257257
subgroup_size_control_supported_ &= (subgroup_size_control_features.subgroupSizeControl == VK_TRUE);
258+
subgroup_size_control_supported_ &= (subgroup_size_control_features.computeFullSubgroups == VK_TRUE);
258259
}
259260

260261
if (!InitCommandBuffers(api_, command_pool_, temp_command_pool_, draw_cmd_bufs_, render_finished_semaphores_,
@@ -893,6 +894,7 @@ bool Ray::Vk::Context::InitVkDevice(const Api &api, VkDevice &device, VkPhysical
893894
VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features = {
894895
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT};
895896
subgroup_size_control_features.subgroupSizeControl = VK_TRUE;
897+
subgroup_size_control_features.computeFullSubgroups = VK_TRUE;
896898
if (enable_subgroup_size_control) {
897899
(*pp_next) = &subgroup_size_control_features;
898900
pp_next = &subgroup_size_control_features.pNext;

internal/Vk/PipelineVK.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ bool Ray::Vk::Pipeline::Init(Context *ctx, const RastState &rast_state, Program
336336
subpass_index, log);
337337
}
338338

339-
bool Ray::Vk::Pipeline::Init(Context *ctx, Program *prog, ILog *log, const int subgroup_size) {
339+
bool Ray::Vk::Pipeline::Init(Context *ctx, Program *prog, ILog *log, const int subgroup_size, const bool require_full_subgroup) {
340340
Destroy();
341341

342342
ePipelineType type = ePipelineType::Undefined;
@@ -413,6 +413,9 @@ bool Ray::Vk::Pipeline::Init(Context *ctx, Program *prog, ILog *log, const int s
413413
stage_info.module = prog->shader(eShaderType(i))->module();
414414
stage_info.pName = "main";
415415
stage_info.pSpecializationInfo = nullptr;
416+
if (ctx->subgroup_size_control_supported() && require_full_subgroup) {
417+
stage_info.flags |= VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT;
418+
}
416419

417420
if (ctx->subgroup_size_control_supported() && subgroup_size != -1) {
418421
stage_info.pNext = &subgroup_size_info;

internal/Vk/PipelineVK.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class Pipeline {
9292
bool Init(Context *ctx, const RastState &rast_state, Program *prog, const VertexInput *vtx_input,
9393
Span<const RenderTarget> color_attachments, RenderTarget depth_attachment, uint32_t subpass_index,
9494
ILog *log);
95-
bool Init(Context *ctx, Program *prog, ILog *log, int subgroup_size = -1);
95+
bool Init(Context *ctx, Program *prog, ILog *log, int subgroup_size = -1, bool require_full_subgroup = false);
9696
};
9797
} // namespace Vk
9898
} // namespace Ray

0 commit comments

Comments
 (0)