@@ -52,6 +52,17 @@ namespace Diligent
5252
5353constexpr INTERFACE_ID PipelineStateVkImpl::IID_InternalImpl;
5454
55+ // Per-shader-stage specialization constant data for Vulkan pipeline creation.
56+ // Holds VkSpecializationMapEntry array, contiguous data blob, and the
57+ // VkSpecializationInfo that references them. Lifetime must exceed the
58+ // vkCreate*Pipelines call.
59+ struct ShaderStageSpecializationData
60+ {
61+ std::vector<VkSpecializationMapEntry> MapEntries;
62+ std::vector<Uint8> DataBlob;
63+ VkSpecializationInfo Info{};
64+ };
65+
5566namespace
5667{
5768
@@ -95,20 +106,11 @@ void InitPipelineShaderStages(const VulkanUtilities::LogicalDevice&
95106}
96107
97108
98- // Per-shader-stage specialization constant data for Vulkan pipeline creation.
99- // Holds VkSpecializationMapEntry array, contiguous data blob, and the
100- // VkSpecializationInfo that references them. Lifetime must exceed the
101- // vkCreate*Pipelines call.
102- struct ShaderStageSpecializationData
103- {
104- std::vector<VkSpecializationMapEntry> MapEntries;
105- std::vector<Uint8> DataBlob;
106- VkSpecializationInfo Info{};
107- };
108-
109- // Matches user-provided SpecializationConstant entries to SPIR-V reflected
110- // specialization constants by name, validates size compatibility, and builds
111- // per-stage VkSpecializationInfo structures.
109+ // Iterates SPIR-V reflected specialization constants and matches them to
110+ // user-provided SpecializationConstant entries by name. If a reflected
111+ // constant has no matching user entry the constant is silently skipped,
112+ // which allows the user to supply a superset of constants shared across
113+ // multiple pipelines / stages.
112114//
113115// Parameters:
114116// ShaderStages - shader stages extracted from the PSO create info
@@ -118,7 +120,7 @@ struct ShaderStageSpecializationData
118120// vkStages [in/out] - VkPipelineShaderStageCreateInfo array to patch
119121// SpecDataPerStage [out] - per-stage specialization data (must outlive vkCreate*Pipelines)
120122//
121- // Throws on validation failure (name not found, size mismatch, duplicate SpecId binding) .
123+ // Throws on size mismatch between user-provided and reflected constants .
122124void BuildSpecializationData (const PipelineStateVkImpl::TShaderStages& ShaderStages,
123125 Uint32 NumSpecializationConstants,
124126 const SpecializationConstant* pSpecializationConstants,
@@ -144,78 +146,58 @@ void BuildSpecializationData(const PipelineStateVkImpl::TShaderStages& Shade
144146 const SPIRVShaderResources* pResources = pShader->GetShaderResources ().get ();
145147 ShaderStageSpecializationData& StageData = SpecDataPerStage[vkStageIdx];
146148
147- // Track SpecIds already bound in this stage to detect duplicates.
148- std::unordered_map<uint32_t , const char *> BoundSpecIds;
149-
150149 Uint32 DataOffset = 0 ;
151- for (Uint32 sc = 0 ; sc < NumSpecializationConstants ; ++sc )
150+ for (Uint32 r = 0 ; r < pResources-> GetNumSpecConstants () ; ++r )
152151 {
153- const SpecializationConstant& UserConst = pSpecializationConstants[sc] ;
152+ const SPIRVSpecializationConstantAttribs& Reflected = pResources-> GetSpecConstant (r) ;
154153
155- // Check if this constant applies to the current shader stage.
156- if ((UserConst.ShaderStages & Stage.Type ) == 0 )
157- continue ;
158-
159- // Search for the matching reflected specialization constant by name.
160- const SPIRVSpecializationConstantAttribs* pReflected = nullptr ;
161- for (Uint32 r = 0 ; r < pResources->GetNumSpecConstants (); ++r)
154+ // Search for a matching user-provided constant by name and stage flag.
155+ const SpecializationConstant* pUserConst = nullptr ;
156+ for (Uint32 sc = 0 ; sc < NumSpecializationConstants; ++sc)
162157 {
163- const auto & Reflected = pResources->GetSpecConstant (r);
164- if (strcmp (Reflected.Name , UserConst.Name ) == 0 )
158+ const SpecializationConstant& Candidate = pSpecializationConstants[sc];
159+ if ((Candidate.ShaderStages & Stage.Type ) != 0 &&
160+ strcmp (Candidate.Name , Reflected.Name ) == 0 )
165161 {
166- pReflected = &Reflected ;
162+ pUserConst = &Candidate ;
167163 break ;
168164 }
169165 }
170166
171- if (pReflected == nullptr )
172- {
173- LOG_ERROR_AND_THROW (" Description of " , GetPipelineTypeString (PSODesc.PipelineType ),
174- " PSO '" , (PSODesc.Name != nullptr ? PSODesc.Name : " " ),
175- " ' is invalid: specialization constant '" , UserConst.Name ,
176- " ' was not found in " , GetShaderTypeLiteralName (Stage.Type ),
177- " shader '" , pShader->GetDesc ().Name , " '." );
178- }
167+ // No user constant for this reflected entry -- skip silently.
168+ if (pUserConst == nullptr )
169+ continue ;
179170
180- // Validate size compatibility.
181- if (UserConst.Size != pReflected->Size )
171+ // The user may provide more data than the shader needs (e.g. when
172+ // sharing a constant array across pipelines with different types).
173+ // Only reject the case where the user provides less data than required.
174+ if (pUserConst->Size < Reflected.Size )
182175 {
183176 LOG_ERROR_AND_THROW (" Description of " , GetPipelineTypeString (PSODesc.PipelineType ),
184177 " PSO '" , (PSODesc.Name != nullptr ? PSODesc.Name : " " ),
185- " ' is invalid: specialization constant '" , UserConst. Name ,
178+ " ' is invalid: specialization constant '" , pUserConst-> Name ,
186179 " ' in " , GetShaderTypeLiteralName (Stage.Type ),
187180 " shader '" , pShader->GetDesc ().Name ,
188- " ' has size mismatch : user provided " , UserConst. Size ,
181+ " ' has insufficient data : user provided " , pUserConst-> Size ,
189182 " bytes, but the shader declares " ,
190- GetShaderCodeBasicTypeString (pReflected-> BasicType ),
191- " (" , pReflected-> Size , " bytes)." );
183+ GetShaderCodeBasicTypeString (Reflected. BasicType ),
184+ " (" , Reflected. Size , " bytes)." );
192185 }
193186
194- // Check for duplicate SpecId binding within this stage.
195- auto it = BoundSpecIds.find (pReflected->SpecId );
196- if (it != BoundSpecIds.end ())
197- {
198- LOG_ERROR_AND_THROW (" Description of " , GetPipelineTypeString (PSODesc.PipelineType ),
199- " PSO '" , (PSODesc.Name != nullptr ? PSODesc.Name : " " ),
200- " ' is invalid: specialization constant '" , UserConst.Name ,
201- " ' in " , GetShaderTypeLiteralName (Stage.Type ),
202- " shader '" , pShader->GetDesc ().Name ,
203- " ' maps to SpecId " , pReflected->SpecId ,
204- " which is already bound by constant '" , it->second , " '." );
205- }
206- BoundSpecIds.emplace (pReflected->SpecId , UserConst.Name );
187+ // Use the reflected size -- it is the actual size the shader expects.
188+ const Uint32 ConstSize = Reflected.Size ;
207189
208190 // Build the map entry.
209191 VkSpecializationMapEntry Entry{};
210- Entry.constantID = pReflected-> SpecId ;
192+ Entry.constantID = Reflected. SpecId ;
211193 Entry.offset = DataOffset;
212- Entry.size = UserConst. Size ;
194+ Entry.size = ConstSize ;
213195 StageData.MapEntries .push_back (Entry);
214196
215- // Append data to the blob.
216- const Uint8* pSrcData = static_cast <const Uint8*>(UserConst. pData );
217- StageData.DataBlob .insert (StageData.DataBlob .end (), pSrcData, pSrcData + UserConst. Size );
218- DataOffset += UserConst. Size ;
197+ // Append data to the blob (only the bytes the shader needs) .
198+ const Uint8* pSrcData = static_cast <const Uint8*>(pUserConst-> pData );
199+ StageData.DataBlob .insert (StageData.DataBlob .end (), pSrcData, pSrcData + ConstSize );
200+ DataOffset += ConstSize ;
219201 }
220202
221203 // Populate VkSpecializationInfo if any entries were matched.
@@ -1097,7 +1079,8 @@ template <typename PSOCreateInfoType>
10971079PipelineStateVkImpl::TShaderStages PipelineStateVkImpl::InitInternalObjects (
10981080 const PSOCreateInfoType& CreateInfo,
10991081 std::vector<VkPipelineShaderStageCreateInfo>& vkShaderStages,
1100- std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules) noexcept (false )
1082+ std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules,
1083+ std::vector<ShaderStageSpecializationData>& SpecDataPerStage) noexcept (false )
11011084{
11021085 TShaderStages ShaderStages;
11031086 ExtractShaders<ShaderVkImpl>(CreateInfo, ShaderStages, /* WaitUntilShadersReady = */ true );
@@ -1117,20 +1100,19 @@ PipelineStateVkImpl::TShaderStages PipelineStateVkImpl::InitInternalObjects(
11171100 // Create shader modules and initialize shader stages
11181101 InitPipelineShaderStages (LogicalDevice, ShaderStages, ShaderModules, vkShaderStages);
11191102
1103+ // Build per-stage specialization data and patch vkShaderStages.
1104+ BuildSpecializationData (ShaderStages, CreateInfo.NumSpecializationConstants , CreateInfo.pSpecializationConstants , m_Desc, vkShaderStages, SpecDataPerStage);
1105+
11201106 return ShaderStages;
11211107}
11221108
11231109void PipelineStateVkImpl::InitializePipeline (const GraphicsPipelineStateCreateInfo& CreateInfo)
11241110{
11251111 std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
11261112 std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1113+ std::vector<ShaderStageSpecializationData> SpecDataPerStage;
11271114
1128- const TShaderStages ShaderStages = InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1129-
1130- // Build per-stage specialization data and patch vkShaderStages.
1131- // SpecDataPerStage must outlive the vkCreate*Pipelines call below.
1132- std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1133- BuildSpecializationData (ShaderStages, CreateInfo.NumSpecializationConstants , CreateInfo.pSpecializationConstants , m_Desc, vkShaderStages, SpecDataPerStage);
1115+ InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage);
11341116
11351117 const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache )->GetVkPipelineCache () : VK_NULL_HANDLE;
11361118 CreateGraphicsPipeline (m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_pGraphicsPipelineData->Desc , m_Pipeline, GetRenderPassPtr (), vkSPOCache);
@@ -1140,11 +1122,9 @@ void PipelineStateVkImpl::InitializePipeline(const ComputePipelineStateCreateInf
11401122{
11411123 std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
11421124 std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1125+ std::vector<ShaderStageSpecializationData> SpecDataPerStage;
11431126
1144- const TShaderStages ShaderStages = InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1145-
1146- std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1147- BuildSpecializationData (ShaderStages, CreateInfo.NumSpecializationConstants , CreateInfo.pSpecializationConstants , m_Desc, vkShaderStages, SpecDataPerStage);
1127+ InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage);
11481128
11491129 const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache )->GetVkPipelineCache () : VK_NULL_HANDLE;
11501130 CreateComputePipeline (m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_Pipeline, vkSPOCache);
@@ -1156,11 +1136,9 @@ void PipelineStateVkImpl::InitializePipeline(const RayTracingPipelineStateCreate
11561136
11571137 std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
11581138 std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1139+ std::vector<ShaderStageSpecializationData> SpecDataPerStage;
11591140
1160- const PipelineStateVkImpl::TShaderStages ShaderStages = InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules);
1161-
1162- std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1163- BuildSpecializationData (ShaderStages, CreateInfo.NumSpecializationConstants , CreateInfo.pSpecializationConstants , m_Desc, vkShaderStages, SpecDataPerStage);
1141+ const TShaderStages ShaderStages = InitInternalObjects (CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage);
11641142
11651143 const std::vector<VkRayTracingShaderGroupCreateInfoKHR> vkShaderGroups = BuildRTShaderGroupDescription (CreateInfo, m_pRayTracingPipelineData->NameToGroupIndex , ShaderStages);
11661144 const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache )->GetVkPipelineCache () : VK_NULL_HANDLE;
0 commit comments