Skip to content

Commit 2385c1e

Browse files
hzqstTheMostDiligent
authored andcommitted
SPIRVShaderResources: add specialization constants support
1 parent faeb54f commit 2385c1e

4 files changed

Lines changed: 214 additions & 5 deletions

File tree

Graphics/ShaderTools/include/SPIRVShaderResources.hpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
// | Accel Structs |
4545
// | Push Constants | _ _ _ m_TotalResources
4646
// | Stage Inputs |
47+
// | Spec Constants |
4748
// | Resource Names |
4849

4950
#include <memory>
@@ -192,6 +193,28 @@ struct SPIRVShaderStageInputAttribs
192193
};
193194
static_assert(sizeof(SPIRVShaderStageInputAttribs) % sizeof(void*) == 0, "Size of SPIRVShaderStageInputAttribs struct must be multiple of sizeof(void*)");
194195

196+
// sizeof(SPIRVSpecializationConstantAttribs) == 24, msvc x64
197+
struct SPIRVSpecializationConstantAttribs
198+
{
199+
// clang-format off
200+
SPIRVSpecializationConstantAttribs(const char* _Name,
201+
uint32_t _SpecId,
202+
uint32_t _Size,
203+
SHADER_CODE_BASIC_TYPE _BasicType) :
204+
Name {_Name},
205+
SpecId {_SpecId},
206+
Size {_Size},
207+
BasicType{_BasicType}
208+
{}
209+
// clang-format on
210+
211+
const char* const Name;
212+
const uint32_t SpecId;
213+
const uint32_t Size; // Byte size of the scalar type
214+
const SHADER_CODE_BASIC_TYPE BasicType;
215+
};
216+
static_assert(sizeof(SPIRVSpecializationConstantAttribs) % sizeof(void*) == 0, "Size of SPIRVSpecializationConstantAttribs struct must be multiple of sizeof(void*)");
217+
195218
/// Diligent::SPIRVShaderResources class
196219
class SPIRVShaderResources
197220
{
@@ -271,8 +294,9 @@ class SPIRVShaderResources
271294
Uint32 GetNumAccelStructs ()const noexcept{ return GetNumResources(ResourceClass::AccelStruct); }
272295
Uint32 GetNumPushConstants()const noexcept{ return GetNumResources(ResourceClass::PushConstant); }
273296

274-
Uint32 GetTotalResources() const noexcept { return m_Offsets[static_cast<size_t>(ResourceClass::NumClasses)]; }
275-
Uint32 GetNumShaderStageInputs()const noexcept { return m_NumShaderStageInputs; }
297+
Uint32 GetTotalResources() const noexcept { return m_Offsets[static_cast<size_t>(ResourceClass::NumClasses)]; }
298+
Uint32 GetNumShaderStageInputs() const noexcept { return m_NumShaderStageInputs; }
299+
Uint32 GetNumSpecConstants() const noexcept { return m_NumSpecConstants; }
276300

277301
const SPIRVShaderResourceAttribs& GetUB (Uint32 n)const noexcept{ return GetResAttribs(ResourceClass::UniformBuffer, n); }
278302
const SPIRVShaderResourceAttribs& GetSB (Uint32 n)const noexcept{ return GetResAttribs(ResourceClass::StorageBuffer, n); }
@@ -294,6 +318,14 @@ class SPIRVShaderResources
294318
return reinterpret_cast<const SPIRVShaderStageInputAttribs*>(ResourceMemoryEnd)[n];
295319
}
296320

