Skip to content

Commit 1984dc6

Browse files
committed
Build SpecializationData with Name -> SpecId matching policy, populate Vulkan structs during PSO creation.
1 parent d0e0a6d commit 1984dc6

1 file changed

Lines changed: 157 additions & 3 deletions

File tree

Graphics/GraphicsEngineVulkan/src/PipelineStateVkImpl.cpp

Lines changed: 157 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,148 @@ void InitPipelineShaderStages(const VulkanUtilities::LogicalDevice&
9595
}
9696

9797

98+
// Per-shader-stage specialization constant data for Vulkan pipeline creation.
99+
// Holds VkSpecializationMapEntry array, contiguous data blob, and the
100+
// VkSpecializationInfo that references them. Lifetime must exceed the
101+
// vkCreate*Pipelines call.
102+
struct ShaderStageSpecializationData
103+
{
104+
std::vector<VkSpecializationMapEntry> MapEntries;
105+
std::vector<Uint8> DataBlob;
106+
VkSpecializationInfo Info{};
107+
};
108+
109+
// Matches user-provided SpecializationConstant entries to SPIR-V reflected
110+
// specialization constants by name, validates size compatibility, and builds
111+
// per-stage VkSpecializationInfo structures.
112+
//
113+
// Parameters:
114+
// ShaderStages - shader stages extracted from the PSO create info
115+
// NumSpecializationConstants - number of user-provided specialization constants
116+
// pSpecializationConstants - user-provided specialization constant array
117+
// PSODesc - pipeline state description (for error messages)
118+
// vkStages [in/out] - VkPipelineShaderStageCreateInfo array to patch
119+
// SpecDataPerStage [out] - per-stage specialization data (must outlive vkCreate*Pipelines)
120+
//
121+
// Throws on validation failure (name not found, size mismatch, duplicate SpecId binding).
122+
void BuildSpecializationData(const PipelineStateVkImpl::TShaderStages& ShaderStages,
123+
Uint32 NumSpecializationConstants,
124+
const SpecializationConstant* pSpecializationConstants,
125+
const PipelineStateDesc& PSODesc,
126+
std::vector<VkPipelineShaderStageCreateInfo>& vkStages,
127+
std::vector<ShaderStageSpecializationData>& SpecDataPerStage)
128+
{
129+
if (NumSpecializationConstants == 0 || pSpecializationConstants == nullptr)
130+
return;
131+
132+
// vkStages has one entry per ShaderStageInfo::Item across all stages.
133+
// We build one ShaderStageSpecializationData per vkStages entry.
134+
SpecDataPerStage.resize(vkStages.size());
135+
136+
Uint32 vkStageIdx = 0;
137+
for (const PipelineStateVkImpl::ShaderStageInfo& Stage : ShaderStages)
138+
{
139+
for (const PipelineStateVkImpl::ShaderStageInfo::Item& StageItem : Stage.Items)
140+
{
141+
VERIFY_EXPR(vkStageIdx < vkStages.size());
142+
143+
const ShaderVkImpl* pShader = StageItem.pShader;
144+
const SPIRVShaderResources* pResources = pShader->GetShaderResources().get();
145+
ShaderStageSpecializationData& StageData = SpecDataPerStage[vkStageIdx];
146+
147+
// Track SpecIds already bound in this stage to detect duplicates.
148+
std::unordered_map<uint32_t, const char*> BoundSpecIds;
149+
150+
Uint32 DataOffset = 0;
151+
for (Uint32 sc = 0; sc < NumSpecializationConstants; ++sc)
152+
{
153+
const SpecializationConstant& UserConst = pSpecializationConstants[sc];
154+
155+
// Check if this constant applies to the current shader stage.
156+
if ((UserConst.ShaderStages & Stage.Type) == 0)
157+
continue;
158+
159+
// Search for the matching reflected specialization constant by name.
160+
const SPIRVSpecializationConstantAttribs* pReflected = nullptr;
161+
for (Uint32 r = 0; r < pResources->GetNumSpecConstants(); ++r)
162+
{
163+
const auto& Reflected = pResources->GetSpecConstant(r);
164+
if (strcmp(Reflected.Name, UserConst.Name) == 0)
165+
{
166+
pReflected = &Reflected;
167+
break;
168+
}
169+
}
170+
171+
if (pReflected == nullptr)
172+
{
173+
LOG_ERROR_AND_THROW("Description of ", GetPipelineTypeString(PSODesc.PipelineType),
174+
" PSO '", (PSODesc.Name != nullptr ? PSODesc.Name : ""),
175+
"' is invalid: specialization constant '", UserConst.Name,
176+
"' was not found in ", GetShaderTypeLiteralName(Stage.Type),
177+
" shader '", pShader->GetDesc().Name, "'.");
178+
}
179+
180+
// Validate size compatibility.
181+
if (UserConst.Size != pReflected->Size)
182+
{
183+
LOG_ERROR_AND_THROW("Description of ", GetPipelineTypeString(PSODesc.PipelineType),
184+
" PSO '", (PSODesc.Name != nullptr ? PSODesc.Name : ""),
185+
"' is invalid: specialization constant '", UserConst.Name,
186+
"' in ", GetShaderTypeLiteralName(Stage.Type),
187+
" shader '", pShader->GetDesc().Name,
188+
"' has size mismatch: user provided ", UserConst.Size,
189+
" bytes, but the shader declares ",
190+
GetShaderCodeBasicTypeString(pReflected->BasicType),
191+
" (", pReflected->Size, " bytes).");
192+
}
193+
194+
// Check for duplicate SpecId binding within this stage.
195+
auto it = BoundSpecIds.find(pReflected->SpecId);
196+
if (it != BoundSpecIds.end())
197+
{
198+
LOG_ERROR_AND_THROW("Description of ", GetPipelineTypeString(PSODesc.PipelineType),
199+
" PSO '", (PSODesc.Name != nullptr ? PSODesc.Name : ""),
200+
"' is invalid: specialization constant '", UserConst.Name,
201+
"' in ", GetShaderTypeLiteralName(Stage.Type),
202+
" shader '", pShader->GetDesc().Name,
203+
"' maps to SpecId ", pReflected->SpecId,
204+
" which is already bound by constant '", it->second, "'.");
205+
}
206+
BoundSpecIds.emplace(pReflected->SpecId, UserConst.Name);
207+
208+
// Build the map entry.
209+
VkSpecializationMapEntry Entry{};
210+
Entry.constantID = pReflected->SpecId;
211+
Entry.offset = DataOffset;
212+
Entry.size = UserConst.Size;
213+
StageData.MapEntries.push_back(Entry);
214+
215+
// Append data to the blob.
216+
const Uint8* pSrcData = static_cast<const Uint8*>(UserConst.pData);
217+
StageData.DataBlob.insert(StageData.DataBlob.end(), pSrcData, pSrcData + UserConst.Size);
218+
DataOffset += UserConst.Size;
219+
}
220+
221+
// Populate VkSpecializationInfo if any entries were matched.
222+
if (!StageData.MapEntries.empty())
223+
{
224+
StageData.Info.mapEntryCount = static_cast<uint32_t>(StageData.MapEntries.size());
225+
StageData.Info.pMapEntries = StageData.MapEntries.data();
226+
StageData.Info.dataSize = StageData.DataBlob.size();
227+
StageData.Info.pData = StageData.DataBlob.data();
228+
229+
vkStages[vkStageIdx].pSpecializationInfo = &StageData.Info;
230+
}
231+
232+
++vkStageIdx;
233+
}
234+
}
235+
236+
VERIFY_EXPR(vkStageIdx == vkStages.size());
237+
}
238+
239+
98240
void CreateComputePipeline(RenderDeviceVkImpl* pDeviceVk,
99241
std::vector<VkPipelineShaderStageCreateInfo>& Stages,
100242
const PipelineLayoutVk& Layout,
@@ -983,7 +1125,12 @@ void PipelineStateVkImpl::InitializePipeline(const GraphicsPipelineStateCreateIn
9831125
std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
9841126
std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
9851127

986-
InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1128+
const TShaderStages ShaderStages = InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1129+
1130+
// Build per-stage specialization data and patch vkShaderStages.
1131+
// SpecDataPerStage must outlive the vkCreate*Pipelines call below.
1132+
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1133+
BuildSpecializationData(ShaderStages, CreateInfo.NumSpecializationConstants, CreateInfo.pSpecializationConstants, m_Desc, vkShaderStages, SpecDataPerStage);
9871134

9881135
const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache)->GetVkPipelineCache() : VK_NULL_HANDLE;
9891136
CreateGraphicsPipeline(m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_pGraphicsPipelineData->Desc, m_Pipeline, GetRenderPassPtr(), vkSPOCache);
@@ -994,7 +1141,10 @@ void PipelineStateVkImpl::InitializePipeline(const ComputePipelineStateCreateInf
9941141
std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
9951142
std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
9961143

997-
InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1144+
const TShaderStages ShaderStages = InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1145+
1146+
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1147+
BuildSpecializationData(ShaderStages, CreateInfo.NumSpecializationConstants, CreateInfo.pSpecializationConstants, m_Desc, vkShaderStages, SpecDataPerStage);
9981148

9991149
const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache)->GetVkPipelineCache() : VK_NULL_HANDLE;
10001150
CreateComputePipeline(m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_Pipeline, vkSPOCache);
@@ -1007,7 +1157,11 @@ void PipelineStateVkImpl::InitializePipeline(const RayTracingPipelineStateCreate
10071157
std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
10081158
std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
10091159

1010-
const PipelineStateVkImpl::TShaderStages ShaderStages = InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1160+
const PipelineStateVkImpl::TShaderStages ShaderStages = InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1161+
1162+
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1163+
BuildSpecializationData(ShaderStages, CreateInfo.NumSpecializationConstants, CreateInfo.pSpecializationConstants, m_Desc, vkShaderStages, SpecDataPerStage);
1164+
10111165
const std::vector<VkRayTracingShaderGroupCreateInfoKHR> vkShaderGroups = BuildRTShaderGroupDescription(CreateInfo, m_pRayTracingPipelineData->NameToGroupIndex, ShaderStages);
10121166
const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache)->GetVkPipelineCache() : VK_NULL_HANDLE;
10131167

0 commit comments

Comments
 (0)