Skip to content

Commit 27547de

Browse files
committed
Addresses PR comments.
1 parent 8e898bf commit 27547de

3 files changed

Lines changed: 54 additions & 13 deletions

File tree

tools/clang/lib/AST/ExprConstant.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9450,18 +9450,15 @@ bool Expr::isIntegerConstantExpr(llvm::APSInt &Value, const ASTContext &Ctx,
94509450

94519451
bool Expr::isVulkanSpecConstantExpr(const ASTContext &Ctx,
94529452
APValue *Result) const {
9453-
if (auto *D = dyn_cast<DeclRefExpr>(this)) {
9454-
if (auto *V = dyn_cast<VarDecl>(D->getDecl())) {
9455-
if (V->hasAttr<VKConstantIdAttr>()) {
9456-
if (const Expr *I = V->getAnyInitializer()) {
9457-
if (!I->isCXX11ConstantExpr(Ctx, Result))
9458-
return false;
9459-
}
9460-
return true;
9461-
}
9462-
}
9463-
}
9464-
return false;
9453+
auto *D = dyn_cast<DeclRefExpr>(this);
9454+
if (!D)
9455+
return false;
9456+
auto *V = dyn_cast<VarDecl>(D->getDecl());
9457+
if (!V || !V->hasAttr<VKConstantIdAttr>())
9458+
return false;
9459+
if (const Expr *I = V->getAnyInitializer())
9460+
return I->isCXX11ConstantExpr(Ctx, Result);
9461+
return true;
94659462
}
94669463

94679464
bool Expr::isCXX98IntegralConstantExpr(const ASTContext &Ctx) const {

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15840,8 +15840,10 @@ void SpirvEmitter::addDerivativeGroupExecutionMode() {
1584015840
assert(spvContext.isCS());
1584115841

1584215842
SpirvExecutionMode *numThreadsEm =
15843-
cast<SpirvExecutionMode>(spvBuilder.getModule()->findExecutionMode(
15843+
dyn_cast<SpirvExecutionMode>(spvBuilder.getModule()->findExecutionMode(
1584415844
entryFunction, spv::ExecutionMode::LocalSize));
15845+
if (!numThreadsEm)
15846+
return addDerivativeGroupExecutionModeId();
1584515847
auto numThreads = numThreadsEm->getParams();
1584615848

1584715849
// The layout of the quad is determined by the numer of threads in each
@@ -15868,6 +15870,47 @@ void SpirvEmitter::addDerivativeGroupExecutionMode() {
1586815870
spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation());
1586915871
}
1587015872

15873+
void SpirvEmitter::addDerivativeGroupExecutionModeId() {
15874+
assert(spvContext.isCS());
15875+
15876+
SpirvExecutionModeId *numThreadsEm =
15877+
dyn_cast<SpirvExecutionModeId>(spvBuilder.getModule()->findExecutionMode(
15878+
entryFunction, spv::ExecutionMode::LocalSizeId));
15879+
auto numThreads = numThreadsEm->getParams();
15880+
auto f = [this](SpirvInstruction *arg) -> llvm::Optional<unsigned> {
15881+
if (auto con = dyn_cast<SpirvConstantInteger>(arg)) {
15882+
return (unsigned)con->getValue().getZExtValue();
15883+
}
15884+
return llvm::None;
15885+
};
15886+
15887+
// The layout of the quad is determined by the numer of threads in each
15888+
// dimention. From the HLSL spec
15889+
// (https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html):
15890+
//
15891+
// Where numthreads has an X value divisible by 4 and Y and Z are both 1, the
15892+
// quad layouts are determined according to 1D quad rules. Where numthreads X
15893+
// and Y values are divisible by 2, the quad layouts are determined according
15894+
// to 2D quad rules. Using derivative operations in any numthreads
15895+
// configuration not matching either of these is invalid and will produce an
15896+
// error.
15897+
static_assert(spv::ExecutionMode::DerivativeGroupQuadsNV ==
15898+
spv::ExecutionMode::DerivativeGroupQuadsKHR);
15899+
static_assert(spv::ExecutionMode::DerivativeGroupLinearNV ==
15900+
spv::ExecutionMode::DerivativeGroupLinearKHR);
15901+
spv::ExecutionMode em = spv::ExecutionMode::DerivativeGroupQuadsNV;
15902+
auto x = f(numThreads[0]), y = f(numThreads[1]), z = f(numThreads[2]);
15903+
if (x.hasValue() && x.getValue() % 4 == 0 && y.hasValue() &&
15904+
y.getValue() == 1 && z.hasValue() && z.getValue() == 1) {
15905+
em = spv::ExecutionMode::DerivativeGroupLinearNV;
15906+
} else {
15907+
assert((!x.hasValue() || x.getValue() % 2 == 0) &&
15908+
(!y.hasValue() || y.getValue() % 2 == 0));
15909+
}
15910+
15911+
spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation());
15912+
}
15913+
1587115914
SpirvVariable *SpirvEmitter::createPCFParmVarAndInitFromStageInputVar(
1587215915
const ParmVarDecl *param) {
1587315916
const QualType type = param->getType();

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,7 @@ class SpirvEmitter : public ASTConsumer {
13711371
/// This decision is made according to the rules in
13721372
/// https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html.
13731373
void addDerivativeGroupExecutionMode();
1374+
void addDerivativeGroupExecutionModeId();
13741375

13751376
/// Creates an input variable for `param` that will be used by the patch
13761377
/// constant function. The parameter is also added to the patch constant

0 commit comments

Comments
 (0)