321+
const SPIRVSpecializationConstantAttribs& GetSpecConstant(Uint32 n) const noexcept
322+
{
323+
VERIFY(n < m_NumSpecConstants, "Specialization constant index (", n, ") is out of range. Total spec constant count: ", m_NumSpecConstants);
324+
const SPIRVShaderResourceAttribs* ResourceMemoryEnd = reinterpret_cast<const SPIRVShaderResourceAttribs*>(m_MemoryBuffer.get()) + GetTotalResources();
325+
const SPIRVShaderStageInputAttribs* StageInputsEnd = reinterpret_cast<const SPIRVShaderStageInputAttribs*>(ResourceMemoryEnd) + m_NumShaderStageInputs;
326+
return reinterpret_cast<const SPIRVSpecializationConstantAttribs*>(StageInputsEnd)[n];
327+
}
328+
297329
const ShaderCodeBufferDesc* GetUniformBufferDesc(Uint32 Index) const
298330
{
299331
if (Index >= GetNumUBs())
@@ -452,6 +484,7 @@ class SPIRVShaderResources
452484
void Initialize(IMemoryAllocator& Allocator,
453485
const ResourceCounters& Counters,
454486
Uint32 NumShaderStageInputs,
487+
Uint32 NumSpecConstants,
455488
size_t ResourceNamesPoolSize,
456489
StringPool& ResourceNamesPool);
457490

@@ -486,8 +519,13 @@ class SPIRVShaderResources
486519
return const_cast<SPIRVShaderStageInputAttribs&>(const_cast<const SPIRVShaderResources*>(this)->GetShaderStageInputAttribs(n));
487520
}
488521

522+
SPIRVSpecializationConstantAttribs& GetSpecConstant(Uint32 n) noexcept
523+
{
524+
return const_cast<SPIRVSpecializationConstantAttribs&>(const_cast<const SPIRVShaderResources*>(this)->GetSpecConstant(n));
525+
}
526+
489527
// Memory buffer that holds all resources as continuous chunk of memory:
490-
// | UBs | SBs | StrgImgs | SmplImgs | ACs | SepSamplers | SepImgs | Stage Inputs | Resource Names |
528+
// | UBs | SBs | StrgImgs | SmplImgs | ACs | SepSamplers | SepImgs | Stage Inputs | Spec Constants | Resource Names |
491529
std::unique_ptr<void, STDDeleterRawMem<void>> m_MemoryBuffer;
492530
std::unique_ptr<void, STDDeleterRawMem<void>> m_UBReflectionBuffer;
493531

@@ -498,6 +536,7 @@ class SPIRVShaderResources
498536
std::array<OffsetType, static_cast<size_t>(ResourceClass::NumClasses) + 1> m_Offsets;
499537

500538
OffsetType m_NumShaderStageInputs = 0;
539+
OffsetType m_NumSpecConstants = 0;
501540

502541
SHADER_TYPE m_ShaderType = SHADER_TYPE_UNKNOWN;
503542

