@@ -15840,8 +15840,10 @@ void SpirvEmitter::addDerivativeGroupExecutionMode() {
1584015840 assert(spvContext.isCS());
1584115841
1584215842 SpirvExecutionMode *numThreadsEm =
15843- cast <SpirvExecutionMode>(spvBuilder.getModule()->findExecutionMode(
15843+ dyn_cast <SpirvExecutionMode>(spvBuilder.getModule()->findExecutionMode(
1584415844 entryFunction, spv::ExecutionMode::LocalSize));
15845+ if (!numThreadsEm)
15846+ return addDerivativeGroupExecutionModeId();
1584515847 auto numThreads = numThreadsEm->getParams();
1584615848
1584715849 // The layout of the quad is determined by the numer of threads in each
@@ -15868,6 +15870,47 @@ void SpirvEmitter::addDerivativeGroupExecutionMode() {
1586815870 spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation());
1586915871}
1587015872
15873+ void SpirvEmitter::addDerivativeGroupExecutionModeId() {
15874+ assert(spvContext.isCS());
15875+
15876+ SpirvExecutionModeId *numThreadsEm =
15877+ dyn_cast<SpirvExecutionModeId>(spvBuilder.getModule()->findExecutionMode(
15878+ entryFunction, spv::ExecutionMode::LocalSizeId));
15879+ auto numThreads = numThreadsEm->getParams();
15880+ auto f = [this](SpirvInstruction *arg) -> llvm::Optional<unsigned> {
15881+ if (auto con = dyn_cast<SpirvConstantInteger>(arg)) {
15882+ return (unsigned)con->getValue().getZExtValue();
15883+ }
15884+ return llvm::None;
15885+ };
15886+
15887+ // The layout of the quad is determined by the numer of threads in each
15888+ // dimention. From the HLSL spec
15889+ // (https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html):
15890+ //
15891+ // Where numthreads has an X value divisible by 4 and Y and Z are both 1, the
15892+ // quad layouts are determined according to 1D quad rules. Where numthreads X
15893+ // and Y values are divisible by 2, the quad layouts are determined according
15894+ // to 2D quad rules. Using derivative operations in any numthreads
15895+ // configuration not matching either of these is invalid and will produce an
15896+ // error.
15897+ static_assert(spv::ExecutionMode::DerivativeGroupQuadsNV ==
15898+ spv::ExecutionMode::DerivativeGroupQuadsKHR);
15899+ static_assert(spv::ExecutionMode::DerivativeGroupLinearNV ==
15900+ spv::ExecutionMode::DerivativeGroupLinearKHR);
15901+ spv::ExecutionMode em = spv::ExecutionMode::DerivativeGroupQuadsNV;
15902+ auto x = f(numThreads[0]), y = f(numThreads[1]), z = f(numThreads[2]);
15903+ if (x.hasValue() && x.getValue() % 4 == 0 && y.hasValue() &&
15904+ y.getValue() == 1 && z.hasValue() && z.getValue() == 1) {
15905+ em = spv::ExecutionMode::DerivativeGroupLinearNV;
15906+ } else {
15907+ assert((!x.hasValue() || x.getValue() % 2 == 0) &&
15908+ (!y.hasValue() || y.getValue() % 2 == 0));
15909+ }
15910+
15911+ spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation());
15912+ }
15913+
1587115914SpirvVariable *SpirvEmitter::createPCFParmVarAndInitFromStageInputVar(
1587215915 const ParmVarDecl *param) {
1587315916 const QualType type = param->getType();
0 commit comments