Skip to content

Commit 8dcb8ac

Browse files
hzqstTheMostDiligent
authored andcommitted
PipelineStateVk: build specialization constant data
1 parent d0e0a6d commit 8dcb8ac

2 files changed

Lines changed: 140 additions & 5 deletions

File tree

Graphics/GraphicsEngineVulkan/include/PipelineStateVkImpl.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ namespace Diligent
4949

5050
class DeviceContextVkImpl;
5151

52+
namespace
53+
{
54+
struct ShaderStageSpecializationData;
55+
}
56+
5257
/// Pipeline state object implementation in Vulkan backend.
5358
class PipelineStateVkImpl final : public PipelineStateBase<EngineVkImplTraits>
5459
{
@@ -131,7 +136,8 @@ class PipelineStateVkImpl final : public PipelineStateBase<EngineVkImplTraits>
131136
template <typename PSOCreateInfoType>
132137
TShaderStages InitInternalObjects(const PSOCreateInfoType& CreateInfo,
133138
std::vector<VkPipelineShaderStageCreateInfo>& vkShaderStages,
134-
std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules) noexcept(false);
139+
std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules,
140+
std::vector<ShaderStageSpecializationData>& SpecDataPerStage) noexcept(false);
135141

136142
void InitPipelineLayout(const PipelineStateCreateInfo& CreateInfo,
137143
TShaderStages& ShaderStages) noexcept(false);

Graphics/GraphicsEngineVulkan/src/PipelineStateVkImpl.cpp

Lines changed: 133 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,127 @@ void InitPipelineShaderStages(const VulkanUtilities::LogicalDevice&
9494
VERIFY_EXPR(ShaderModules.size() == Stages.size());
9595
}
9696

97+
// Per-shader-stage specialization constant data for Vulkan pipeline creation.
98+
// Holds VkSpecializationMapEntry array, contiguous data blob, and the
99+
// VkSpecializationInfo that references them. Lifetime must exceed the
100+
// vkCreate*Pipelines call.
101+
struct ShaderStageSpecializationData
102+
{
103+
std::vector<VkSpecializationMapEntry> MapEntries;
104+
std::vector<Uint8> DataBlob;
105+
VkSpecializationInfo Info{};
106+
};
107+
108+
// Iterates SPIR-V reflected specialization constants and matches them to
109+
// user-provided SpecializationConstant entries by name. If a reflected
110+
// constant has no matching user entry the constant is silently skipped,
111+
// which allows the user to supply a superset of constants shared across
112+
// multiple pipelines / stages.
113+
//
114+
// Parameters:
115+
// ShaderStages - shader stages extracted from the PSO create info
116+
// NumSpecializationConstants - number of user-provided specialization constants
117+
// pSpecializationConstants - user-provided specialization constant array
118+
// PSODesc - pipeline state description (for error messages)
119+
// vkStages [in/out] - VkPipelineShaderStageCreateInfo array to patch
120+
// SpecDataPerStage [out] - per-stage specialization data (must outlive vkCreate*Pipelines)
121+
//
122+
// Throws on size mismatch between user-provided and reflected constants.
123+
void BuildSpecializationData(const PipelineStateVkImpl::TShaderStages& ShaderStages,
124+
Uint32 NumSpecializationConstants,
125+
const SpecializationConstant* pSpecializationConstants,
126+
const PipelineStateDesc& PSODesc,
127+
std::vector<VkPipelineShaderStageCreateInfo>& vkStages,
128+
std::vector<ShaderStageSpecializationData>& SpecDataPerStage)
129+
{
130+
if (NumSpecializationConstants == 0 || pSpecializationConstants == nullptr)
131+
return;
132+
133+
// vkStages has one entry per ShaderStageInfo::Item across all stages.
134+
// We build one ShaderStageSpecializationData per vkStages entry.
135+
SpecDataPerStage.resize(vkStages.size());
136+
137+
Uint32 vkStageIdx = 0;
138+
for (const PipelineStateVkImpl::ShaderStageInfo& Stage : ShaderStages)
139+
{
140+
for (const PipelineStateVkImpl::ShaderStageInfo::Item& StageItem : Stage.Items)
141+
{
142+
VERIFY_EXPR(vkStageIdx < vkStages.size());
143+
144+
const ShaderVkImpl* pShader = StageItem.pShader;
145+
const SPIRVShaderResources* pResources = pShader->GetShaderResources().get();
146+
ShaderStageSpecializationData& StageData = SpecDataPerStage[vkStageIdx];
147+
148+
for (Uint32 r = 0; r < pResources->GetNumSpecConstants(); ++r)
149+
{
150+
const SPIRVSpecializationConstantAttribs& Reflected = pResources->GetSpecConstant(r);
151+
152+
// Search for a matching user-provided constant by name and stage flag.
153+
const SpecializationConstant* pUserConst = nullptr;
154+
for (Uint32 sc = 0; sc < NumSpecializationConstants; ++sc)
155+
{
156+
const SpecializationConstant& Candidate = pSpecializationConstants[sc];
157+
if ((Candidate.ShaderStages & Stage.Type) != 0 &&
158+
strcmp(Candidate.Name, Reflected.Name) == 0)
159+
{
160+
pUserConst = &Candidate;
161+
break;
162+
}
163+
}
164+
165+
// No user constant for this reflected entry -- skip silently.
166+
if (pUserConst == nullptr)
167+
continue;
168+
169+
// The user may provide more data than the shader needs (e.g. when
170+
// sharing a constant array across pipelines with different types).
171+
// Only reject the case where the user provides less data than required.
172+
if (pUserConst->Size < Reflected.Size)
173+
{
174+
LOG_ERROR_AND_THROW("Description of ", GetPipelineTypeString(PSODesc.PipelineType),
175+
" PSO '", (PSODesc.Name != nullptr ? PSODesc.Name : ""),
176+
"' is invalid: specialization constant '", pUserConst->Name,
177+
"' in ", GetShaderTypeLiteralName(Stage.Type),
178+
" shader '", pShader->GetDesc().Name,
179+
"' has insufficient data: user provided ", pUserConst->Size,
180+
" bytes, but the shader declares ",
181+
GetShaderCodeBasicTypeString(Reflected.BasicType),
182+
" (", Reflected.Size, " bytes).");
183+
}
184+
185+
// Use the reflected size -- it is the actual size the shader expects.
186+
const Uint32 ConstSize = Reflected.Size;
187+
188+
// Build the map entry.
189+
VkSpecializationMapEntry Entry{};
190+
Entry.constantID = Reflected.SpecId;
191+
Entry.offset = static_cast<uint32_t>(StageData.DataBlob.size());
192+
Entry.size = ConstSize;
193+
StageData.MapEntries.push_back(Entry);
194+
195+
// Append data to the blob (only the bytes the shader needs).
196+
const Uint8* pSrcData = static_cast<const Uint8*>(pUserConst->pData);
197+
StageData.DataBlob.insert(StageData.DataBlob.end(), pSrcData, pSrcData + ConstSize);
198+
}
199+
200+
// Populate VkSpecializationInfo if any entries were matched.
201+
if (!StageData.MapEntries.empty())
202+
{
203+
StageData.Info.mapEntryCount = static_cast<uint32_t>(StageData.MapEntries.size());
204+
StageData.Info.pMapEntries = StageData.MapEntries.data();
205+
StageData.Info.dataSize = StageData.DataBlob.size();
206+
StageData.Info.pData = StageData.DataBlob.data();
207+
208+
vkStages[vkStageIdx].pSpecializationInfo = &StageData.Info;
209+
}
210+
211+
++vkStageIdx;
212+
}
213+
}
214+
215+
VERIFY_EXPR(vkStageIdx == vkStages.size());
216+
}
217+
97218