Graphics/ShaderTools/src/SPIRVShaderResources.cpp

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,9 +613,57 @@ SPIRVShaderResources::SPIRVShaderResources(IMemoryAllocator& Allocator,
613613
ResCounters.NumPushConstants = static_cast<Uint32>(resources.push_constant_buffers.size());
614614
static_assert(Uint32{SPIRVShaderResourceAttribs::ResourceType::NumResourceTypes} == 13, "Please set the new resource type counter here");
615615

616+
// Specialization constants reflection
617+
struct SpecConstInfo
618+
{
619+
std::string Name;
620+
uint32_t SpecId = 0;
621+
uint32_t Size = 0;
622+
SHADER_CODE_BASIC_TYPE BasicType = SHADER_CODE_BASIC_TYPE_UNKNOWN;
623+
};
624+
std::vector<SpecConstInfo> SpecConstants;
625+
Uint32 NumSpecConstants = 0;
626+
627+
{
628+
diligent_spirv_cross::SmallVector<diligent_spirv_cross::SpecializationConstant> spec_consts =
629+
Compiler.get_specialization_constants();
630+
for (const diligent_spirv_cross::SpecializationConstant& sc : spec_consts)
631+
{
632+
const diligent_spirv_cross::SPIRConstant& Constant = Compiler.get_constant(sc.id);
633+
const diligent_spirv_cross::SPIRType& Type = Compiler.get_type(Constant.constant_type);
634+
635+
// Only support scalar specialization constants
636+
if (Type.vecsize != 1 || Type.columns != 1)
637+
{
638+
LOG_WARNING_MESSAGE("Specialization constant '", Compiler.get_name(sc.id),
639+
"' (SpecId=", sc.constant_id, ") in shader '", CI.Name,
640+
"' is not a scalar type and will be skipped.");
641+
continue;
642+
}
643+
644+
SpecConstInfo Info;
645+
Info.Name = Compiler.get_name(sc.id);
646+
Info.SpecId = sc.constant_id;
647+
// OpTypeBool has width==1 in SPIRV-Cross; use 4 bytes (VkBool32) for bool specialization constants
648+
Info.Size = Type.basetype == diligent_spirv_cross::SPIRType::Boolean ? 4 : Type.width / 8;
649+
Info.BasicType = SpirvBaseTypeToShaderCodeBasicType(Type.basetype);
650+
651+
if (Info.Name.empty())
652+
{
653+
LOG_WARNING_MESSAGE("Specialization constant with SpecId=", sc.constant_id,
654+
" in shader '", CI.Name, "' has no name (OpName) and will be skipped.");
655+
continue;
656+
}
657+
658+
ResourceNamesPoolSize += Info.Name.length() + 1;
659+
SpecConstants.emplace_back(std::move(Info));
660+
}
661+
NumSpecConstants = static_cast<Uint32>(SpecConstants.size());
662+
}
663+
616664
// Resource names pool is only needed to facilitate string allocation.
617665
StringPool ResourceNamesPool;
618-
Initialize(Allocator, ResCounters, NumShaderStageInputs, ResourceNamesPoolSize, ResourceNamesPool);
666+
Initialize(Allocator, ResCounters, NumShaderStageInputs, NumSpecConstants, ResourceNamesPoolSize, ResourceNamesPool);
619667

620668
// Uniform buffer reflections
621669
std::vector<ShaderCodeBufferDescX> UBReflections;
@@ -842,6 +890,22 @@ SPIRVShaderResources::SPIRVShaderResources(IMemoryAllocator& Allocator,
842890
VERIFY_EXPR(CurrStageInput == GetNumShaderStageInputs());
843891
}
844892

893+
if (!SpecConstants.empty())
894+
{
895+
Uint32 CurrSpecConst = 0;
896+
for (const SpecConstInfo& SC : SpecConstants)
897+
{
898+
new (&GetSpecConstant(CurrSpecConst++)) SPIRVSpecializationConstantAttribs //
899+
{
900+
ResourceNamesPool.CopyString(SC.Name),
901+
SC.SpecId,
902+
SC.Size,
903+
SC.BasicType //
904+
};
905+
}
906+
VERIFY_EXPR(CurrSpecConst == GetNumSpecConstants());
907+
}
908+
845909
VERIFY(ResourceNamesPool.GetRemainingSize() == 0, "Names pool must be empty");
846910

847911
if (m_ShaderType == SHADER_TYPE_COMPUTE)
@@ -862,6 +926,7 @@ SPIRVShaderResources::SPIRVShaderResources(IMemoryAllocator& Allocator,
862926
void SPIRVShaderResources::Initialize(IMemoryAllocator& Allocator,
863927
const ResourceCounters& Counters,
864928
Uint32 NumShaderStageInputs,
929+
Uint32 NumSpecConstants,
865930
size_t ResourceNamesPoolSize,
866931
StringPool& ResourceNamesPool)
867932
{
@@ -890,12 +955,16 @@ void SPIRVShaderResources::Initialize(IMemoryAllocator& Allocator,
890955
VERIFY(NumShaderStageInputs <= MaxOffset, "Max offset exceeded");
891956
m_NumShaderStageInputs = static_cast<OffsetType>(NumShaderStageInputs);
892957

958+
VERIFY(NumSpecConstants <= MaxOffset, "Max offset exceeded");
959+
m_NumSpecConstants = static_cast<OffsetType>(NumSpecConstants);
960+
893961
size_t AlignedResourceNamesPoolSize = AlignUp(ResourceNamesPoolSize, sizeof(void*));
894962

895963
static_assert(sizeof(SPIRVShaderResourceAttribs) % sizeof(void*) == 0, "Size of SPIRVShaderResourceAttribs struct must be multiple of sizeof(void*)");
896964
// clang-format off
897965
size_t MemorySize = GetTotalResources() * sizeof(SPIRVShaderResourceAttribs) +
898966
m_NumShaderStageInputs * sizeof(SPIRVShaderStageInputAttribs) +
967+
m_NumSpecConstants * sizeof(SPIRVSpecializationConstantAttribs) +
899968
AlignedResourceNamesPoolSize * sizeof(char);
900969

901970
VERIFY_EXPR(GetNumUBs() == Counters.NumUBs);
@@ -917,7 +986,8 @@ void SPIRVShaderResources::Initialize(IMemoryAllocator& Allocator,
917986
m_MemoryBuffer = std::unique_ptr<void, STDDeleterRawMem<void>>(pRawMem, Allocator);
918987
char* NamesPool = reinterpret_cast<char*>(m_MemoryBuffer.get()) +
919988
GetTotalResources() * sizeof(SPIRVShaderResourceAttribs) +
920-
m_NumShaderStageInputs * sizeof(SPIRVShaderStageInputAttribs);
989+
m_NumShaderStageInputs * sizeof(SPIRVShaderStageInputAttribs) +
990+
m_NumSpecConstants * sizeof(SPIRVSpecializationConstantAttribs);
921991
ResourceNamesPool.AssignMemory(NamesPool, ResourceNamesPoolSize);
922992
}
923993
}
@@ -951,6 +1021,9 @@ SPIRVShaderResources::~SPIRVShaderResources()
9511021
for (Uint32 n = 0; n < GetNumShaderStageInputs(); ++n)
9521022
GetShaderStageInputAttribs(n).~SPIRVShaderStageInputAttribs();
9531023

1024+
for (Uint32 n = 0; n < GetNumSpecConstants(); ++n)
1025+
GetSpecConstant(n).~SPIRVSpecializationConstantAttribs();
1026+
9541027
for (Uint32 n = 0; n < GetNumAccelStructs(); ++n)
9551028
GetAccelStruct(n).~SPIRVShaderResourceAttribs();
9561029

@@ -1191,6 +1264,18 @@ std::string SPIRVShaderResources::DumpResources() const
11911264
);
11921265
VERIFY_EXPR(ResNum == GetTotalResources());
11931266

1267+
if (GetNumSpecConstants() > 0)
1268+
{
1269+
ss << std::endl
1270+
<< "Specialization constants (" << GetNumSpecConstants() << "):";
1271+
for (Uint32 n = 0; n < GetNumSpecConstants(); ++n)
1272+
{
1273+
const auto& SC = GetSpecConstant(n);
1274+
ss << std::endl
1275+
<< " '" << SC.Name << "' SpecId=" << SC.SpecId << " Size=" << SC.Size;
1276+
}
1277+
}
1278+
11941279
return ss.str();
11951280
}
11961281

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Test shader for specialization constants reflection
2+
[[vk::constant_id(0)]] const bool g_EnableFeature = true;
3+
[[vk::constant_id(1)]] const int g_IntParam = 42;
4+
[[vk::constant_id(2)]] const uint g_UintParam = 7;
5+
[[vk::constant_id(3)]] const float g_FloatParam = 1.0f;
6+
7+
float4 main() : SV_Target
8+
{
9+
float4 result = float4(0, 0, 0, 1);
10+
11+
if (g_EnableFeature)
12+
{
13+
result.x = float(g_IntParam);
14+
result.y = float(g_UintParam);
15+
result.z = g_FloatParam;
16+
}
17+
18+
return result;
19+
}

Tests/DiligentCoreTest/src/ShaderTools/SPIRVShaderResourcesTest.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,4 +589,70 @@ TEST_F(SPIRVShaderResourcesTest, MixedResources_DXC)
589589
TestMixedResources(SHADER_COMPILER_DXC);
590590
}
591591

