Skip to content

Commit 214553d

Browse files
author
Jack Elliott
committed
Initial implementation of GroupSharedLimit
1 parent 795cf94 commit 214553d

6 files changed

Lines changed: 70 additions & 11 deletions

File tree

include/dxc/DXIL/DxilModule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@ class DxilModule {
254254
void SetNumThreads(unsigned x, unsigned y, unsigned z);
255255
unsigned GetNumThreads(unsigned idx) const;
256256

257+
unsigned GetGroupSharedLimit() const;
258+
257259
// Compute shader
258260
DxilWaveSize &GetWaveSize();
259261
const DxilWaveSize &GetWaveSize() const;

include/dxc/DxilContainer/DxilPipelineStateValidation.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ struct PSVRuntimeInfo3 : public PSVRuntimeInfo2 {
175175
uint32_t EntryFunctionName;
176176
};
177177

178+
struct PSVRuntimeInfo4 : public PSVRuntimeInfo3 {
179+
uint32_t GroupSharedMemoryLimit;
180+
};
181+
178182
enum class PSVResourceType {
179183
Invalid = 0,
180184

@@ -474,7 +478,7 @@ class PSVSignatureElement {
474478
const uint32_t *SemanticIndexes) const;
475479
};
476480

477-
#define MAX_PSV_VERSION 3
481+
#define MAX_PSV_VERSION 4
478482

479483
struct PSVInitInfo {
480484
PSVInitInfo(uint32_t psvVersion) : PSVVersion(psvVersion) {}
@@ -491,7 +495,7 @@ struct PSVInitInfo {
491495
uint8_t SigPatchConstOrPrimVectors = 0;
492496
uint8_t SigOutputVectors[PSV_GS_MAX_STREAMS] = {0, 0, 0, 0};
493497

494-
static_assert(MAX_PSV_VERSION == 3, "otherwise this needs updating.");
498+
static_assert(MAX_PSV_VERSION == 4, "otherwise this needs updating.");
495499
uint32_t RuntimeInfoSize() const {
496500
switch (PSVVersion) {
497501
case 0:
@@ -500,10 +504,12 @@ struct PSVInitInfo {
500504
return sizeof(PSVRuntimeInfo1);
501505
case 2:
502506
return sizeof(PSVRuntimeInfo2);
507+
case 3:
508+
return sizeof(PSVRuntimeInfo3);
503509
default:
504510
break;
505511
}
506-
return sizeof(PSVRuntimeInfo3);
512+
return sizeof(PSVRuntimeInfo4);
507513
}
508514
uint32_t ResourceBindInfoSize() const {
509515
if (PSVVersion < 2)
@@ -519,6 +525,7 @@ class DxilPipelineStateValidation {
519525
PSVRuntimeInfo1 *m_pPSVRuntimeInfo1 = nullptr;
520526
PSVRuntimeInfo2 *m_pPSVRuntimeInfo2 = nullptr;
521527
PSVRuntimeInfo3 *m_pPSVRuntimeInfo3 = nullptr;
528+
PSVRuntimeInfo4 *m_pPSVRuntimeInfo4 = nullptr;
522529
uint32_t m_uResourceCount = 0;
523530
uint32_t m_uPSVResourceBindInfoSize = 0;
524531
void *m_pPSVResourceBindInfo = nullptr;
@@ -634,6 +641,8 @@ class DxilPipelineStateValidation {
634641

635642
PSVRuntimeInfo3 *GetPSVRuntimeInfo3() const { return m_pPSVRuntimeInfo3; }
636643

644+
PSVRuntimeInfo4 *GetPSVRuntimeInfo4() const { return m_pPSVRuntimeInfo4; }
645+
637646
uint32_t GetBindCount() const { return m_uResourceCount; }
638647

639648
template <typename _T>
@@ -949,6 +958,8 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize,
949958
m_uPSVRuntimeInfoSize); // failure ok
950959
AssignDerived(&m_pPSVRuntimeInfo3, m_pPSVRuntimeInfo0,
951960
m_uPSVRuntimeInfoSize); // failure ok
961+
AssignDerived(&m_pPSVRuntimeInfo4, m_pPSVRuntimeInfo0,
962+
m_uPSVRuntimeInfoSize); // failure ok
952963

953964
// In RWMode::CalcSize, use temp runtime info to hold needed values from
954965
// initInfo
@@ -1137,11 +1148,13 @@ void SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM);
11371148
void SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM);
11381149
void SetShaderProps(PSVRuntimeInfo1 *pInfo1, const DxilModule &DM);
11391150
void SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM);
1151+
void SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM);
11401152

11411153
void PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
11421154
PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2,
1143-
PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind,
1144-
const char *EntryName, const char *Comment);
1155+
PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4,
1156+
uint8_t ShaderKind, const char *EntryName,
1157+
const char *Comment);
11451158

