Skip to content

Commit 17ac310

Browse files
committed
Change as requested: for each SPIRV spec constant, match the user-provided spec constant. moving BuildSpecializationData into InitInternalObjects.
1 parent 71b0b7b commit 17ac310

2 files changed

Lines changed: 59 additions & 79 deletions

File tree

Graphics/GraphicsEngineVulkan/include/PipelineStateVkImpl.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ namespace Diligent
4848
{
4949

5050
class DeviceContextVkImpl;
51+
struct ShaderStageSpecializationData;
5152

5253
/// Pipeline state object implementation in Vulkan backend.
5354
class PipelineStateVkImpl final : public PipelineStateBase<EngineVkImplTraits>
@@ -131,7 +132,8 @@ class PipelineStateVkImpl final : public PipelineStateBase<EngineVkImplTraits>
131132
template <typename PSOCreateInfoType>
132133
TShaderStages InitInternalObjects(const PSOCreateInfoType& CreateInfo,
133134
std::vector<VkPipelineShaderStageCreateInfo>& vkShaderStages,
134-
std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules) noexcept(false);
135+
std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules,
136+
std::vector<ShaderStageSpecializationData>& SpecDataPerStage) noexcept(false);
135137

136138
void InitPipelineLayout(const PipelineStateCreateInfo& CreateInfo,
137139
TShaderStages& ShaderStages) noexcept(false);

Graphics/GraphicsEngineVulkan/src/PipelineStateVkImpl.cpp

Lines changed: 56 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ namespace Diligent
5252

5353
constexpr INTERFACE_ID PipelineStateVkImpl::IID_InternalImpl;
5454

