@@ -15874,8 +15874,12 @@ void SpirvEmitter::addDerivativeGroupExecutionModeId() {
1587415874 dyn_cast<SpirvExecutionModeId>(spvBuilder.getModule()->findExecutionMode(
1587515875 entryFunction, spv::ExecutionMode::LocalSizeId));
1587615876 auto numThreads = numThreadsEm->getParams();
15877- auto f = [this](SpirvInstruction *arg) -> llvm::Optional<unsigned> {
15877+ bool numThreadsHasSpecConst = false;
15878+ auto f = [&numThreadsHasSpecConst](
15879+ SpirvInstruction *arg) -> llvm::Optional<unsigned> {
1587815880 if (auto con = dyn_cast<SpirvConstantInteger>(arg)) {
15881+ if (con->isSpecConstant())
15882+ numThreadsHasSpecConst = true;
1587915883 return (unsigned)con->getValue().getZExtValue();
1588015884 }
1588115885 return llvm::None;
@@ -15905,6 +15909,16 @@ void SpirvEmitter::addDerivativeGroupExecutionModeId() {
1590515909 (!y.hasValue() || y.getValue() % 2 == 0));
1590615910 }
1590715911
15912+ if (numThreadsHasSpecConst) {
15913+ // This code probably belongs in DiagnoseNumThreadsForDerivativeOp() in
15914+ // SemaHLSL.cpp, but that function apparently isn't invoked in all
15915+ // applicable situations.
15916+ diags.Report(
15917+ numThreadsEm->getSourceLocation(),
15918+ diags.getCustomDiagID(DiagnosticsEngine::Level::Warning,
15919+ "NumThreads spec constant default value used to "
15920+ "determine derivative group mode"));
15921+ }
1590815922 spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation());
1590915923}
1591015924
0 commit comments