Skip to content

Commit 423617d

Browse files
committed
Improves error handling and testing.
1 parent eb2dc82 commit 423617d

4 files changed

Lines changed: 80 additions & 49 deletions

File tree

tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7740,6 +7740,8 @@ def warn_hlsl_entry_attribute_without_shader_attribute : Warning<
77407740
InGroup<HLSLEntryAttributeWithoutShaderAttrType>;
77417741
def err_hlsl_attribute_expects_float_literal : Error<
77427742
"attribute %0 must have a float literal argument">;
7743+
def err_hlsl_attribute_expects_integer_const_expr : Error<
7744+
"attribute %0 argument %1 must be integer constant expression">;
77437745
def warn_hlsl_comma_in_init : Warning<
77447746
"comma expression used where a constructor list may have been intended">,
77457747
InGroup<HLSLCommaInInit>;

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -328,17 +328,23 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
328328
};
329329
} // namespace
330330

331-
static uint32_t GetIntConstAttrArg(ASTContext &astContext, const Expr *expr,
332-
uint32_t defaultVal = 0) {
331+
static uint32_t
332+
getIntConstAttrArg(ASTContext &astContext, const Expr *expr,
333+
llvm::Optional<uint32_t> defaultVal = llvm::None) {
333334
if (expr) {
334335
llvm::APSInt apsInt;
335336
APValue apValue;
336337
if (expr->isIntegerConstantExpr(apsInt, astContext))
337338
return (uint32_t)apsInt.getSExtValue();
338339
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt())
339340
return (uint32_t)apValue.getInt().getSExtValue();
341+
llvm_unreachable(
342+
"Expression must be a constant expression or spec constant");
340343
}
341-
return defaultVal;
344+
if (!defaultVal.hasValue()) {
345+
DXASSERT(defaultVal.hasValue(), "missing attribute parameter");
346+
}
347+
return defaultVal.getValue();
342348
}
343349

344350
//------------------------------------------------------------------------------
@@ -1646,9 +1652,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16461652

