Skip to content

Commit 0e39a6b

Browse files
committed
Add specialization constants support for WGSLShaderResource.
1 parent 3f6694e commit 0e39a6b

2 files changed

Lines changed: 108 additions & 6 deletions

File tree

Graphics/ShaderTools/include/WGSLShaderResources.hpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2024 Diligent Graphics LLC
2+
* Copyright 2024-2026 Diligent Graphics LLC
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -32,8 +32,8 @@
3232
// WGSLShaderResources class uses continuous chunk of memory to store all resources, as follows:
3333
//
3434
// m_MemoryBuffer
35-
// | |
36-
// | Uniform Buffers | Storage Buffers | Textures | Storage Textures | Samplers | Ext Textures | Resource Names |
35+
// | |
36+
// | Uniform Buffers | Storage Buffers | Textures | Storage Textures | Samplers | Ext Textures | Spec Constants | Resource Names |
3737

3838
#include <memory>
3939
#include <string>
@@ -159,6 +159,34 @@ struct WGSLShaderResourceAttribs
159159
static_assert(sizeof(WGSLShaderResourceAttribs) % sizeof(void*) == 0, "Size of WGSLShaderResourceAttribs struct must be a multiple of sizeof(void*)");
160160

161161

162+
struct WGSLSpecializationConstantAttribs
163+
{
164+
// clang-format off
165+
166+
/* 0 */const char* const Name;
167+
/* 8 */const Uint16 OverrideId;
168+
/* 10 */const Uint8 Type; // SHADER_CODE_BASIC_TYPE
169+
/* 11 */const Uint8 Padding;
170+
/* 12 */const Uint32 Reserved;
171+
/* 16 */ // End of structure
172+
173+
// clang-format on
174+
175+
WGSLSpecializationConstantAttribs(const char* _Name,
176+
Uint16 _OverrideId,
177+
SHADER_CODE_BASIC_TYPE _Type) noexcept :
178+
Name{_Name},
179+
OverrideId{_OverrideId},
180+
Type{static_cast<Uint8>(_Type)},
181+
Padding{0},
182+
Reserved{0}
183+
{}
184+
185+
SHADER_CODE_BASIC_TYPE GetType() const { return static_cast<SHADER_CODE_BASIC_TYPE>(Type); }
186+
};
187+
static_assert(sizeof(WGSLSpecializationConstantAttribs) % sizeof(void*) == 0, "Size of WGSLSpecializationConstantAttribs struct must be a multiple of sizeof(void*)");
188+
189+
162190
/// Diligent::WGSLShaderResources class
163191
class WGSLShaderResources
164192
{
@@ -191,6 +219,7 @@ class WGSLShaderResources
191219
Uint32 GetNumSamplers ()const noexcept{ return (m_ExternalTextureOffset - m_SamplerOffset); }
192220
Uint32 GetNumExtTextures ()const noexcept{ return (m_TotalResources - m_ExternalTextureOffset);}
193221
Uint32 GetTotalResources ()const noexcept{ return m_TotalResources; }
222+
Uint32 GetNumSpecConstants()const noexcept{ return m_NumSpecConstants; }
194223

195224
const WGSLShaderResourceAttribs& GetUB (Uint32 n) const noexcept { return GetResAttribs(n, GetNumUBs(), 0 ); }
196225
const WGSLShaderResourceAttribs& GetSB (Uint32 n) const noexcept { return GetResAttribs(n, GetNumSBs(), m_StorageBufferOffset ); }
@@ -200,6 +229,8 @@ class WGSLShaderResources
200229
const WGSLShaderResourceAttribs& GetExtTexture(Uint32 n) const noexcept { return GetResAttribs(n, GetNumExtTextures(), m_ExternalTextureOffset); }
201230
const WGSLShaderResourceAttribs& GetResource (Uint32 n) const noexcept { return GetResAttribs(n, GetTotalResources(), 0 ); }
202231

232+
const WGSLSpecializationConstantAttribs& GetSpecConstant(Uint32 n) const noexcept;
233+
203234
// clang-format on
204235

205236
const ShaderCodeBufferDesc* GetUniformBufferDesc(Uint32 Index) const
@@ -308,6 +339,7 @@ class WGSLShaderResources
308339
private:
309340
void Initialize(IMemoryAllocator& Allocator,
310341
const ResourceCounters& Counters,
342+
Uint32 NumSpecConstants,
311343
size_t ResourceNamesPoolSize,
312344
StringPool& ResourceNamesPool);
313345

@@ -335,11 +367,13 @@ class WGSLShaderResources
335367
WGSLShaderResourceAttribs& GetExtTexture(Uint32 n) noexcept { return GetResAttribs(n, GetNumExtTextures(), m_ExternalTextureOffset); }
336368
WGSLShaderResourceAttribs& GetResource (Uint32 n) noexcept { return GetResAttribs(n, GetTotalResources(), 0 ); }
337369

370+
WGSLSpecializationConstantAttribs& GetSpecConstant(Uint32 n) noexcept;
371+
338372
// clang-format on
339373

340374
private:
341375
// Memory buffer that holds all resources as continuous chunk of memory:
342-
// | UBs | SBs | Textures | StorageTex | Samplers | ExternalTex | Resource Names |
376+
// | UBs | SBs | Textures | StorageTex | Samplers | ExternalTex | Spec Constants | Resource Names |
343377
std::unique_ptr<void, STDDeleterRawMem<void>> m_MemoryBuffer;
344378
std::unique_ptr<void, STDDeleterRawMem<void>> m_UBReflectionBuffer;
345379

@@ -355,6 +389,7 @@ class WGSLShaderResources
355389
OffsetType m_SamplerOffset = 0;
356390
OffsetType m_ExternalTextureOffset = 0;
357391
OffsetType m_TotalResources = 0;
392+
OffsetType m_NumSpecConstants = 0;
358393

359394
SHADER_TYPE m_ShaderType = SHADER_TYPE_UNKNOWN;
360395
};