11461159
} // namespace hlsl
11471160

lib/DXIL/DxilModule.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,16 @@ unsigned DxilModule::GetNumThreads(unsigned idx) const {
412412
return props.numThreads[idx];
413413
}
414414

415+
unsigned DxilModule::GetGroupSharedLimit() const {
416+
DXASSERT(m_DxilEntryPropsMap.size() == 1 &&
417+
(m_pSM->IsCS() || m_pSM->IsMS() || m_pSM->IsAS()),
418+
"only works for CS/MS/AS profiles");
419+
const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
420+
DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
421+
return props.groupSharedLimitBytes;
422+
}
423+
424+
415425
DxilWaveSize &DxilModule::GetWaveSize() {
416426
return const_cast<DxilWaveSize &>(
417427
static_cast<const DxilModule *>(this)->GetWaveSize());

lib/DxilContainer/DxilContainerAssembler.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,8 @@ class DxilPSVWriter : public DxilPartWriter {
798798
PSVRuntimeInfo1 *pInfo1 = m_PSV.GetPSVRuntimeInfo1();
799799
PSVRuntimeInfo2 *pInfo2 = m_PSV.GetPSVRuntimeInfo2();
800800
PSVRuntimeInfo3 *pInfo3 = m_PSV.GetPSVRuntimeInfo3();
801+
PSVRuntimeInfo4 *pInfo4 = m_PSV.GetPSVRuntimeInfo4();
802+
801803
if (pInfo)
802804
hlsl::SetShaderProps(pInfo, m_Module);
803805
if (pInfo1)
@@ -806,6 +808,8 @@ class DxilPSVWriter : public DxilPartWriter {
806808
hlsl::SetShaderProps(pInfo2, m_Module);
807809
if (pInfo3)
808810
pInfo3->EntryFunctionName = EntryFunctionName;
811+
if (pInfo4)
812+
hlsl::SetShaderProps(pInfo4, m_Module);
809813

810814
// Set resource binding information
811815
UINT uResIndex = 0;

lib/DxilContainer/DxilPipelineStateValidation.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,21 @@ void hlsl::SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM) {
305305
}
306306
}
307307

308+
void hlsl::SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM) {
309+
assert(pInfo4);
310+
const ShaderModel* SM = DM.GetShaderModel();
311+
switch (SM->GetKind())
312+
{
313+
case ShaderModel::Kind::Compute:
314+
case ShaderModel::Kind::Mesh:
315+
case ShaderModel::Kind::Amplification:
316+
pInfo4->GroupSharedMemoryLimit = DM.GetGroupSharedLimit();
317+
break;
318+
default:
319+
break;
320+
}
321+
}
322+
308323
void PSVResourceBindInfo0::Print(raw_ostream &OS) const {
309324
OS << "PSVResourceBindInfo:\n";
310325
OS << " Space: " << Space << "\n";
@@ -584,8 +599,9 @@ void PSVDependencyTable::Print(raw_ostream &OS, const char *InputSetName,
584599

585600
void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
586601
PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2,
587-
PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind,
588-
const char *EntryName, const char *Comment) {
602+
PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4,
603+
uint8_t ShaderKind, const char *EntryName,
604+
const char *Comment) {
589605
if (pInfo1 && pInfo1->ShaderStage != ShaderKind)
590606
ShaderKind = pInfo1->ShaderStage;
591607
OS << Comment << "PSVRuntimeInfo:\n";
@@ -808,13 +824,21 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
808824
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
809825
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
810826
}
827+
if (pInfo4) {
828+
OS << Comment << " GroupSharedMemoryLimit="
829+
<< pInfo4->GroupSharedMemoryLimit << "\n";
830+
}
811831
break;
812832
case PSVShaderKind::Amplification:
813833
OS << Comment << " Amplification Shader\n";
814834
if (pInfo2) {
815835
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
816836
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
817837
}
838+
if (pInfo4) {
839+
OS << Comment << " GroupSharedMemoryLimit="
840+
<< pInfo4->GroupSharedMemoryLimit << "\n";
841+
}
818842
break;
819843
case PSVShaderKind::Mesh:
820844
OS << Comment << " Mesh Shader\n";
@@ -841,6 +865,10 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
841865
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
842866
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
843867
}
868+
if (pInfo4) {
869+
OS << Comment << " GroupSharedMemoryLimit="
870+
<< pInfo4->GroupSharedMemoryLimit << "\n";
871+
}
844872
break;
845873
case PSVShaderKind::Library:
846874
case PSVShaderKind::Invalid:
@@ -887,9 +915,10 @@ void DxilPipelineStateValidation::PrintPSVRuntimeInfo(
887915
PSVRuntimeInfo1 *pInfo1 = m_pPSVRuntimeInfo1;
888916
PSVRuntimeInfo2 *pInfo2 = m_pPSVRuntimeInfo2;
889917
PSVRuntimeInfo3 *pInfo3 = m_pPSVRuntimeInfo3;
918+
PSVRuntimeInfo4 *pInfo4 = m_pPSVRuntimeInfo4;
890919

891920
hlsl::PrintPSVRuntimeInfo(
892-
OS, pInfo0, pInfo1, pInfo2, pInfo3, ShaderKind,
921+
OS, pInfo0, pInfo1, pInfo2, pInfo3, pInfo4, ShaderKind,
893922
m_pPSVRuntimeInfo3 ? m_StringTable.Get(pInfo3->EntryFunctionName) : "",
894923
Comment);
895924
}