98219
void CreateComputePipeline(RenderDeviceVkImpl* pDeviceVk,
99220
std::vector<VkPipelineShaderStageCreateInfo>& Stages,
@@ -955,7 +1076,8 @@ template <typename PSOCreateInfoType>
9551076
PipelineStateVkImpl::TShaderStages PipelineStateVkImpl::InitInternalObjects(
9561077
const PSOCreateInfoType& CreateInfo,
9571078
std::vector<VkPipelineShaderStageCreateInfo>& vkShaderStages,
958-
std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules) noexcept(false)
1079+
std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules,
1080+
std::vector<ShaderStageSpecializationData>& SpecDataPerStage) noexcept(false)
9591081
{
9601082
TShaderStages ShaderStages;
9611083
ExtractShaders<ShaderVkImpl>(CreateInfo, ShaderStages, /*WaitUntilShadersReady = */ true);
@@ -975,15 +1097,19 @@ PipelineStateVkImpl::TShaderStages PipelineStateVkImpl::InitInternalObjects(
9751097
// Create shader modules and initialize shader stages
9761098
InitPipelineShaderStages(LogicalDevice, ShaderStages, ShaderModules, vkShaderStages);
9771099

1100+
// Build per-stage specialization data and patch vkShaderStages.
1101+
BuildSpecializationData(ShaderStages, CreateInfo.NumSpecializationConstants, CreateInfo.pSpecializationConstants, m_Desc, vkShaderStages, SpecDataPerStage);
1102+
9781103
return ShaderStages;
9791104
}
9801105

9811106
void PipelineStateVkImpl::InitializePipeline(const GraphicsPipelineStateCreateInfo& CreateInfo)
9821107
{
9831108
std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
9841109
std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1110+
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
9851111

986-
InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1112+
InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage);
9871113

9881114
const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache)->GetVkPipelineCache() : VK_NULL_HANDLE;
9891115
CreateGraphicsPipeline(m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_pGraphicsPipelineData->Desc, m_Pipeline, GetRenderPassPtr(), vkSPOCache);
@@ -993,8 +1119,9 @@ void PipelineStateVkImpl::InitializePipeline(const ComputePipelineStateCreateInf
9931119
{
9941120
std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
9951121
std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1122+
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
9961123

997-
InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1124+
InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage);
9981125

9991126
const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache)->GetVkPipelineCache() : VK_NULL_HANDLE;
10001127
CreateComputePipeline(m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_Pipeline, vkSPOCache);
@@ -1006,8 +1133,10 @@ void PipelineStateVkImpl::InitializePipeline(const RayTracingPipelineStateCreate
10061133

10071134
std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
10081135
std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1136+
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1137+
1138+
const TShaderStages ShaderStages = InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage);
10091139

1010-
const PipelineStateVkImpl::TShaderStages ShaderStages = InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
10111140
const std::vector<VkRayTracingShaderGroupCreateInfoKHR> vkShaderGroups = BuildRTShaderGroupDescription(CreateInfo, m_pRayTracingPipelineData->NameToGroupIndex, ShaderStages);
10121141
const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache)->GetVkPipelineCache() : VK_NULL_HANDLE;
10131142

0 commit comments

Comments
 (0)