16471653
// Populate numThreads
16481654
if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {
1649-
funcProps->numThreads[0] = GetIntConstAttrArg(astContext, Attr->getX(), 1);
1650-
funcProps->numThreads[1] = GetIntConstAttrArg(astContext, Attr->getY(), 1);
1651-
funcProps->numThreads[2] = GetIntConstAttrArg(astContext, Attr->getZ(), 1);
1655+
funcProps->numThreads[0] = getIntConstAttrArg(astContext, Attr->getX());
1656+
funcProps->numThreads[1] = getIntConstAttrArg(astContext, Attr->getY());
1657+
funcProps->numThreads[2] = getIntConstAttrArg(astContext, Attr->getZ());
16521658

16531659
if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
16541660
unsigned DiagID = Diags.getCustomDiagID(
@@ -1822,7 +1828,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18221828
if (const auto *pAttr = FD->getAttr<HLSLNodeIdAttr>()) {
18231829
funcProps->NodeShaderID.Name = pAttr->getName().str();
18241830
funcProps->NodeShaderID.Index =
1825-
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
1831+
getIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
18261832
} else {
18271833
funcProps->NodeShaderID.Name = FD->getName().str();
18281834
funcProps->NodeShaderID.Index = 0;
@@ -1834,27 +1840,27 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18341840
if (const auto *pAttr = FD->getAttr<HLSLNodeShareInputOfAttr>()) {
18351841
funcProps->NodeShaderSharedInput.Name = pAttr->getName().str();
18361842
funcProps->NodeShaderSharedInput.Index =
1837-
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
1843+
getIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
18381844
}
18391845
if (const auto *pAttr = FD->getAttr<HLSLNodeDispatchGridAttr>()) {
18401846
funcProps->Node.DispatchGrid[0] =
1841-
GetIntConstAttrArg(astContext, pAttr->getX(), 1);
1847+
getIntConstAttrArg(astContext, pAttr->getX());
18421848
funcProps->Node.DispatchGrid[1] =
1843-
GetIntConstAttrArg(astContext, pAttr->getY(), 1);
1849+
getIntConstAttrArg(astContext, pAttr->getY());
18441850
funcProps->Node.DispatchGrid[2] =
1845-
GetIntConstAttrArg(astContext, pAttr->getZ(), 1);
1851+
getIntConstAttrArg(astContext, pAttr->getZ());
18461852
}
18471853
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxDispatchGridAttr>()) {
18481854
funcProps->Node.MaxDispatchGrid[0] =
1849-
GetIntConstAttrArg(astContext, pAttr->getX(), 1);
1855+
getIntConstAttrArg(astContext, pAttr->getX());
18501856
funcProps->Node.MaxDispatchGrid[1] =
1851-
GetIntConstAttrArg(astContext, pAttr->getY(), 1);
1857+
getIntConstAttrArg(astContext, pAttr->getY());
18521858
funcProps->Node.MaxDispatchGrid[2] =
1853-
GetIntConstAttrArg(astContext, pAttr->getZ(), 1);
1859+
getIntConstAttrArg(astContext, pAttr->getZ());
18541860
}
18551861
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxRecursionDepthAttr>()) {
18561862
funcProps->Node.MaxRecursionDepth =
1857-
GetIntConstAttrArg(astContext, pAttr->getCount(), 0);
1863+
getIntConstAttrArg(astContext, pAttr->getCount());
18581864
}
18591865
if (!FD->getAttr<HLSLNumThreadsAttr>()) {
18601866
// NumThreads wasn't specified.
@@ -2368,9 +2374,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23682374
NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
23692375

23702376
if (parmDecl->hasAttr<HLSLMaxRecordsAttr>()) {
2371-
node.MaxRecords = GetIntConstAttrArg(
2377+
node.MaxRecords = getIntConstAttrArg(
23722378
astContext,
2373-
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount(), 1);
2379+
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount());
23742380
}
23752381
if (parmDecl->hasAttr<HLSLGloballyCoherentAttr>())
23762382
node.Flags.SetGloballyCoherent();
@@ -2402,7 +2408,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24022408
if (const auto *Attr = parmDecl->getAttr<HLSLNodeIdAttr>()) {
24032409
node.OutputID.Name = Attr->getName().str();
24042410
node.OutputID.Index =
2405-
GetIntConstAttrArg(astContext, Attr->getArrayIndex(), 0);
2411+
getIntConstAttrArg(astContext, Attr->getArrayIndex(), 0);
24062412
} else {
24072413
node.OutputID.Name = parmDecl->getName().str();
24082414
node.OutputID.Index = 0;
@@ -2461,7 +2467,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24612467
node.MaxRecordsSharedWith = ix;
24622468
}
24632469
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsAttr>())
2464-
node.MaxRecords = GetIntConstAttrArg(astContext, Attr->getMaxCount(), 0);
2470+
node.MaxRecords = getIntConstAttrArg(astContext, Attr->getMaxCount());
24652471
}
24662472

24672473
if (inputPatchCount > 1) {

tools/clang/lib/Sema/SemaHLSL.cpp

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12283,17 +12283,29 @@ void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall,
1228312283
}
1228412284
}
1228512285

12286-
static uint32_t GetIntConstAttrArg(ASTContext &astContext, const Expr *expr,
12287-
uint32_t defaultVal = 0) {
12286+
static uint32_t
12287+
getIntConstAttrArg(Sema &S, const Attr *attr, unsigned argNum, const Expr *expr,
12288+
llvm::Optional<uint32_t> defaultVal = llvm::None) {
1228812289
if (expr) {
1228912290
llvm::APSInt apsInt;
1229012291
APValue apValue;
12291-
if (expr->isIntegerConstantExpr(apsInt, astContext))
12292+
if (expr->isIntegerConstantExpr(apsInt, S.getASTContext()))
1229212293
return (uint32_t)apsInt.getSExtValue();
12293-
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt())
12294+
if (expr->isVulkanSpecConstantExpr(S.getASTContext(), &apValue) &&
12295+
apValue.isInt())
1229412296
return (uint32_t)apValue.getInt().getSExtValue();
12297+
S.Diag(expr->getExprLoc(),
12298+
diag::err_hlsl_attribute_expects_integer_const_expr)
12299+
<< attr->getSpelling() << argNum;
12300+
return 0;
12301+
}
12302+
if (!defaultVal.hasValue()) {
12303+
S.Diag(attr->getLocation(),
12304+
diag::err_hlsl_attribute_expects_integer_const_expr)
12305+
<< attr->getSpelling() << argNum;
12306+
return 0;
1229512307
}
12296-
return defaultVal;
12308+
return defaultVal.getValue();
1229712309
}
1229812310