592+
593+
struct SPIRVSpecConstRefAttribs
594+
{
595+
const char* const Name;
596+
const uint32_t SpecId;
597+
const uint32_t Size;
598+
const SHADER_CODE_BASIC_TYPE BasicType;
599+
};
600+
601+
void TestSpecializationConstants(SHADER_COMPILER Compiler)
602+
{
603+
std::vector<unsigned int> SPIRV;
604+
ASSERT_NO_FATAL_FAILURE(CompileSPIRV("SpecializationConstants.psh", Compiler, SHADER_TYPE_PIXEL, SHADER_SOURCE_LANGUAGE_HLSL, SPIRV));
605+
606+
if (::testing::Test::IsSkipped())
607+
return;
608+
609+
SPIRVShaderResources::CreateInfo ResCI;
610+
ResCI.ShaderType = SHADER_TYPE_PIXEL;
611+
ResCI.Name = "SpecConstants test";
612+
SPIRVShaderResources Resources{
613+
GetRawAllocator(),
614+
SPIRV,
615+
ResCI,
616+
};
617+
618+
LOG_INFO_MESSAGE("SPIRV Resources:\n", Resources.DumpResources());
619+
620+
const std::vector<SPIRVSpecConstRefAttribs> RefSpecConstants = {
621+
{"g_EnableFeature", 0, 4, SHADER_CODE_BASIC_TYPE_BOOL},
622+
{"g_IntParam", 1, 4, SHADER_CODE_BASIC_TYPE_INT},
623+
{"g_UintParam", 2, 4, SHADER_CODE_BASIC_TYPE_UINT},
624+
{"g_FloatParam", 3, 4, SHADER_CODE_BASIC_TYPE_FLOAT},
625+
};
626+
627+
const SPIRVShaderResources& ConstResources = Resources;
628+
EXPECT_EQ(ConstResources.GetNumSpecConstants(), static_cast<Uint32>(RefSpecConstants.size()));
629+
630+
// Build a map from name to reference for order-independent matching
631+
std::unordered_map<std::string, const SPIRVSpecConstRefAttribs*> RefMap;
632+
for (const SPIRVSpecConstRefAttribs& Ref : RefSpecConstants)
633+
RefMap[Ref.Name] = &Ref;
634+
635+
for (Uint32 i = 0; i < ConstResources.GetNumSpecConstants(); ++i)
636+
{
637+
const SPIRVSpecializationConstantAttribs& SC = ConstResources.GetSpecConstant(i);
638+
const auto it = RefMap.find(SC.Name);
639+
ASSERT_NE(it, RefMap.end()) << "Specialization constant '" << SC.Name << "' is not found in the reference list";
640+
641+
const auto* pRef = it->second;
642+
EXPECT_EQ(SC.SpecId, pRef->SpecId) << SC.Name;
643+
EXPECT_EQ(SC.Size, pRef->Size) << SC.Name;
644+
EXPECT_EQ(SC.BasicType, pRef->BasicType) << SC.Name;
645+
}
646+
}
647+
648+
TEST_F(SPIRVShaderResourcesTest, SpecializationConstants_GLSLang)
649+
{
650+
TestSpecializationConstants(SHADER_COMPILER_GLSLANG);
651+
}
652+
653+
TEST_F(SPIRVShaderResourcesTest, SpecializationConstants_DXC)
654+
{
655+
TestSpecializationConstants(SHADER_COMPILER_DXC);
656+
}
657+
592658
} // namespace

0 commit comments

Comments
 (0)