55+
// Per-shader-stage specialization constant data for Vulkan pipeline creation.
56+
// Holds VkSpecializationMapEntry array, contiguous data blob, and the
57+
// VkSpecializationInfo that references them. Lifetime must exceed the
58+
// vkCreate*Pipelines call.
59+
struct ShaderStageSpecializationData
60+
{
61+
std::vector<VkSpecializationMapEntry> MapEntries;
62+
std::vector<Uint8> DataBlob;
63+
VkSpecializationInfo Info{};
64+
};
65+
5566
namespace
5667
{
5768

@@ -95,20 +106,11 @@ void InitPipelineShaderStages(const VulkanUtilities::LogicalDevice&
95106
}
96107

97108

98-
// Per-shader-stage specialization constant data for Vulkan pipeline creation.
99-
// Holds VkSpecializationMapEntry array, contiguous data blob, and the
100-
// VkSpecializationInfo that references them. Lifetime must exceed the
101-
// vkCreate*Pipelines call.
102-
struct ShaderStageSpecializationData
103-
{
104-
std::vector<VkSpecializationMapEntry> MapEntries;
105-
std::vector<Uint8> DataBlob;
106-
VkSpecializationInfo Info{};
107-
};
108-
109-
// Matches user-provided SpecializationConstant entries to SPIR-V reflected
110-
// specialization constants by name, validates size compatibility, and builds
111-
// per-stage VkSpecializationInfo structures.
109+
// Iterates SPIR-V reflected specialization constants and matches them to
110+
// user-provided SpecializationConstant entries by name. If a reflected
111+
// constant has no matching user entry the constant is silently skipped,
112+
// which allows the user to supply a superset of constants shared across
113+
// multiple pipelines / stages.
112114
//
113115
// Parameters:
114116
// ShaderStages - shader stages extracted from the PSO create info
@@ -118,7 +120,7 @@ struct ShaderStageSpecializationData
118120
// vkStages [in/out] - VkPipelineShaderStageCreateInfo array to patch
119121
// SpecDataPerStage [out] - per-stage specialization data (must outlive vkCreate*Pipelines)
120122
//
121-
// Throws on validation failure (name not found, size mismatch, duplicate SpecId binding).
123+
// Throws on size mismatch between user-provided and reflected constants.
122124
void BuildSpecializationData(const PipelineStateVkImpl::TShaderStages& ShaderStages,
123125
Uint32 NumSpecializationConstants,
124126
const SpecializationConstant* pSpecializationConstants,
@@ -144,78 +146,58 @@ void BuildSpecializationData(const PipelineStateVkImpl::TShaderStages& Shade
144146
const SPIRVShaderResources* pResources = pShader->GetShaderResources().get();
145147
ShaderStageSpecializationData& StageData = SpecDataPerStage[vkStageIdx];
146148

147-
// Track SpecIds already bound in this stage to detect duplicates.
148-
std::unordered_map<uint32_t, const char*> BoundSpecIds;
149-
150149
Uint32 DataOffset = 0;
151-
for (Uint32 sc = 0; sc < NumSpecializationConstants; ++sc)
150+
for (Uint32 r = 0; r < pResources->GetNumSpecConstants(); ++r)
152151
{
153-
const SpecializationConstant& UserConst = pSpecializationConstants[sc];
152+
const SPIRVSpecializationConstantAttribs& Reflected = pResources->GetSpecConstant(r);
154153

155-
// Check if this constant applies to the current shader stage.
156-
if ((UserConst.ShaderStages & Stage.Type) == 0)
157-
continue;
158-
159-
// Search for the matching reflected specialization constant by name.
160-
const SPIRVSpecializationConstantAttribs* pReflected = nullptr;
161-
for (Uint32 r = 0; r < pResources->GetNumSpecConstants(); ++r)
154+
// Search for a matching user-provided constant by name and stage flag.
155+
const SpecializationConstant* pUserConst = nullptr;
156+
for (Uint32 sc = 0; sc < NumSpecializationConstants; ++sc)
162157
{
163-
const auto& Reflected = pResources->GetSpecConstant(r);
164-
if (strcmp(Reflected.Name, UserConst.Name) == 0)
158+
const SpecializationConstant& Candidate = pSpecializationConstants[sc];
159+
if ((Candidate.ShaderStages & Stage.Type) != 0 &&
160+
strcmp(Candidate.Name, Reflected.Name) == 0)
165161
{
166-
pReflected = &Reflected;
162+
pUserConst = &Candidate;
167163
break;
168164
}
169165
}
170166

171-
if (pReflected == nullptr)
172-
{
173-
LOG_ERROR_AND_THROW("Description of ", GetPipelineTypeString(PSODesc.PipelineType),
174-
" PSO '", (PSODesc.Name != nullptr ? PSODesc.Name : ""),
175-
"' is invalid: specialization constant '", UserConst.Name,
176-
"' was not found in ", GetShaderTypeLiteralName(Stage.Type),
177-
" shader '", pShader->GetDesc().Name, "'.");
178-
}
167+
// No user constant for this reflected entry -- skip silently.
168+
if (pUserConst == nullptr)
169+
continue;
179170

180-
// Validate size compatibility.
181-
if (UserConst.Size != pReflected->Size)
171+
// The user may provide more data than the shader needs (e.g. when
172+
// sharing a constant array across pipelines with different types).
173+
// Only reject the case where the user provides less data than required.
174+
if (pUserConst->Size < Reflected.Size)
182175
{
183176
LOG_ERROR_AND_THROW("Description of ", GetPipelineTypeString(PSODesc.PipelineType),
184177
" PSO '", (PSODesc.Name != nullptr ? PSODesc.Name : ""),
185-
"' is invalid: specialization constant '", UserConst.Name,
178+
"' is invalid: specialization constant '", pUserConst->Name,
186179
"' in ", GetShaderTypeLiteralName(Stage.Type),
187180
" shader '", pShader->GetDesc().Name,
188-
"' has size mismatch: user provided ", UserConst.Size,
181+
"' has insufficient data: user provided ", pUserConst->Size,
189182
" bytes, but the shader declares ",
190-
GetShaderCodeBasicTypeString(pReflected->BasicType),
191-
" (", pReflected->Size, " bytes).");
183+
GetShaderCodeBasicTypeString(Reflected.BasicType),
184+
" (", Reflected.Size, " bytes).");
192185
}
193186

194-
// Check for duplicate SpecId binding within this stage.
195-
auto it = BoundSpecIds.find(pReflected->SpecId);
196-
if (it != BoundSpecIds.end())
197-
{
198-
LOG_ERROR_AND_THROW("Description of ", GetPipelineTypeString(PSODesc.PipelineType),
199-
" PSO '", (PSODesc.Name != nullptr ? PSODesc.Name : ""),
200-
"' is invalid: specialization constant '", UserConst.Name,
201-
"' in ", GetShaderTypeLiteralName(Stage.Type),
202-
" shader '", pShader->GetDesc().Name,
203-
"' maps to SpecId ", pReflected->SpecId,
204-
" which is already bound by constant '", it->second, "'.");
205-
}
206-
BoundSpecIds.emplace(pReflected->SpecId, UserConst.Name);
187+
// Use the reflected size -- it is the actual size the shader expects.
188+
const Uint32 ConstSize = Reflected.Size;
207189

208190
// Build the map entry.
209191
VkSpecializationMapEntry Entry{};
210-
Entry.constantID = pReflected->SpecId;
192+
Entry.constantID = Reflected.SpecId;
211193
Entry.offset = DataOffset;
212-
Entry.size = UserConst.Size;
194+
Entry.size = ConstSize;
213195
StageData.MapEntries.push_back(Entry);
214196

215-
// Append data to the blob.
216-
const Uint8* pSrcData = static_cast<const Uint8*>(UserConst.pData);
217-
StageData.DataBlob.insert(StageData.DataBlob.end(), pSrcData, pSrcData + UserConst.Size);
218-
DataOffset += UserConst.Size;
197+
// Append data to the blob (only the bytes the shader needs).
198+
const Uint8* pSrcData = static_cast<const Uint8*>(pUserConst->pData);
199+
StageData.DataBlob.insert(StageData.DataBlob.end(), pSrcData, pSrcData + ConstSize);
200+
DataOffset += ConstSize;
219201
}
220202

221203
// Populate VkSpecializationInfo if any entries were matched.
@@ -1097,7 +1079,8 @@ template <typename PSOCreateInfoType>
10971079
PipelineStateVkImpl::TShaderStages PipelineStateVkImpl::InitInternalObjects(
10981080
const PSOCreateInfoType& CreateInfo,
10991081
std::vector<VkPipelineShaderStageCreateInfo>& vkShaderStages,
1100-
std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules) noexcept(false)
1082+
std::vector<VulkanUtilities::ShaderModuleWrapper>& ShaderModules,
1083+
std::vector<ShaderStageSpecializationData>& SpecDataPerStage) noexcept(false)
11011084
{
11021085
TShaderStages ShaderStages;
11031086
ExtractShaders<ShaderVkImpl>(CreateInfo, ShaderStages, /*WaitUntilShadersReady = */ true);
@@ -1117,20 +1100,19 @@ PipelineStateVkImpl::TShaderStages PipelineStateVkImpl::InitInternalObjects(
11171100
// Create shader modules and initialize shader stages
11181101
InitPipelineShaderStages(LogicalDevice, ShaderStages, ShaderModules, vkShaderStages);
11191102

1103+
// Build per-stage specialization data and patch vkShaderStages.
1104+
BuildSpecializationData(ShaderStages, CreateInfo.NumSpecializationConstants, CreateInfo.pSpecializationConstants, m_Desc, vkShaderStages, SpecDataPerStage);
1105+
11201106
return ShaderStages;
11211107
}
11221108

11231109
void PipelineStateVkImpl::InitializePipeline(const GraphicsPipelineStateCreateInfo& CreateInfo)
11241110
{
11251111
std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
11261112
std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1113+
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
11271114

1128-
const TShaderStages ShaderStages = InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1129-
1130-
// Build per-stage specialization data and patch vkShaderStages.
1131-
// SpecDataPerStage must outlive the vkCreate*Pipelines call below.
1132-
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1133-
BuildSpecializationData(ShaderStages, CreateInfo.NumSpecializationConstants, CreateInfo.pSpecializationConstants, m_Desc, vkShaderStages, SpecDataPerStage);
1115+
InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage);
11341116

11351117
const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache)->GetVkPipelineCache() : VK_NULL_HANDLE;
11361118
CreateGraphicsPipeline(m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_pGraphicsPipelineData->Desc, m_Pipeline, GetRenderPassPtr(), vkSPOCache);
@@ -1140,11 +1122,9 @@ void PipelineStateVkImpl::InitializePipeline(const ComputePipelineStateCreateInf
11401122
{
11411123
std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
11421124
std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1125+
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
11431126

1144-
const TShaderStages ShaderStages = InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1145-
1146-
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1147-
BuildSpecializationData(ShaderStages, CreateInfo.NumSpecializationConstants, CreateInfo.pSpecializationConstants, m_Desc, vkShaderStages, SpecDataPerStage);
1127+
InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage);
11481128

11491129
const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache)->GetVkPipelineCache() : VK_NULL_HANDLE;
11501130
CreateComputePipeline(m_pDevice, vkShaderStages, m_PipelineLayout, m_Desc, m_Pipeline, vkSPOCache);
@@ -1156,11 +1136,9 @@ void PipelineStateVkImpl::InitializePipeline(const RayTracingPipelineStateCreate
11561136

11571137
std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
11581138
std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
1139+
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
11591140

1160-
const PipelineStateVkImpl::TShaderStages ShaderStages = InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules);
1161-
1162-
std::vector<ShaderStageSpecializationData> SpecDataPerStage;
1163-
BuildSpecializationData(ShaderStages, CreateInfo.NumSpecializationConstants, CreateInfo.pSpecializationConstants, m_Desc, vkShaderStages, SpecDataPerStage);
1141+
const TShaderStages ShaderStages = InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules, SpecDataPerStage);
11641142

11651143
const std::vector<VkRayTracingShaderGroupCreateInfoKHR> vkShaderGroups = BuildRTShaderGroupDescription(CreateInfo, m_pRayTracingPipelineData->NameToGroupIndex, ShaderStages);
11661144
const VkPipelineCache vkSPOCache = CreateInfo.pPSOCache != nullptr ? ClassPtrCast<PipelineStateCacheVkImpl>(CreateInfo.pPSOCache)->GetVkPipelineCache() : VK_NULL_HANDLE;

0 commit comments

Comments
 (0)