@@ -95,6 +95,148 @@ void InitPipelineShaderStages(const VulkanUtilities::LogicalDevice&
9595}
9696
9797
98+ // Per-shader-stage specialization constant data for Vulkan pipeline creation.
99+ // Holds VkSpecializationMapEntry array, contiguous data blob, and the
100+ // VkSpecializationInfo that references them. Lifetime must exceed the
101+ // vkCreate*Pipelines call.
102+ struct ShaderStageSpecializationData
103+ {
104+ std::vector<VkSpecializationMapEntry> MapEntries;
105+ std::vector<Uint8> DataBlob;
106+ VkSpecializationInfo Info{};
107+ };
108+
109+ // Matches user-provided SpecializationConstant entries to SPIR-V reflected
110+ // specialization constants by name, validates size compatibility, and builds
111+ // per-stage VkSpecializationInfo structures.
112+ //
113+ // Parameters:
114+ // ShaderStages - shader stages extracted from the PSO create info
115+ // NumSpecializationConstants - number of user-provided specialization constants
116+ // pSpecializationConstants - user-provided specialization constant array
117+ // PSODesc - pipeline state description (for error messages)
118+ // vkStages [in/out] - VkPipelineShaderStageCreateInfo array to patch
119+ // SpecDataPerStage [out] - per-stage specialization data (must outlive vkCreate*Pipelines)
120+ //
121+ // Throws on validation failure (name not found, size mismatch, duplicate SpecId binding).
122+ void BuildSpecializationData (const PipelineStateVkImpl::TShaderStages& ShaderStages,
123+ Uint32 NumSpecializationConstants,
124+ const SpecializationConstant* pSpecializationConstants,
125+ const PipelineStateDesc& PSODesc,
126+ std::vector<VkPipelineShaderStageCreateInfo>& vkStages,
127+ std::vector<ShaderStageSpecializationData>& SpecDataPerStage)
128+ {
129+ if (NumSpecializationConstants == 0 || pSpecializationConstants == nullptr )
130+ return ;
131+
132+ // vkStages has one entry per ShaderStageInfo::Item across all stages.
133+ // We build one ShaderStageSpecializationData per vkStages entry.
134+ SpecDataPerStage.resize (vkStages.size ());
135+
136+ Uint32 vkStageIdx = 0 ;
137+ for (const PipelineStateVkImpl::ShaderStageInfo& Stage : ShaderStages)
138+ {
139+ for (const PipelineStateVkImpl::ShaderStageInfo::Item& StageItem : Stage.Items )
140+ {
141+ VERIFY_EXPR (vkStageIdx < vkStages.size ());
142+
143+ const ShaderVkImpl* pShader = StageItem.pShader ;
144+ const SPIRVShaderResources* pResources = pShader->GetShaderResources ().get ();
145+ ShaderStageSpecializationData& StageData = SpecDataPerStage[vkStageIdx];
146+
147+ // Track SpecIds already bound in this stage to detect duplicates.
148+ std::unordered_map<uint32_t , const char *> BoundSpecIds;
149+
150+ Uint32 DataOffset = 0 ;
151+ for (Uint32 sc = 0 ; sc < NumSpecializationConstants; ++sc)
152+ {
153+ const SpecializationConstant& UserConst = pSpecializationConstants[sc];
154+
155+ // Check if this constant applies to the current shader stage.
156+ if ((UserConst.ShaderStages & Stage.Type ) == 0 )
157+ continue ;
158+
159+ // Search for the matching reflected specialization constant by name.
160+ const SPIRVSpecializationConstantAttribs* pReflected = nullptr ;
161+ for (Uint32 r = 0 ; r < pResources->GetNumSpecConstants (); ++r)
162+ {
163+ const auto & Reflected = pResources->GetSpecConstant (r);
164+ if (strcmp (Reflected.Name , UserConst.Name ) == 0 )
165+ {
166+ pReflected = &Reflected;
167+ break ;
168+ }
169+ }
170+
171+ if (pReflected == nullptr )
172+ {
173+ LOG_ERROR_AND_THROW (" Description of " , GetPipelineTypeString (PSODesc.PipelineType ),
174+ " PSO '" , (PSODesc.Name != nullptr ? PSODesc.Name : " " ),
175+ " ' is invalid: specialization constant '" , UserConst.Name ,
176+ " ' was not found in " , GetShaderTypeLiteralName (Stage.Type ),
177+ " shader '" , pShader->GetDesc ().Name , " '." );
178+ }
179+
180+ // Validate size compatibility.
181+ if (UserConst.Size != pReflected->Size )
182+ {
183+ LOG_ERROR_AND_THROW (" Description of " , GetPipelineTypeString (PSODesc.PipelineType ),
184+ " PSO '" , (PSODesc.Name != nullptr ? PSODesc.Name : " " ),
185+ " ' is invalid: specialization constant '" , UserConst.Name ,
186+ " ' in " , GetShaderTypeLiteralName (Stage.Type ),
187+ " shader '" , pShader->GetDesc ().Name ,
188+ " ' has size mismatch: user provided " , UserConst.Size ,
189+ " bytes, but the shader declares " ,
190+ GetShaderCodeBasicTypeString (pReflected->BasicType ),
191+ " (" , pReflected->Size , " bytes)." );
192+ }
193+
194+ // Check for duplicate SpecId binding within this stage.
195+ auto it = BoundSpecIds.find (pReflected->SpecId );
196+ if (it != BoundSpecIds.end ())
197+ {
198+ LOG_ERROR_AND_THROW (" Description of " , GetPipelineTypeString (PSODesc.PipelineType ),
199+ " PSO '" , (PSODesc.Name != nullptr ? PSODesc.Name : " " ),
200+ " ' is invalid: specialization constant '" , UserConst.Name ,
201+ " ' in " , GetShaderTypeLiteralName (Stage.Type ),
202+ " shader '" , pShader->GetDesc ().Name ,
203+ " ' maps to SpecId " , pReflected->SpecId ,
204+ " which is already bound by constant '" , it->second , " '." );
205+ }
206+ BoundSpecIds.emplace (pReflected->SpecId , UserConst.Name );
207+
208+ // Build the map entry.
209+ VkSpecializationMapEntry Entry{};
210+ Entry.constantID = pReflected->SpecId ;
211+ Entry.offset = DataOffset;
212+ Entry.size = UserConst.Size ;
213+ StageData.MapEntries .push_back (Entry);
214+
215+ // Append data to the blob.
216+ const Uint8* pSrcData = static_cast <const Uint8*>(UserConst.pData );
217+ StageData.DataBlob .insert (StageData.DataBlob .end (), pSrcData, pSrcData + UserConst.Size );
218+ DataOffset += UserConst.Size ;
219+ }
220+
221+ // Populate VkSpecializationInfo if any entries were matched.
222+ if (!StageData.MapEntries .empty ())
223+ {
224+ StageData.Info .mapEntryCount = static_cast <uint32_t >(StageData.MapEntries .size ());
225+ StageData.Info .pMapEntries = StageData.MapEntries .data ();
226+ StageData.Info .dataSize = StageData.DataBlob .size ();
227+ StageData.Info .pData = StageData.DataBlob .data ();
228+
229+ vkStages[vkStageIdx].pSpecializationInfo = &StageData.Info ;
230+ }
231+
232+ ++vkStageIdx;
233+ }
234+ }
235+
236+ VERIFY_EXPR (vkStageIdx == vkStages.size ());
237+ }
238+
239+
98240void CreateComputePipeline (RenderDeviceVkImpl* pDeviceVk,
99241 std::vector<VkPipelineShaderStageCreateInfo>& Stages,
100242 const PipelineLayoutVk& Layout,
@@ -983,7 +1125,12 @@ void PipelineStateVkImpl::InitializePipeline(const GraphicsPipelineStateCreateIn
9831125 std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
9841126 std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
9851127
986- InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1128+ const TShaderStages ShaderStages = InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1129+
1130+ // Build per-stage specialization data and patch vkShaderStages.
1131+ // SpecDataPerStage must outlive the vkCreate*Pipelines call below.
1132+ std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1133+ BuildSpecializationData (ShaderStages, CreateInfo.NumSpecializationConstants , CreateInfo.pSpecializationConstants , m_Desc, vkShaderStages, SpecDataPerStage);
9871134
9881135 const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache )->GetVkPipelineCache () : VK_NULL_HANDLE;
9891136 CreateGraphicsPipeline (m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_pGraphicsPipelineData->Desc , m_Pipeline, GetRenderPassPtr (), vkSPOCache);
@@ -994,7 +1141,10 @@ void PipelineStateVkImpl::InitializePipeline(const ComputePipelineStateCreateInf
9941141 std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
9951142 std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
9961143
997- InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1144+ const TShaderStages ShaderStages = InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1145+
1146+ std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1147+ BuildSpecializationData (ShaderStages, CreateInfo.NumSpecializationConstants , CreateInfo.pSpecializationConstants , m_Desc, vkShaderStages, SpecDataPerStage);
9981148
9991149 const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache )->GetVkPipelineCache () : VK_NULL_HANDLE;
10001150 CreateComputePipeline (m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_Pipeline, vkSPOCache);
@@ -1007,7 +1157,11 @@ void PipelineStateVkImpl::InitializePipeline(const RayTracingPipelineStateCreate
10071157 std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
10081158 std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
10091159
1010- const PipelineStateVkImpl::TShaderStages ShaderStages = InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1160+ const PipelineStateVkImpl::TShaderStages ShaderStages = InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1161+
1162+ std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1163+ BuildSpecializationData (ShaderStages, CreateInfo.NumSpecializationConstants , CreateInfo.pSpecializationConstants , m_Desc, vkShaderStages, SpecDataPerStage);
1164+
10111165 const std::vector<VkRayTracingShaderGroupCreateInfoKHR> vkShaderGroups = BuildRTShaderGroupDescription (CreateInfo, m_pRayTracingPipelineData->NameToGroupIndex , ShaderStages);
10121166 const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache )->GetVkPipelineCache () : VK_NULL_HANDLE;
10131167
0 commit comments