Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions Graphics/ShaderTools/include/WGSLShaderResources.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2024 Diligent Graphics LLC
* Copyright 2024-2026 Diligent Graphics LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,8 +32,8 @@
// WGSLShaderResources class uses continuous chunk of memory to store all resources, as follows:
//
// m_MemoryBuffer
// | |
// | Uniform Buffers | Storage Buffers | Textures | Storage Textures | Samplers | Ext Textures | Resource Names |
// | |
// | Uniform Buffers | Storage Buffers | Textures | Storage Textures | Samplers | Ext Textures | Spec Constants | Resource Names |

#include <memory>
#include <string>
Expand Down Expand Up @@ -159,6 +159,31 @@ struct WGSLShaderResourceAttribs
static_assert(sizeof(WGSLShaderResourceAttribs) % sizeof(void*) == 0, "Size of WGSLShaderResourceAttribs struct must be a multiple of sizeof(void*)");


struct WGSLSpecializationConstantAttribs
{
// clang-format off

/* 0 */const char* const Name;
/* 8 */const Uint16 OverrideId;
/* 10 */const Uint8 Type; // SHADER_CODE_BASIC_TYPE
/* 11 */
/* 16 */ // End of structure

// clang-format on

WGSLSpecializationConstantAttribs(const char* _Name,
Uint16 _OverrideId,
SHADER_CODE_BASIC_TYPE _Type) noexcept :
Name{_Name},
OverrideId{_OverrideId},
Type{static_cast<Uint8>(_Type)}
{}

SHADER_CODE_BASIC_TYPE GetType() const { return static_cast<SHADER_CODE_BASIC_TYPE>(Type); }
};
static_assert(sizeof(WGSLSpecializationConstantAttribs) % sizeof(void*) == 0, "Size of WGSLSpecializationConstantAttribs struct must be a multiple of sizeof(void*)");


/// Diligent::WGSLShaderResources class
class WGSLShaderResources
{
Expand Down Expand Up @@ -191,6 +216,7 @@ class WGSLShaderResources
Uint32 GetNumSamplers ()const noexcept{ return (m_ExternalTextureOffset - m_SamplerOffset); }
Uint32 GetNumExtTextures ()const noexcept{ return (m_TotalResources - m_ExternalTextureOffset);}
Uint32 GetTotalResources ()const noexcept{ return m_TotalResources; }
Uint32 GetNumSpecConstants()const noexcept{ return m_NumSpecConstants; }

const WGSLShaderResourceAttribs& GetUB (Uint32 n) const noexcept { return GetResAttribs(n, GetNumUBs(), 0 ); }
const WGSLShaderResourceAttribs& GetSB (Uint32 n) const noexcept { return GetResAttribs(n, GetNumSBs(), m_StorageBufferOffset ); }
Expand All @@ -200,6 +226,8 @@ class WGSLShaderResources
const WGSLShaderResourceAttribs& GetExtTexture(Uint32 n) const noexcept { return GetResAttribs(n, GetNumExtTextures(), m_ExternalTextureOffset); }
const WGSLShaderResourceAttribs& GetResource (Uint32 n) const noexcept { return GetResAttribs(n, GetTotalResources(), 0 ); }

const WGSLSpecializationConstantAttribs& GetSpecConstant(Uint32 n) const noexcept;

// clang-format on

const ShaderCodeBufferDesc* GetUniformBufferDesc(Uint32 Index) const
Expand Down Expand Up @@ -308,6 +336,7 @@ class WGSLShaderResources
private:
void Initialize(IMemoryAllocator& Allocator,
const ResourceCounters& Counters,
Uint32 NumSpecConstants,
size_t ResourceNamesPoolSize,
StringPool& ResourceNamesPool);

Expand Down Expand Up @@ -335,11 +364,13 @@ class WGSLShaderResources
WGSLShaderResourceAttribs& GetExtTexture(Uint32 n) noexcept { return GetResAttribs(n, GetNumExtTextures(), m_ExternalTextureOffset); }
WGSLShaderResourceAttribs& GetResource (Uint32 n) noexcept { return GetResAttribs(n, GetTotalResources(), 0 ); }

WGSLSpecializationConstantAttribs& GetSpecConstant(Uint32 n) noexcept;

// clang-format on

private:
// Memory buffer that holds all resources as continuous chunk of memory:
// | UBs | SBs | Textures | StorageTex | Samplers | ExternalTex | Resource Names |
// | UBs | SBs | Textures | StorageTex | Samplers | ExternalTex | Spec Constants | Resource Names |
std::unique_ptr<void, STDDeleterRawMem<void>> m_MemoryBuffer;
std::unique_ptr<void, STDDeleterRawMem<void>> m_UBReflectionBuffer;

Expand All @@ -355,6 +386,7 @@ class WGSLShaderResources
OffsetType m_SamplerOffset = 0;
OffsetType m_ExternalTextureOffset = 0;
OffsetType m_TotalResources = 0;
OffsetType m_NumSpecConstants = 0;

SHADER_TYPE m_ShaderType = SHADER_TYPE_UNKNOWN;
};
Expand Down
76 changes: 74 additions & 2 deletions Graphics/ShaderTools/src/WGSLShaderResources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,24 @@ WEB_GPU_BINDING_TYPE GetWebGPUTextureBindingType(WGSLShaderResourceAttribs::Text
}
}

