From 9259e1868433335cb13aba5cf9fe4fa1028d2bcd Mon Sep 17 00:00:00 2001 From: Vimarsh Sathia Date: Wed, 8 Apr 2026 15:22:10 -0500 Subject: [PATCH] Add const shadows for forward mode AD on `RegionBranchOpInterface` When an operand has const activity, we may still need to create a shadow for it (particularly in scf.for, where a const iter_arg may still have to be shadowed if the result and the terminator are active) --- .../CoreDialectsAutoDiffImplementations.cpp | 122 +++++++++++++++--- enzyme/test/MLIR/ForwardMode/affine.mlir | 6 +- enzyme/test/MLIR/ForwardMode/for3.mlir | 51 ++++++++ enzyme/test/MLIR/ForwardMode/parallel.mlir | 4 +- .../MLIR/ForwardMode/parallel_reduce.mlir | 10 +- enzyme/test/MLIR/ForwardMode/while.mlir | 12 +- 6 files changed, 169 insertions(+), 36 deletions(-) create mode 100644 enzyme/test/MLIR/ForwardMode/for3.mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index aceebb951173..abb3dc4ddbab 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -318,15 +318,14 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( // add the shadow as operand. auto regionBranchOp = dyn_cast(op); if (!regionBranchOp) { - op->emitError() << " RegionBranchOpInterface not implemented for " << *op - << "\n"; - return failure(); + return op->emitError() << " RegionBranchOpInterface not implemented for " + << *op << "\n"; } auto iface = dyn_cast(op); if (!iface) { - op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for " - << *op << "\n"; - return failure(); + return op->emitError() + << " ControlFlowAutoDiffOpInterface not implemented for " << *op + << "\n"; } // TODO: we may need to record, for every successor, which of its inputs @@ -352,8 +351,37 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( // operands. for (auto &&[i, regionValue, operand] : llvm::enumerate(targetValues, operandRange)) { - if (gutils->isConstantValue(regionValue)) - continue; + + // if all the possible predecessors for this value are also const, then + // we can skip creating a shadow. Else we need to create a shadow for + // activity correctness + if (gutils->isConstantValue(regionValue)) { + SmallVector possibleActivePreds; + SmallVector predecessors; + regionBranchOp.getPredecessors(successor, predecessors); + for (RegionBranchPoint predecessor : predecessors) { + if (predecessor.isParent()) { + // if the predecessor is the parent itself, then it's just + // `operand` + possibleActivePreds.push_back(operand); + continue; + } + auto terminator = predecessor.getTerminatorPredecessorOrNull(); + auto predecessorOperands = terminator.getSuccessorOperands(successor); + if (i < predecessorOperands.size()) + possibleActivePreds.push_back(predecessorOperands[i]); + } + + bool skipOpShadow = true; + for (auto pv : possibleActivePreds) { + if (!skipOpShadow) + break; + skipOpShadow = skipOpShadow && gutils->isConstantValue(pv); + }; + if (skipOpShadow) + continue; + // if there's any possible active predecessor, we create a shadow for it + } operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i); if (successor.isParent()) resultPositionsToShadow.insert(i); @@ -388,19 +416,80 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( continue; auto typeIface = dyn_cast(result.getType()); if (!typeIface) { - op->emitError() << " AutoDiffTypeInterface not implemented for " - << result.getType() << "\n"; - return failure(); + return op->emitError() << " AutoDiffTypeInterface not implemented for " + << result.getType() << "\n"; } newOpResultTypes.push_back(typeIface.getShadowType(gutils->width)); } SmallVector newOperands; newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size()); + + auto iface = dyn_cast(op); + if (!iface) { + return op->emitError() + << " ControlFlowAutoDiffOpInterface not implemented for " << *op + << "\n"; + } + + auto regionBranchOp = cast(op); + SmallVector entrySuccessors; + regionBranchOp.getEntrySuccessorRegions( + SmallVector(op->getNumOperands(), Attribute()), + entrySuccessors); for (OpOperand &operand : op->getOpOperands()) { newOperands.push_back(gutils->getNewFromOriginal(operand.get())); - if (operandPositionsToShadow.contains(operand.getOperandNumber())) - newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); + if (operandPositionsToShadow.contains(operand.getOperandNumber())) { + Value shadowValue = nullptr; + if (!gutils->isConstantValue(operand.get())) + shadowValue = gutils->invertPointerM(operand.get(), builder); + else { + auto Ty = operand.get().getType(); + auto shadowType = + cast(Ty).getShadowType(gutils->width); + shadowValue = cast(shadowType) + .createNullValue(builder, operand.get().getLoc()); + + // modify block arguments for entry successors to newOp, since + // forceAugmentedReturns will not shadow const operands. No need to add + // to the invertPointers map since `operand` is const (the shadow will + // be unused) + for (const RegionSuccessor &successor : entrySuccessors) { + if (successor.isParent()) + continue; + auto &newOpRegion = + newOp->getRegion(successor.getSuccessor()->getRegionNumber()); + OperandRange succOperands = + iface.getSuccessorOperands(regionBranchOp, successor); + ValueRange succInputs = regionBranchOp.getSuccessorInputs(successor); + + if (succOperands.empty()) + continue; + + auto succInputPos = + operand.getOperandNumber() - succOperands.getBeginOperandIndex(); + + if (succInputPos >= 0 && succInputPos < succInputs.size()) { + auto oldRegionInput = + dyn_cast(succInputs[succInputPos]); + if (!oldRegionInput) + continue; + if (gutils->invertedPointers.contains(oldRegionInput)) + continue; + auto newOpBlockVal = + cast(gutils->getNewFromOriginal(oldRegionInput)); + auto i = newOpBlockVal.getArgNumber(); + if (i == newOpRegion.getNumArguments() - 1) { + newOpRegion.addArgument(shadowType, newOpBlockVal.getLoc()); + } else { + newOpRegion.insertArgument(newOpRegion.args_begin() + i + 1, + shadowType, newOpBlockVal.getLoc()); + } + } + } + } + newOperands.push_back(shadowValue); + } } // We are assuming the op can forward additional operands, listed // immediately after the original operands, to the same regions. @@ -408,13 +497,6 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( // Our interface guarantees this. // We also assume that the region-holding op returns all of the values // yielded by terminators, and only those values. - - auto iface = dyn_cast(op); - if (!iface) { - op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for " - << *op << "\n"; - return failure(); - } Operation *replacement = iface.createWithShadows( builder, gutils, op, newOperands, newOpResultTypes); assert(replacement->getNumResults() == newOpResultTypes.size()); diff --git a/enzyme/test/MLIR/ForwardMode/affine.mlir b/enzyme/test/MLIR/ForwardMode/affine.mlir index 15454cf1e03b..75c4e4a82b65 100644 --- a/enzyme/test/MLIR/ForwardMode/affine.mlir +++ b/enzyme/test/MLIR/ForwardMode/affine.mlir @@ -15,9 +15,9 @@ module { } // CHECK: @fwddiffeloop // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) - // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 - // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 - // CHECK: %[[r0:.+]]:2 = affine.for %{{.*}} = 0 to 10 iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, f64) { + // CHECK-DAG: %[[TEN:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: %[[r0:.+]]:2 = affine.for %{{.*}} = 0 to 10 iter_args(%[[arg3:.+]] = %[[TEN]], %[[arg4:.+]] = %[[ZERO]]) -> (f64, f64) { // CHECK: %[[v1:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64 // CHECK: %[[v2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 // CHECK: affine.yield %[[v2]], %[[v1]] : f64, f64 diff --git a/enzyme/test/MLIR/ForwardMode/for3.mlir b/enzyme/test/MLIR/ForwardMode/for3.mlir new file mode 100644 index 000000000000..a958a00c5680 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/for3.mlir @@ -0,0 +1,51 @@ +// RUN: %eopt --enzyme --split-input-file %s | FileCheck %s + +module { + func.func @carry_mismatch_scf(%x : f64) -> f64 { + %zero = arith.constant 0.0 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %zero) -> (f64) { + scf.yield %x : f64 + } + return %r : f64 + } + + func.func @dcarry_mismatch_scf(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @carry_mismatch_scf(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } +} + +// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_scf( +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK: %[[LOOP:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ACC:.+]] = %{{.+}}, %[[DACC:.+]] = %{{.+}}) -> (f64, f64) { +// CHECK-NEXT: scf.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[LOOP]]#1 : f64 + +// ----- + +module { + func.func @carry_mismatch_affine(%x : f64) -> f64 { + %zero = arith.constant 0.0 : f64 + %r = affine.for %i = 0 to 10 iter_args(%acc = %zero) -> (f64) { + affine.yield %x : f64 + } + return %r : f64 + } + + func.func @dcarry_mismatch_affine(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @carry_mismatch_affine(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } +} + +// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_affine( +// CHECK: %[[ALOOP:.+]]:2 = affine.for %[[AIV:.+]] = 0 to 10 iter_args(%[[AACC:.+]] = %{{.+}}, %[[ADACC:.+]] = %{{.+}}) -> (f64, f64) { +// CHECK-NEXT: affine.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[ALOOP]]#1 : f64 diff --git a/enzyme/test/MLIR/ForwardMode/parallel.mlir b/enzyme/test/MLIR/ForwardMode/parallel.mlir index 6e566652b0d3..4b368906a764 100644 --- a/enzyme/test/MLIR/ForwardMode/parallel.mlir +++ b/enzyme/test/MLIR/ForwardMode/parallel.mlir @@ -32,13 +32,13 @@ module { // CHECK: @fwddiffematvec(%[[arg0:.+]]: memref, %[[arg1:.+]]: memref, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref, %[[arg4:.+]]: memref, %[[arg5:.+]]: memref) { // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64 // CHECK: %[[c1:.+]] = arith.constant 1 : index // CHECK: %[[c0:.+]] = arith.constant 0 : index // CHECK: %[[dim:.+]] = memref.dim %[[arg0:.+]], %[[c0]] : memref // CHECK: %[[dim_1:.+]] = memref.dim %[[arg0:.+]], %[[c1]] : memref // CHECK: scf.parallel (%[[arg6:.+]]) = (%[[c0]]) to (%dim) step (%[[c1]]) { -// CHECK: %[[x0:.+]]:2 = scf.for %[[arg7:.+]] = %[[c0]] to %dim_1 step %[[c1]] iter_args(%[[arg8:.+]] = %[[cst_0]], %[[arg9:.+]] = %[[cst]]) -> (f64, f64) { +// CHECK: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[x0:.+]]:2 = scf.for %[[arg7:.+]] = %[[c0]] to %[[dim_1]] step %[[c1]] iter_args(%[[arg8:.+]] = %[[cst]], %[[arg9:.+]] = %[[cst_0]]) -> (f64, f64) { // CHECK: %[[x1:.+]] = memref.load %[[arg1]][%[[arg6]], %[[arg7]]] : memref // CHECK: %[[x2:.+]] = memref.load %[[arg0]][%[[arg6]], %[[arg7]]] : memref // CHECK: %[[x3:.+]] = memref.load %[[arg3]][%[[arg7]]] : memref diff --git a/enzyme/test/MLIR/ForwardMode/parallel_reduce.mlir b/enzyme/test/MLIR/ForwardMode/parallel_reduce.mlir index 77c9cd5b3335..a4c21dcf965c 100644 --- a/enzyme/test/MLIR/ForwardMode/parallel_reduce.mlir +++ b/enzyme/test/MLIR/ForwardMode/parallel_reduce.mlir @@ -27,12 +27,12 @@ module { } // CHECK: @fwddiffenrm2(%[[arg0:.+]]: memref, %[[arg1:.+]]: memref) -> f64 { - // CHECK: %[[c0:.+]] = arith.constant 0 : index - // CHECK: %[[c1:.+]] = arith.constant 1 : index - // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK-DAG: %[[dim:.+]] = memref.dim %[[arg0]], %[[c0]] : memref // CHECK: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64 - // CHECK: %[[dim:.+]] = memref.dim %[[arg0]], %[[c0]] : memref - // CHECK: %[[x0:.+]]:2 = scf.parallel (%[[arg2:.+]]) = (%[[c0]]) to (%dim) step (%[[c1]]) init (%[[cst_0]], %[[cst]]) -> (f64, f64) { + // CHECK: %[[x0:.+]]:2 = scf.parallel (%[[arg2:.+]]) = (%[[c0]]) to (%[[dim]]) step (%[[c1]]) init (%[[cst]], %[[cst_0]]) -> (f64, f64) { // CHECK: %[[x1:.+]] = memref.load %[[arg1]][%[[arg2]]] : memref // CHECK: %[[x2:.+]] = memref.load %[[arg0]][%[[arg2]]] : memref // CHECK: %[[x3:.+]] = arith.mulf %[[x1]], %[[x2]] : f64 diff --git a/enzyme/test/MLIR/ForwardMode/while.mlir b/enzyme/test/MLIR/ForwardMode/while.mlir index 0f5d9fbfdc04..ba2e2e158fc4 100644 --- a/enzyme/test/MLIR/ForwardMode/while.mlir +++ b/enzyme/test/MLIR/ForwardMode/while.mlir @@ -24,12 +24,12 @@ module { } // CHECK: @fwddiffewhile // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { - // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 - // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 - // CHECK: %[[c0:.+]] = arith.constant 0 : index - // CHECK: %[[c1:.+]] = arith.constant 1 : index - // CHECK: %[[c10:.+]] = arith.constant 10 : index - // CHECK: %[[r0:.+]]:3 = scf.while (%[[arg2:.+]] = %[[c0]], %[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) : (index, f64, f64) -> (index, f64, f64) { + // CHECK-DAG: %[[TEN:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index + // CHECK: %[[r0:.+]]:3 = scf.while (%[[arg2:.+]] = %[[c0]], %[[arg3:.+]] = %[[TEN]], %[[arg4:.+]] = %[[ZERO]]) : (index, f64, f64) -> (index, f64, f64) { // CHECK: %[[v1:.+]] = arith.cmpi slt, %[[arg2]], %[[c10]] : index // CHECK: scf.condition(%[[v1]]) %[[arg2]], %[[arg3]], %[[arg4]] : index, f64, f64 // CHECK: } do {