@@ -613,9 +613,57 @@ SPIRVShaderResources::SPIRVShaderResources(IMemoryAllocator& Allocator,
613613 ResCounters.NumPushConstants = static_cast <Uint32>(resources.push_constant_buffers .size ());
614614 static_assert (Uint32{SPIRVShaderResourceAttribs::ResourceType::NumResourceTypes} == 13 , " Please set the new resource type counter here" );
615615
616+ // Specialization constants reflection
617+ struct SpecConstInfo
618+ {
619+ std::string Name;
620+ uint32_t SpecId = 0 ;
621+ uint32_t Size = 0 ;
622+ SHADER_CODE_BASIC_TYPE BasicType = SHADER_CODE_BASIC_TYPE_UNKNOWN;
623+ };
624+ std::vector<SpecConstInfo> SpecConstants;
625+ Uint32 NumSpecConstants = 0 ;
626+
627+ {
628+ diligent_spirv_cross::SmallVector<diligent_spirv_cross::SpecializationConstant> spec_consts =
629+ Compiler.get_specialization_constants ();
630+ for (const diligent_spirv_cross::SpecializationConstant& sc : spec_consts)
631+ {
632+ const diligent_spirv_cross::SPIRConstant& Constant = Compiler.get_constant (sc.id );
633+ const diligent_spirv_cross::SPIRType& Type = Compiler.get_type (Constant.constant_type );
634+
635+ // Only support scalar specialization constants
636+ if (Type.vecsize != 1 || Type.columns != 1 )
637+ {
638+ LOG_WARNING_MESSAGE (" Specialization constant '" , Compiler.get_name (sc.id ),
639+ " ' (SpecId=" , sc.constant_id , " ) in shader '" , CI.Name ,
640+ " ' is not a scalar type and will be skipped." );
641+ continue ;
642+ }
643+
644+ SpecConstInfo Info;
645+ Info.Name = Compiler.get_name (sc.id );
646+ Info.SpecId = sc.constant_id ;
647+ // OpTypeBool has width==1 in SPIRV-Cross; use 4 bytes (VkBool32) for bool specialization constants
648+ Info.Size = Type.basetype == diligent_spirv_cross::SPIRType::Boolean ? 4 : Type.width / 8 ;
649+ Info.BasicType = SpirvBaseTypeToShaderCodeBasicType (Type.basetype );
650+
651+ if (Info.Name .empty ())
652+ {
653+ LOG_WARNING_MESSAGE (" Specialization constant with SpecId=" , sc.constant_id ,
654+ " in shader '" , CI.Name , " ' has no name (OpName) and will be skipped." );
655+ continue ;
656+ }
657+
658+ ResourceNamesPoolSize += Info.Name .length () + 1 ;
659+ SpecConstants.emplace_back (std::move (Info));
660+ }
661+ NumSpecConstants = static_cast <Uint32>(SpecConstants.size ());
662+ }
663+
616664 // Resource names pool is only needed to facilitate string allocation.
617665 StringPool ResourceNamesPool;
618- Initialize (Allocator, ResCounters, NumShaderStageInputs, ResourceNamesPoolSize, ResourceNamesPool);
666+ Initialize (Allocator, ResCounters, NumShaderStageInputs, NumSpecConstants, ResourceNamesPoolSize, ResourceNamesPool);
619667
620668 // Uniform buffer reflections
621669 std::vector<ShaderCodeBufferDescX> UBReflections;
@@ -842,6 +890,22 @@ SPIRVShaderResources::SPIRVShaderResources(IMemoryAllocator& Allocator,
842890 VERIFY_EXPR (CurrStageInput == GetNumShaderStageInputs ());
843891 }
844892
893+ if (!SpecConstants.empty ())
894+ {
895+ Uint32 CurrSpecConst = 0 ;
896+ for (const SpecConstInfo& SC : SpecConstants)
897+ {
898+ new (&GetSpecConstant (CurrSpecConst++)) SPIRVSpecializationConstantAttribs //
899+ {
900+ ResourceNamesPool.CopyString (SC.Name ),
901+ SC.SpecId ,
902+ SC.Size ,
903+ SC.BasicType //
904+ };
905+ }
906+ VERIFY_EXPR (CurrSpecConst == GetNumSpecConstants ());
907+ }
908+
845909 VERIFY (ResourceNamesPool.GetRemainingSize () == 0 , " Names pool must be empty" );
846910
847911 if (m_ShaderType == SHADER_TYPE_COMPUTE)
@@ -862,6 +926,7 @@ SPIRVShaderResources::SPIRVShaderResources(IMemoryAllocator& Allocator,
862926void SPIRVShaderResources::Initialize (IMemoryAllocator& Allocator,
863927 const ResourceCounters& Counters,
864928 Uint32 NumShaderStageInputs,
929+ Uint32 NumSpecConstants,
865930 size_t ResourceNamesPoolSize,
866931 StringPool& ResourceNamesPool)
867932{
@@ -890,12 +955,16 @@ void SPIRVShaderResources::Initialize(IMemoryAllocator& Allocator,
890955 VERIFY (NumShaderStageInputs <= MaxOffset, " Max offset exceeded" );
891956 m_NumShaderStageInputs = static_cast <OffsetType>(NumShaderStageInputs);
892957
958+ VERIFY (NumSpecConstants <= MaxOffset, " Max offset exceeded" );
959+ m_NumSpecConstants = static_cast <OffsetType>(NumSpecConstants);
960+
893961 size_t AlignedResourceNamesPoolSize = AlignUp (ResourceNamesPoolSize, sizeof (void *));
894962
895963 static_assert (sizeof (SPIRVShaderResourceAttribs) % sizeof (void *) == 0 , " Size of SPIRVShaderResourceAttribs struct must be multiple of sizeof(void*)" );
896964 // clang-format off
897965 size_t MemorySize = GetTotalResources () * sizeof (SPIRVShaderResourceAttribs) +
898966 m_NumShaderStageInputs * sizeof (SPIRVShaderStageInputAttribs) +
967+ m_NumSpecConstants * sizeof (SPIRVSpecializationConstantAttribs) +
899968 AlignedResourceNamesPoolSize * sizeof (char );
900969
901970 VERIFY_EXPR (GetNumUBs () == Counters.NumUBs );
@@ -917,7 +986,8 @@ void SPIRVShaderResources::Initialize(IMemoryAllocator& Allocator,
917986 m_MemoryBuffer = std::unique_ptr<void , STDDeleterRawMem<void >>(pRawMem, Allocator);
918987 char * NamesPool = reinterpret_cast <char *>(m_MemoryBuffer.get ()) +
919988 GetTotalResources () * sizeof (SPIRVShaderResourceAttribs) +
920- m_NumShaderStageInputs * sizeof (SPIRVShaderStageInputAttribs);
989+ m_NumShaderStageInputs * sizeof (SPIRVShaderStageInputAttribs) +
990+ m_NumSpecConstants * sizeof (SPIRVSpecializationConstantAttribs);
921991 ResourceNamesPool.AssignMemory (NamesPool, ResourceNamesPoolSize);
922992 }
923993}
@@ -951,6 +1021,9 @@ SPIRVShaderResources::~SPIRVShaderResources()
9511021 for (Uint32 n = 0 ; n < GetNumShaderStageInputs (); ++n)
9521022 GetShaderStageInputAttribs (n).~SPIRVShaderStageInputAttribs ();
9531023
1024+ for (Uint32 n = 0 ; n < GetNumSpecConstants (); ++n)
1025+ GetSpecConstant (n).~SPIRVSpecializationConstantAttribs ();
1026+
9541027 for (Uint32 n = 0 ; n < GetNumAccelStructs (); ++n)
9551028 GetAccelStruct (n).~SPIRVShaderResourceAttribs ();
9561029
@@ -1191,6 +1264,18 @@ std::string SPIRVShaderResources::DumpResources() const
11911264 );
11921265 VERIFY_EXPR (ResNum == GetTotalResources ());
11931266
1267+ if (GetNumSpecConstants () > 0 )
1268+ {
1269+ ss << std::endl
1270+ << " Specialization constants (" << GetNumSpecConstants () << " ):" ;
1271+ for (Uint32 n = 0 ; n < GetNumSpecConstants (); ++n)
1272+ {
1273+ const auto & SC = GetSpecConstant (n);
1274+ ss << std::endl
1275+ << " '" << SC.Name << " ' SpecId=" << SC.SpecId << " Size=" << SC.Size ;
1276+ }
1277+ }
1278+
11941279 return ss.str ();
11951280}
11961281
0 commit comments