lib/DxilValidation/DxilContainerValidation.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -413,12 +413,13 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM,
413413
PSVRuntimeInfo0 *PSV0,
414414
PSVRuntimeInfo1 *PSV1,
415415
PSVRuntimeInfo2 *PSV2) {
416-
PSVRuntimeInfo3 DMPSV;
417-
memset(&DMPSV, 0, sizeof(PSVRuntimeInfo3));
416+
PSVRuntimeInfo4 DMPSV;
417+
memset(&DMPSV, 0, sizeof(PSVRuntimeInfo4));
418418

419419
hlsl::SetShaderProps((PSVRuntimeInfo0 *)&DMPSV, DM);
420420
hlsl::SetShaderProps((PSVRuntimeInfo1 *)&DMPSV, DM);
421421
hlsl::SetShaderProps((PSVRuntimeInfo2 *)&DMPSV, DM);
422+
hlsl::SetShaderProps((PSVRuntimeInfo4 *)&DMPSV, DM);
422423
if (PSV1) {
423424
// Init things not set in InitPSVRuntimeInfo.
424425
DMPSV.ShaderStage = static_cast<uint8_t>(SM->GetKind());
@@ -447,7 +448,7 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM,
447448
if (Mismatched) {
448449
std::string Str;
449450
raw_string_ostream OS(Str);
450-
hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV,
451+
hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV, &DMPSV,
451452
static_cast<uint8_t>(SM->GetKind()),
452453
DM.GetEntryFunctionName().c_str(), "");
453454
OS.flush();

0 commit comments

Comments
 (0)