1229912311
/////////////////////////////////////////////////////////////////////////////
@@ -12303,10 +12315,9 @@ static void DiagnoseNumThreadsForDerivativeOp(
1230312315
Sema &S, const HLSLNumThreadsAttr *Attr, SourceLocation LocDeriv,
1230412316
FunctionDecl *FD, const FunctionDecl *EntryDecl, DiagnosticsEngine &Diags) {
1230512317
bool invalidNumThreads = false;
12306-
ASTContext &astContext = S.getASTContext();
12307-
uint32_t x = GetIntConstAttrArg(astContext, Attr->getX(), 1);
12308-
uint32_t y = GetIntConstAttrArg(astContext, Attr->getY(), 1);
12309-
uint32_t z = GetIntConstAttrArg(astContext, Attr->getZ(), 1);
12318+
uint32_t x = getIntConstAttrArg(S, Attr, 1, Attr->getX());
12319+
uint32_t y = getIntConstAttrArg(S, Attr, 2, Attr->getY());
12320+
uint32_t z = getIntConstAttrArg(S, Attr, 3, Attr->getZ());
1231012321

1231112322
if (y != 1) {
1231212323
// 2D mode requires x and y to be multiple of 2.
@@ -14233,7 +14244,7 @@ HLSLMaxRecordsAttr *ValidateMaxRecordsAttributes(Sema &S, Decl *D,
1423314244
Loc = ExistingMRSWA->getLocation();
1423414245
} else if (ExistingMRA) {
1423514246
uint32_t maxCount =
14236-
GetIntConstAttrArg(S.getASTContext(), ExistingMRA->getMaxCount(), 0);
14247+
getIntConstAttrArg(S, ExistingMRA, 1, ExistingMRA->getMaxCount(), 0);
1423714248
if (LiteralInt->getValue().getLimitedValue() != maxCount)
1423814249
Loc = ExistingMRA->getLocation();
1423914250
}
@@ -14459,17 +14470,16 @@ void Sema::DiagnoseCoherenceMismatch(const Expr *SrcExpr, QualType TargetType,
1445914470
void ValidateDispatchGridValues(Sema &S, const AttributeList &A,
1446014471
Attr *declAttr) {
1446114472
unsigned x = 1, y = 1, z = 1;
14462-
ASTContext &astContext = S.getASTContext();
1446314473
if (HLSLNodeDispatchGridAttr *pA =
1446414474
dyn_cast<HLSLNodeDispatchGridAttr>(declAttr)) {
14465-
x = GetIntConstAttrArg(astContext, pA->getX(), 1);
14466-
y = GetIntConstAttrArg(astContext, pA->getY(), 1);
14467-
z = GetIntConstAttrArg(astContext, pA->getZ(), 1);
14475+
x = getIntConstAttrArg(S, pA, 1, pA->getX());
14476+
y = getIntConstAttrArg(S, pA, 2, pA->getY());
14477+
z = getIntConstAttrArg(S, pA, 3, pA->getZ());
1446814478
} else if (HLSLNodeMaxDispatchGridAttr *pA =
1446914479
dyn_cast<HLSLNodeMaxDispatchGridAttr>(declAttr)) {
14470-
x = GetIntConstAttrArg(astContext, pA->getX(), 1);
14471-
y = GetIntConstAttrArg(astContext, pA->getY(), 1);
14472-
z = GetIntConstAttrArg(astContext, pA->getZ(), 1);
14480+
x = getIntConstAttrArg(S, pA, 1, pA->getX());
14481+
y = getIntConstAttrArg(S, pA, 2, pA->getY());
14482+
z = getIntConstAttrArg(S, pA, 3, pA->getZ());
1447314483
} else {
1447414484
llvm_unreachable("ValidateDispatchGridValues() called for wrong attribute");
1447514485
}
@@ -17227,10 +17237,9 @@ void DiagnoseNodeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName,
1722717237
// thread group size is (1,1,1)
1722817238
if (NodeLaunchTy == DXIL::NodeLaunchType::Thread) {
1722917239
if (auto NumThreads = FD->getAttr<HLSLNumThreadsAttr>()) {
17230-
ASTContext &astContext = S.getASTContext();
17231-
uint32_t x = GetIntConstAttrArg(astContext, NumThreads->getX(), 1);
17232-
uint32_t y = GetIntConstAttrArg(astContext, NumThreads->getY(), 1);
17233-
uint32_t z = GetIntConstAttrArg(astContext, NumThreads->getZ(), 1);
17240+
uint32_t x = getIntConstAttrArg(S, NumThreads, 1, NumThreads->getX());
17241+
uint32_t y = getIntConstAttrArg(S, NumThreads, 2, NumThreads->getY());
17242+
uint32_t z = getIntConstAttrArg(S, NumThreads, 3, NumThreads->getZ());
1723417243
if (x != 1 || y != 1 || z != 1) {
1723517244
S.Diags.Report(NumThreads->getLocation(),
1723617245
diag::err_hlsl_wg_thread_launch_group_size)

tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.hlsl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s
1+
// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 -DSPEC=1 %s | FileCheck %s
2+
// RUN: not %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s 2>&1 | FileCheck --check-prefix=NOSPEC %s
23

34
// Note: validation disabled until NodePayloadAMDX pointers are allowed
45
// as function arguments
@@ -11,14 +12,22 @@ struct OutputPayload {
1112
uint foo;
1213
};
1314

14-
[[vk::constant_id(0)]]
15-
const uint MaxPayloads = 1;
16-
[[vk::constant_id(1)]]
17-
const uint WorkgroupSizeX = 1;
18-
[[vk::constant_id(2)]]
19-
const uint ShaderIndex = 0;
20-
[[vk::constant_id(3)]]
21-
const uint NumThreadsX = 512;
15+
#ifdef SPEC
16+
[[vk::constant_id(0)]] const
17+
#endif
18+
uint MaxPayloads = 1;
19+
#ifdef SPEC
20+
[[vk::constant_id(1)]] const
21+
#endif
22+
uint WorkgroupSizeX = 1;
23+
#ifdef SPEC
24+
[[vk::constant_id(2)]] const
25+
#endif
26+
uint ShaderIndex = 0;
27+
#ifdef SPEC
28+
[[vk::constant_id(3)]] const
29+
#endif
30+
uint NumThreadsX = 512;
2231

2332
[Shader("node")]
2433
[NodeLaunch("broadcasting")]
@@ -51,3 +60,8 @@ void main(const uint svGroupIndex : SV_GroupIndex,
5160
// CHECK-DAG: [[SHADERINDEX:%[_0-9A-Za-z]*]] = OpSpecConstant [[UINT]] 0
5261
// CHECK-DAG: [[NUMTHREADSX:%[_0-9A-Za-z]*]] = OpSpecConstant [[UINT]] 512
5362

63+
// NOSPEC-DAG: error: 'MaxRecords' attribute requires an integer constant
64+
// NOSPEC-DAG: error: 'NodeID' attribute requires an integer constant
65+
// NOSPEC-DAG: error: 'NodeDispatchGrid' attribute requires an integer constant
66+
// NOSPEC-DAG: error: 'NumThreads' attribute requires an integer constant
67+

0 commit comments

Comments
 (0)