@@ -94,6 +94,127 @@ void InitPipelineShaderStages(const VulkanUtilities::LogicalDevice&
9494 VERIFY_EXPR (ShaderModules.size () == Stages.size ());
9595}
9696
97+ // Per-shader-stage specialization constant data for Vulkan pipeline creation.
98+ // Holds VkSpecializationMapEntry array, contiguous data blob, and the
99+ // VkSpecializationInfo that references them. Lifetime must exceed the
100+ // vkCreate*Pipelines call.
101+ struct ShaderStageSpecializationData
102+ {
103+ std::vector<VkSpecializationMapEntry> MapEntries;
104+ std::vector<Uint8> DataBlob;
105+ VkSpecializationInfo Info{};
106+ };
107+
108+ // Iterates SPIR-V reflected specialization constants and matches them to
109+ // user-provided SpecializationConstant entries by name. If a reflected
110+ // constant has no matching user entry the constant is silently skipped,
111+ // which allows the user to supply a superset of constants shared across
112+ // multiple pipelines / stages.
113+ //
114+ // Parameters:
115+ // ShaderStages - shader stages extracted from the PSO create info
116+ // NumSpecializationConstants - number of user-provided specialization constants
117+ // pSpecializationConstants - user-provided specialization constant array
118+ // PSODesc - pipeline state description (for error messages)
119+ // vkStages [in/out] - VkPipelineShaderStageCreateInfo array to patch
120+ // SpecDataPerStage [out] - per-stage specialization data (must outlive vkCreate*Pipelines)
121+ //
122+ // Throws on size mismatch between user-provided and reflected constants.
123+ void BuildSpecializationData (const PipelineStateVkImpl::TShaderStages& ShaderStages,
124+ Uint32 NumSpecializationConstants,
125+ const SpecializationConstant* pSpecializationConstants,
126+ const PipelineStateDesc& PSODesc,
127+ std::vector<VkPipelineShaderStageCreateInfo>& vkStages,
128+ std::vector<ShaderStageSpecializationData>& SpecDataPerStage)
129+ {
130+ if (NumSpecializationConstants == 0 || pSpecializationConstants == nullptr )
131+ return ;
132+
133+ // vkStages has one entry per ShaderStageInfo::Item across all stages.
134+ // We build one ShaderStageSpecializationData per vkStages entry.
135+ SpecDataPerStage.resize (vkStages.size ());
136+
137+ Uint32 vkStageIdx = 0 ;
138+ for (const PipelineStateVkImpl::ShaderStageInfo& Stage : ShaderStages)
139+ {
140+ for (const PipelineStateVkImpl::ShaderStageInfo::Item& StageItem : Stage.Items )
141+ {
142+ VERIFY_EXPR (vkStageIdx < vkStages.size ());
143+
144+ const ShaderVkImpl* pShader = StageItem.pShader ;
145+ const SPIRVShaderResources* pResources = pShader->GetShaderResources ().get ();
146+ ShaderStageSpecializationData& StageData = SpecDataPerStage[vkStageIdx];
147+
148+ for (Uint32 r = 0 ; r < pResources->GetNumSpecConstants (); ++r)
149+ {
150+ const SPIRVSpecializationConstantAttribs& Reflected = pResources->GetSpecConstant (r);
151+
152+ // Search for a matching user-provided constant by name and stage flag.
153+ const SpecializationConstant* pUserConst = nullptr ;
154+ for (Uint32 sc = 0 ; sc < NumSpecializationConstants; ++sc)
155+ {
156+ const SpecializationConstant& Candidate = pSpecializationConstants[sc];
157+ if ((Candidate.ShaderStages & Stage.Type ) != 0 &&
158+ strcmp (Candidate.Name , Reflected.Name ) == 0 )
159+ {
160+ pUserConst = &Candidate;
161+ break ;
162+ }
163+ }
164+
165+ // No user constant for this reflected entry -- skip silently.
166+ if (pUserConst == nullptr )
167+ continue ;
168+
169+ // The user may provide more data than the shader needs (e.g. when
170+ // sharing a constant array across pipelines with different types).
171+ // Only reject the case where the user provides less data than required.
172+ if (pUserConst->Size < Reflected.Size )
173+ {
174+ LOG_ERROR_AND_THROW (" Description of " , GetPipelineTypeString (PSODesc.PipelineType ),
175+ " PSO '" , (PSODesc.Name != nullptr ? PSODesc.Name : " " ),
176+ " ' is invalid: specialization constant '" , pUserConst->Name ,
177+ " ' in " , GetShaderTypeLiteralName (Stage.Type ),
178+ " shader '" , pShader->GetDesc ().Name ,
179+ " ' has insufficient data: user provided " , pUserConst->Size ,
180+ " bytes, but the shader declares " ,
181+ GetShaderCodeBasicTypeString (Reflected.BasicType ),
182+ " (" , Reflected.Size , " bytes)." );
183+ }
184+
185+ // Use the reflected size -- it is the actual size the shader expects.
186+ const Uint32 ConstSize = Reflected.Size ;
187+
188+ // Build the map entry.
189+ VkSpecializationMapEntry Entry{};
190+ Entry.constantID = Reflected.SpecId ;
191+ Entry.offset = static_cast <uint32_t >(StageData.DataBlob .size ());
192+ Entry.size = ConstSize;
193+ StageData.MapEntries .push_back (Entry);
194+
195+ // Append data to the blob (only the bytes the shader needs).
196+ const Uint8* pSrcData = static_cast <const Uint8*>(pUserConst->pData );
197+ StageData.DataBlob .insert (StageData.DataBlob .end (), pSrcData, pSrcData + ConstSize);
198+ }
199+
200+ // Populate VkSpecializationInfo if any entries were matched.
201+ if (!StageData.MapEntries .empty ())
202+ {
203+ StageData.Info .mapEntryCount = static_cast <uint32_t >(StageData.MapEntries .size ());
204+ StageData.Info .pMapEntries = StageData.MapEntries .data ();
205+ StageData.Info .dataSize = StageData.DataBlob .size ();
206+ StageData.Info .pData = StageData.DataBlob .data ();
207+
208+ vkStages[vkStageIdx].pSpecializationInfo = &StageData.Info ;
209+ }
210+
211+ ++vkStageIdx;
212+ }
213+ }
214+
215+ VERIFY_EXPR (vkStageIdx == vkStages.size ());
216+ }
217+
97218
98219void CreateComputePipeline (RenderDeviceVkImpl* pDeviceVk,
99220 std::vector<VkPipelineShaderStageCreateInfo>& Stages,
@@ -955,7 +1076,8 @@ template <typename PSOCreateInfoType>
9551076PipelineStateVkImpl::TShaderStages PipelineStateVkImpl::InitInternalObjects (
9561077 const PSOCreateInfoType& CreateInfo,
9571078 std::vector<VkPipelineShaderStageCreateInfo>& vkShaderStages,
958- std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules) noexcept (false )
1079+ std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules,
1080+ std::vector<ShaderStageSpecializationData>& SpecDataPerStage) noexcept (false )
9591081{
9601082 TShaderStages ShaderStages;
9611083 ExtractShaders<ShaderVkImpl>(CreateInfo, ShaderStages, /* WaitUntilShadersReady = */ true );
@@ -975,15 +1097,19 @@ PipelineStateVkImpl::TShaderStages PipelineStateVkImpl::InitInternalObjects(
9751097 // Create shader modules and initialize shader stages
9761098 InitPipelineShaderStages (LogicalDevice, ShaderStages, ShaderModules, vkShaderStages);
9771099
1100+ // Build per-stage specialization data and patch vkShaderStages.
1101+ BuildSpecializationData (ShaderStages, CreateInfo.NumSpecializationConstants , CreateInfo.pSpecializationConstants , m_Desc, vkShaderStages, SpecDataPerStage);
1102+
9781103 return ShaderStages;
9791104}
9801105
9811106void PipelineStateVkImpl::InitializePipeline (const GraphicsPipelineStateCreateInfo& CreateInfo)
9821107{
9831108 std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
9841109 std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1110+ std::vector<ShaderStageSpecializationData> SpecDataPerStage;
9851111
986- InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1112+ InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage );
9871113
9881114 const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache )->GetVkPipelineCache () : VK_NULL_HANDLE;
9891115 CreateGraphicsPipeline (m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_pGraphicsPipelineData->Desc , m_Pipeline, GetRenderPassPtr (), vkSPOCache);
@@ -993,8 +1119,9 @@ void PipelineStateVkImpl::InitializePipeline(const ComputePipelineStateCreateInf
9931119{
9941120 std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
9951121 std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1122+ std::vector<ShaderStageSpecializationData> SpecDataPerStage;
9961123
997- InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1124+ InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage );
9981125
9991126 const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache )->GetVkPipelineCache () : VK_NULL_HANDLE;
10001127 CreateComputePipeline (m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_Pipeline, vkSPOCache);
@@ -1006,8 +1133,10 @@ void PipelineStateVkImpl::InitializePipeline(const RayTracingPipelineStateCreate
10061133
10071134 std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
10081135 std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1136+ std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1137+
1138+ const TShaderStages ShaderStages = InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage);
10091139
1010- const PipelineStateVkImpl::TShaderStages ShaderStages = InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
10111140 const std::vector<VkRayTracingShaderGroupCreateInfoKHR> vkShaderGroups = BuildRTShaderGroupDescription (CreateInfo, m_pRayTracingPipelineData->NameToGroupIndex , ShaderStages);
10121141 const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache )->GetVkPipelineCache () : VK_NULL_HANDLE;
10131142
0 commit comments