SHADER_CODE_BASIC_TYPE TintOverrideTypeToShaderCodeBasicType(tint::inspector::Override::Type OverrideType)
{
using TintOverrideType = tint::inspector::Override::Type;
switch (OverrideType)
{
// clang-format off
case TintOverrideType::kBool: return SHADER_CODE_BASIC_TYPE_BOOL;
case TintOverrideType::kFloat32: return SHADER_CODE_BASIC_TYPE_FLOAT;
case TintOverrideType::kUint32: return SHADER_CODE_BASIC_TYPE_UINT;
case TintOverrideType::kInt32: return SHADER_CODE_BASIC_TYPE_INT;
case TintOverrideType::kFloat16: return SHADER_CODE_BASIC_TYPE_FLOAT16;
// clang-format on
default:
UNEXPECTED("Unexpected override type");
return SHADER_CODE_BASIC_TYPE_UNKNOWN;
}
}

} // namespace

WGSLShaderResourceAttribs::WGSLShaderResourceAttribs(const char* _Name,
Expand Down Expand Up @@ -901,9 +919,17 @@ WGSLShaderResources::WGSLShaderResources(IMemoryAllocator& Allocator,
ResourceNamesPoolSize += strlen(ShaderName) + 1;
ResourceNamesPoolSize += strlen(EntryPoint) + 1;

// Count override constants (specialization constants)
const auto& Overrides = EntryPoints[EntryPointIdx].overrides;
Uint32 NumSpecConstants = static_cast<Uint32>(Overrides.size());
for (const tint::inspector::Override& Override : Overrides)
{
ResourceNamesPoolSize += Override.name.length() + 1;
}

// Resource names pool is only needed to facilitate string allocation.
StringPool ResourceNamesPool;
Initialize(Allocator, ResCounters, ResourceNamesPoolSize, ResourceNamesPool);
Initialize(Allocator, ResCounters, NumSpecConstants, ResourceNamesPoolSize, ResourceNamesPool);

// Uniform buffer reflections
std::vector<ShaderCodeBufferDescX> UBReflections;
Expand Down Expand Up @@ -979,6 +1005,18 @@ WGSLShaderResources::WGSLShaderResources(IMemoryAllocator& Allocator,
VERIFY_EXPR(CurrRes.NumSamplers == GetNumSamplers());
VERIFY_EXPR(CurrRes.NumExtTextures == GetNumExtTextures());

// Construct specialization constant attribs
{
Uint32 SpecConstIdx = 0;
for (const tint::inspector::Override& Override : Overrides)
{
const char* SCName = ResourceNamesPool.CopyString(Override.name);
SHADER_CODE_BASIC_TYPE SCType = TintOverrideTypeToShaderCodeBasicType(Override.type);
new (&GetSpecConstant(SpecConstIdx++)) WGSLSpecializationConstantAttribs{SCName, Override.id.value, SCType};
}
VERIFY_EXPR(SpecConstIdx == GetNumSpecConstants());
}

if (CombinedSamplerSuffix != nullptr)
{
m_CombinedSamplerSuffix = ResourceNamesPool.CopyString(CombinedSamplerSuffix);
Expand All @@ -1002,6 +1040,7 @@ WGSLShaderResources::WGSLShaderResources(IMemoryAllocator& Allocator,

void WGSLShaderResources::Initialize(IMemoryAllocator& Allocator,
const ResourceCounters& Counters,
Uint32 NumSpecConstants,
size_t ResourceNamesPoolSize,
StringPool& ResourceNamesPool)
{
Expand All @@ -1023,13 +1062,16 @@ void WGSLShaderResources::Initialize(IMemoryAllocator& Allocator,
m_SamplerOffset = AdvanceOffset(Counters.NumSamplers);
m_ExternalTextureOffset = AdvanceOffset(Counters.NumExtTextures);
m_TotalResources = AdvanceOffset(0);
m_NumSpecConstants = static_cast<OffsetType>(NumSpecConstants);
static_assert(Uint32{WGSLShaderResourceAttribs::ResourceType::NumResourceTypes} == 13, "Please update the new resource type offset");

size_t AlignedResourceNamesPoolSize = AlignUp(ResourceNamesPoolSize, sizeof(void*));

static_assert(sizeof(WGSLShaderResourceAttribs) % sizeof(void*) == 0, "Size of WGSLShaderResourceAttribs struct must be a multiple of sizeof(void*)");
static_assert(sizeof(WGSLSpecializationConstantAttribs) % sizeof(void*) == 0, "Size of WGSLSpecializationConstantAttribs struct must be a multiple of sizeof(void*)");
// clang-format off
size_t MemorySize = m_TotalResources * sizeof(WGSLShaderResourceAttribs) +
m_NumSpecConstants * sizeof(WGSLSpecializationConstantAttribs) +
AlignedResourceNamesPoolSize * sizeof(char);

VERIFY_EXPR(GetNumUBs() == Counters.NumUBs);
Expand All @@ -1046,7 +1088,8 @@ void WGSLShaderResources::Initialize(IMemoryAllocator& Allocator,
void* pRawMem = Allocator.Allocate(MemorySize, "Memory for shader resources", __FILE__, __LINE__);
m_MemoryBuffer = std::unique_ptr<void, STDDeleterRawMem<void>>(pRawMem, Allocator);
char* NamesPool = reinterpret_cast<char*>(m_MemoryBuffer.get()) +
m_TotalResources * sizeof(WGSLShaderResourceAttribs);
m_TotalResources * sizeof(WGSLShaderResourceAttribs) +
m_NumSpecConstants * sizeof(WGSLSpecializationConstantAttribs);
ResourceNamesPool.AssignMemory(NamesPool, ResourceNamesPoolSize);
}
}
Expand All @@ -1055,6 +1098,21 @@ WGSLShaderResources::~WGSLShaderResources()
{
for (Uint32 n = 0; n < GetTotalResources(); ++n)
GetResource(n).~WGSLShaderResourceAttribs();

for (Uint32 n = 0; n < GetNumSpecConstants(); ++n)
GetSpecConstant(n).~WGSLSpecializationConstantAttribs();
}

const WGSLSpecializationConstantAttribs& WGSLShaderResources::GetSpecConstant(Uint32 n) const noexcept
{
VERIFY(n < m_NumSpecConstants, "Specialization constant index (", n, ") is out of range. Total spec constant count: ", m_NumSpecConstants);
const WGSLShaderResourceAttribs* ResourceMemoryEnd = reinterpret_cast<const WGSLShaderResourceAttribs*>(m_MemoryBuffer.get()) + GetTotalResources();
return reinterpret_cast<const WGSLSpecializationConstantAttribs*>(ResourceMemoryEnd)[n];
}

WGSLSpecializationConstantAttribs& WGSLShaderResources::GetSpecConstant(Uint32 n) noexcept
{
return const_cast<WGSLSpecializationConstantAttribs&>(const_cast<const WGSLShaderResources*>(this)->GetSpecConstant(n));
}

std::string WGSLShaderResources::DumpResources()
Expand Down Expand Up @@ -1173,6 +1231,20 @@ std::string WGSLShaderResources::DumpResources()
);
VERIFY_EXPR(ResNum == GetTotalResources());

if (GetNumSpecConstants() > 0)
{
ss << std::endl
<< "Spec Constants (" << GetNumSpecConstants() << "):";
for (Uint32 n = 0; n < GetNumSpecConstants(); ++n)
{
const auto& SC = GetSpecConstant(n);
ss << std::endl
<< std::setw(3) << n << " Spec Constant "
<< std::setw(32) << ('\'' + std::string{SC.Name} + '\'')
<< " id=" << SC.OverrideId;
}
}

return ss.str();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2024 Diligent Graphics LLC
* Copyright 2024-2026 Diligent Graphics LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -258,4 +258,69 @@ TEST(WGSLShaderResources, RWStructBufferArrays)
});
}

TEST(WGSLShaderResources, SpecializationConstants)
{
// WGSL source with override (specialization) constants of various types.
static constexpr char WGSL[] = R"(
override sc_float: f32 = 1.0;
override sc_int: i32 = 0;
override sc_uint: u32 = 0;
override sc_bool: bool = false;

@group(0) @binding(0) var<storage, read_write> output: array<f32>;

@compute @workgroup_size(1)
fn main() {
output[0] = sc_float;
output[1] = f32(sc_int);
output[2] = f32(sc_uint);
if sc_bool {
output[3] = 1.0;
}
}
)";

WGSLShaderResources Resources{
GetRawAllocator(),
WGSL,
SHADER_SOURCE_LANGUAGE_WGSL,
"SpecConst test",
nullptr, // CombinedSamplerSuffix
"main", // EntryPoint
nullptr, // ArrayIndexSuffix
false, // LoadUniformBufferReflection
nullptr // ppTintOutput
};
LOG_INFO_MESSAGE("WGSL Resources:\n", Resources.DumpResources());

// One storage buffer resource
EXPECT_EQ(Resources.GetTotalResources(), 1u);
EXPECT_EQ(Resources.GetNumSBs(), 1u);

const std::vector<WGSLSpecializationConstantAttribs> RefSpecConstants = {
{"sc_float", 0, SHADER_CODE_BASIC_TYPE_FLOAT},
{"sc_int", 0, SHADER_CODE_BASIC_TYPE_INT},
{"sc_uint", 0, SHADER_CODE_BASIC_TYPE_UINT},
{"sc_bool", 0, SHADER_CODE_BASIC_TYPE_BOOL},
};

const WGSLShaderResources& ConstResources = Resources;
EXPECT_EQ(ConstResources.GetNumSpecConstants(), static_cast<Uint32>(RefSpecConstants.size()));

// Build a map from name to reference for order-independent matching
std::unordered_map<std::string, const WGSLSpecializationConstantAttribs*> RefMap;
for (const WGSLSpecializationConstantAttribs& Ref : RefSpecConstants)
RefMap[Ref.Name] = &Ref;

for (Uint32 i = 0; i < ConstResources.GetNumSpecConstants(); ++i)
{
const WGSLSpecializationConstantAttribs& SC = ConstResources.GetSpecConstant(i);
const auto it = RefMap.find(SC.Name);
ASSERT_NE(it, RefMap.end()) << "Specialization constant '" << SC.Name << "' is not found in the reference list";

const WGSLSpecializationConstantAttribs* pRef = it->second;
EXPECT_EQ(SC.GetType(), pRef->GetType()) << SC.Name;
}
}

} // namespace
Loading