Graphics/ShaderTools/src/WGSLShaderResources.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -901,9 +901,17 @@ WGSLShaderResources::WGSLShaderResources(IMemoryAllocator& Allocator,
901901
ResourceNamesPoolSize += strlen(ShaderName) + 1;
902902
ResourceNamesPoolSize += strlen(EntryPoint) + 1;
903903

904+
// Count override constants (specialization constants)
905+
const auto& Overrides = EntryPoints[EntryPointIdx].overrides;
906+
Uint32 NumSpecConstants = static_cast<Uint32>(Overrides.size());
907+
for (const auto& Override : Overrides)
908+
{
909+
ResourceNamesPoolSize += Override.name.length() + 1;
910+
}
911+
904912
// Resource names pool is only needed to facilitate string allocation.
905913
StringPool ResourceNamesPool;
906-
Initialize(Allocator, ResCounters, ResourceNamesPoolSize, ResourceNamesPool);
914+
Initialize(Allocator, ResCounters, NumSpecConstants, ResourceNamesPoolSize, ResourceNamesPool);
907915

908916
// Uniform buffer reflections
909917
std::vector<ShaderCodeBufferDescX> UBReflections;
@@ -979,6 +987,31 @@ WGSLShaderResources::WGSLShaderResources(IMemoryAllocator& Allocator,
979987
VERIFY_EXPR(CurrRes.NumSamplers == GetNumSamplers());
980988
VERIFY_EXPR(CurrRes.NumExtTextures == GetNumExtTextures());
981989

990+
// Construct specialization constant attribs
991+
{
992+
using TintOverrideType = tint::inspector::Override::Type;
993+
Uint32 SpecConstIdx = 0;
994+
for (const auto& Override : Overrides)
995+
{
996+
const char* SCName = ResourceNamesPool.CopyString(Override.name);
997+
SHADER_CODE_BASIC_TYPE SCType = SHADER_CODE_BASIC_TYPE_UNKNOWN;
998+
switch (Override.type)
999+
{
1000+
// clang-format off
1001+
case TintOverrideType::kBool: SCType = SHADER_CODE_BASIC_TYPE_BOOL; break;
1002+
case TintOverrideType::kFloat32: SCType = SHADER_CODE_BASIC_TYPE_FLOAT; break;
1003+
case TintOverrideType::kUint32: SCType = SHADER_CODE_BASIC_TYPE_UINT; break;
1004+
case TintOverrideType::kInt32: SCType = SHADER_CODE_BASIC_TYPE_INT; break;
1005+
case TintOverrideType::kFloat16: SCType = SHADER_CODE_BASIC_TYPE_FLOAT16; break;
1006+
// clang-format on
1007+
default:
1008+
UNEXPECTED("Unexpected override type");
1009+
}
1010+
new (&GetSpecConstant(SpecConstIdx++)) WGSLSpecializationConstantAttribs{SCName, Override.id.value, SCType};
1011+
}
1012+
VERIFY_EXPR(SpecConstIdx == GetNumSpecConstants());
1013+
}
1014+
9821015
if (CombinedSamplerSuffix != nullptr)
9831016
{
9841017
m_CombinedSamplerSuffix = ResourceNamesPool.CopyString(CombinedSamplerSuffix);
@@ -1002,6 +1035,7 @@ WGSLShaderResources::WGSLShaderResources(IMemoryAllocator& Allocator,
10021035

10031036
void WGSLShaderResources::Initialize(IMemoryAllocator& Allocator,
10041037
const ResourceCounters& Counters,
1038+
Uint32 NumSpecConstants,
10051039
size_t ResourceNamesPoolSize,
10061040
StringPool& ResourceNamesPool)
10071041
{
@@ -1023,13 +1057,16 @@ void WGSLShaderResources::Initialize(IMemoryAllocator& Allocator,
10231057
m_SamplerOffset = AdvanceOffset(Counters.NumSamplers);
10241058
m_ExternalTextureOffset = AdvanceOffset(Counters.NumExtTextures);
10251059
m_TotalResources = AdvanceOffset(0);
1060+
m_NumSpecConstants = static_cast<OffsetType>(NumSpecConstants);
10261061
static_assert(Uint32{WGSLShaderResourceAttribs::ResourceType::NumResourceTypes} == 13, "Please update the new resource type offset");
10271062

10281063
size_t AlignedResourceNamesPoolSize = AlignUp(ResourceNamesPoolSize, sizeof(void*));
10291064

10301065
static_assert(sizeof(WGSLShaderResourceAttribs) % sizeof(void*) == 0, "Size of WGSLShaderResourceAttribs struct must be a multiple of sizeof(void*)");
1066+
static_assert(sizeof(WGSLSpecializationConstantAttribs) % sizeof(void*) == 0, "Size of WGSLSpecializationConstantAttribs struct must be a multiple of sizeof(void*)");
10311067
// clang-format off
10321068
size_t MemorySize = m_TotalResources * sizeof(WGSLShaderResourceAttribs) +
1069+
m_NumSpecConstants * sizeof(WGSLSpecializationConstantAttribs) +
10331070
AlignedResourceNamesPoolSize * sizeof(char);
10341071

10351072
VERIFY_EXPR(GetNumUBs() == Counters.NumUBs);
@@ -1046,7 +1083,8 @@ void WGSLShaderResources::Initialize(IMemoryAllocator& Allocator,
10461083
void* pRawMem = Allocator.Allocate(MemorySize, "Memory for shader resources", __FILE__, __LINE__);
10471084
m_MemoryBuffer = std::unique_ptr<void, STDDeleterRawMem<void>>(pRawMem, Allocator);
10481085
char* NamesPool = reinterpret_cast<char*>(m_MemoryBuffer.get()) +
1049-
m_TotalResources * sizeof(WGSLShaderResourceAttribs);
1086+
m_TotalResources * sizeof(WGSLShaderResourceAttribs) +
1087+
m_NumSpecConstants * sizeof(WGSLSpecializationConstantAttribs);
10501088
ResourceNamesPool.AssignMemory(NamesPool, ResourceNamesPoolSize);
10511089
}
10521090
}
@@ -1055,6 +1093,21 @@ WGSLShaderResources::~WGSLShaderResources()
10551093
{
10561094
for (Uint32 n = 0; n < GetTotalResources(); ++n)
10571095
GetResource(n).~WGSLShaderResourceAttribs();
1096+
1097+
for (Uint32 n = 0; n < GetNumSpecConstants(); ++n)
1098+
GetSpecConstant(n).~WGSLSpecializationConstantAttribs();
1099+
}
1100+
1101+
const WGSLSpecializationConstantAttribs& WGSLShaderResources::GetSpecConstant(Uint32 n) const noexcept
1102+
{
1103+
VERIFY(n < m_NumSpecConstants, "Specialization constant index (", n, ") is out of range. Total spec constant count: ", m_NumSpecConstants);
1104+
const WGSLShaderResourceAttribs* ResourceMemoryEnd = reinterpret_cast<const WGSLShaderResourceAttribs*>(m_MemoryBuffer.get()) + GetTotalResources();
1105+
return reinterpret_cast<const WGSLSpecializationConstantAttribs*>(ResourceMemoryEnd)[n];
1106+
}
1107+
1108+
WGSLSpecializationConstantAttribs& WGSLShaderResources::GetSpecConstant(Uint32 n) noexcept
1109+
{
1110+
return const_cast<WGSLSpecializationConstantAttribs&>(const_cast<const WGSLShaderResources*>(this)->GetSpecConstant(n));
10581111
}
10591112

10601113
std::string WGSLShaderResources::DumpResources()
@@ -1173,6 +1226,20 @@ std::string WGSLShaderResources::DumpResources()
11731226
);
11741227
VERIFY_EXPR(ResNum == GetTotalResources());
11751228

1229+
if (GetNumSpecConstants() > 0)
1230+
{
1231+
ss << std::endl
1232+
<< "Spec Constants (" << GetNumSpecConstants() << "):";
1233+
for (Uint32 n = 0; n < GetNumSpecConstants(); ++n)
1234+
{
1235+
const auto& SC = GetSpecConstant(n);
1236+
ss << std::endl
1237+
<< std::setw(3) << n << " Spec Constant "
1238+
<< std::setw(32) << ('\'' + std::string{SC.Name} + '\'')
1239+
<< " id=" << SC.OverrideId;
1240+
}
1241+
}
1242+
11761243
return ss.str();
11771244
}
11781245

0 commit comments

Comments
 (0)