Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,14 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
// add the shadow as operand.
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
if (!regionBranchOp) {
op->emitError() << " RegionBranchOpInterface not implemented for " << *op
<< "\n";
return failure();
return op->emitError() << " RegionBranchOpInterface not implemented for "
<< *op << "\n";
}
auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
if (!iface) {
op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for "
<< *op << "\n";
return failure();
return op->emitError()
<< " ControlFlowAutoDiffOpInterface not implemented for " << *op
<< "\n";
}

// TODO: we may need to record, for every successor, which of its inputs
Expand All @@ -352,8 +351,37 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
// operands.
for (auto &&[i, regionValue, operand] :
llvm::enumerate(targetValues, operandRange)) {
if (gutils->isConstantValue(regionValue))
continue;

// if all the possible predecessors for this value are also const, then
// we can skip creating a shadow. Else we need to create a shadow for
// activity correctness
if (gutils->isConstantValue(regionValue)) {
SmallVector<Value> possibleActivePreds;
SmallVector<RegionBranchPoint> predecessors;
regionBranchOp.getPredecessors(successor, predecessors);
for (RegionBranchPoint predecessor : predecessors) {
if (predecessor.isParent()) {
// if the predecessor is the parent itself, then it's just
// `operand`
possibleActivePreds.push_back(operand);
continue;
}
auto terminator = predecessor.getTerminatorPredecessorOrNull();
auto predecessorOperands = terminator.getSuccessorOperands(successor);
if (i < predecessorOperands.size())
possibleActivePreds.push_back(predecessorOperands[i]);
}

bool skipOpShadow = true;
for (auto pv : possibleActivePreds) {
if (!skipOpShadow)
break;
skipOpShadow = skipOpShadow && gutils->isConstantValue(pv);
};
if (skipOpShadow)
continue;
// if there's any possible active predecessor, we create a shadow for it
}
operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i);
if (successor.isParent())
resultPositionsToShadow.insert(i);
Expand Down Expand Up @@ -388,33 +416,87 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
continue;
auto typeIface = dyn_cast<AutoDiffTypeInterface>(result.getType());
if (!typeIface) {
op->emitError() << " AutoDiffTypeInterface not implemented for "
<< result.getType() << "\n";
return failure();
return op->emitError() << " AutoDiffTypeInterface not implemented for "
<< result.getType() << "\n";
}
newOpResultTypes.push_back(typeIface.getShadowType(gutils->width));
}

SmallVector<Value> newOperands;
newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size());

auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
if (!iface) {
return op->emitError()
<< " ControlFlowAutoDiffOpInterface not implemented for " << *op
<< "\n";
}

auto regionBranchOp = cast<RegionBranchOpInterface>(op);
SmallVector<RegionSuccessor> entrySuccessors;
regionBranchOp.getEntrySuccessorRegions(
SmallVector<Attribute>(op->getNumOperands(), Attribute()),
entrySuccessors);
for (OpOperand &operand : op->getOpOperands()) {
newOperands.push_back(gutils->getNewFromOriginal(operand.get()));
if (operandPositionsToShadow.contains(operand.getOperandNumber()))
newOperands.push_back(gutils->invertPointerM(operand.get(), builder));
if (operandPositionsToShadow.contains(operand.getOperandNumber())) {
Value shadowValue = nullptr;
if (!gutils->isConstantValue(operand.get()))
shadowValue = gutils->invertPointerM(operand.get(), builder);
else {
auto Ty = operand.get().getType();
auto shadowType =
cast<AutoDiffTypeInterface>(Ty).getShadowType(gutils->width);
shadowValue = cast<AutoDiffTypeInterface>(shadowType)
.createNullValue(builder, operand.get().getLoc());

// modify block arguments for entry successors to newOp, since
// forceAugmentedReturns will not shadow const operands. No need to add
// to the invertPointers map since `operand` is const (the shadow will
// be unused)
for (const RegionSuccessor &successor : entrySuccessors) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be here, but within the impl of createWithShadows, right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it has to be here. We are shadowing only the const operands for any augmented newOp.

Copy link
Copy Markdown
Member Author

@vimarsh6739 vimarsh6739 May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

createWithShadows correctly adds const shadows. But newOp doesn't have them(since it was created with forceAugmentedReturns), so the takeBody() discards the args in the replacement op

We have 2 choices here - either add the const shadows to newOp or the replacement (after takeBody). The end result is the same.

if (successor.isParent())
continue;
auto &newOpRegion =
newOp->getRegion(successor.getSuccessor()->getRegionNumber());
OperandRange succOperands =
iface.getSuccessorOperands(regionBranchOp, successor);
ValueRange succInputs = regionBranchOp.getSuccessorInputs(successor);

if (succOperands.empty())
continue;

auto succInputPos =
operand.getOperandNumber() - succOperands.getBeginOperandIndex();

if (succInputPos >= 0 && succInputPos < succInputs.size()) {
auto oldRegionInput =
dyn_cast<BlockArgument>(succInputs[succInputPos]);
if (!oldRegionInput)
continue;
if (gutils->invertedPointers.contains(oldRegionInput))
continue;
auto newOpBlockVal =
cast<BlockArgument>(gutils->getNewFromOriginal(oldRegionInput));
auto i = newOpBlockVal.getArgNumber();
if (i == newOpRegion.getNumArguments() - 1) {
newOpRegion.addArgument(shadowType, newOpBlockVal.getLoc());
} else {
newOpRegion.insertArgument(newOpRegion.args_begin() + i + 1,
shadowType, newOpBlockVal.getLoc());
}
}
}
}
newOperands.push_back(shadowValue);
}
}
// We are assuming the op can forward additional operands, listed
// immediately after the original operands, to the same regions.
// ^^
// Our interface guarantees this.
// We also assume that the region-holding op returns all of the values
// yielded by terminators, and only those values.

auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
if (!iface) {
op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for "
<< *op << "\n";
return failure();
}
Operation *replacement = iface.createWithShadows(
builder, gutils, op, newOperands, newOpResultTypes);
assert(replacement->getNumResults() == newOpResultTypes.size());
Expand Down
6 changes: 3 additions & 3 deletions enzyme/test/MLIR/ForwardMode/affine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ module {
}
// CHECK: @fwddiffeloop
// CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64)
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64
// CHECK: %[[r0:.+]]:2 = affine.for %{{.*}} = 0 to 10 iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, f64) {
// CHECK-DAG: %[[TEN:.+]] = arith.constant 1.000000e+01 : f64
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[r0:.+]]:2 = affine.for %{{.*}} = 0 to 10 iter_args(%[[arg3:.+]] = %[[TEN]], %[[arg4:.+]] = %[[ZERO]]) -> (f64, f64) {
// CHECK: %[[v1:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64
// CHECK: %[[v2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64
// CHECK: affine.yield %[[v2]], %[[v1]] : f64, f64
Expand Down
51 changes: 51 additions & 0 deletions enzyme/test/MLIR/ForwardMode/for3.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: %eopt --enzyme --split-input-file %s | FileCheck %s

module {
func.func @carry_mismatch_scf(%x : f64) -> f64 {
%zero = arith.constant 0.0 : f64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%r = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %zero) -> (f64) {
scf.yield %x : f64
}
return %r : f64
}

func.func @dcarry_mismatch_scf(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @carry_mismatch_scf(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}
}

// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_scf(
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
// CHECK: %[[LOOP:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ACC:.+]] = %{{.+}}, %[[DACC:.+]] = %{{.+}}) -> (f64, f64) {
// CHECK-NEXT: scf.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64
// CHECK-NEXT: }
// CHECK-NEXT: return %[[LOOP]]#1 : f64

// -----

module {
func.func @carry_mismatch_affine(%x : f64) -> f64 {
%zero = arith.constant 0.0 : f64
%r = affine.for %i = 0 to 10 iter_args(%acc = %zero) -> (f64) {
affine.yield %x : f64
}
return %r : f64
}

func.func @dcarry_mismatch_affine(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @carry_mismatch_affine(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}
}

// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_affine(
// CHECK: %[[ALOOP:.+]]:2 = affine.for %[[AIV:.+]] = 0 to 10 iter_args(%[[AACC:.+]] = %{{.+}}, %[[ADACC:.+]] = %{{.+}}) -> (f64, f64) {
// CHECK-NEXT: affine.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64
// CHECK-NEXT: }
// CHECK-NEXT: return %[[ALOOP]]#1 : f64
4 changes: 2 additions & 2 deletions enzyme/test/MLIR/ForwardMode/parallel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ module {

// CHECK: @fwddiffematvec(%[[arg0:.+]]: memref<?x?xf64>, %[[arg1:.+]]: memref<?x?xf64>, %[[arg2:.+]]: memref<?xf64>, %[[arg3:.+]]: memref<?xf64>, %[[arg4:.+]]: memref<?xf64>, %[[arg5:.+]]: memref<?xf64>) {
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[c1:.+]] = arith.constant 1 : index
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[dim:.+]] = memref.dim %[[arg0:.+]], %[[c0]] : memref<?x?xf64>
// CHECK: %[[dim_1:.+]] = memref.dim %[[arg0:.+]], %[[c1]] : memref<?x?xf64>
// CHECK: scf.parallel (%[[arg6:.+]]) = (%[[c0]]) to (%dim) step (%[[c1]]) {
// CHECK: %[[x0:.+]]:2 = scf.for %[[arg7:.+]] = %[[c0]] to %dim_1 step %[[c1]] iter_args(%[[arg8:.+]] = %[[cst_0]], %[[arg9:.+]] = %[[cst]]) -> (f64, f64) {
// CHECK: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[x0:.+]]:2 = scf.for %[[arg7:.+]] = %[[c0]] to %[[dim_1]] step %[[c1]] iter_args(%[[arg8:.+]] = %[[cst]], %[[arg9:.+]] = %[[cst_0]]) -> (f64, f64) {
// CHECK: %[[x1:.+]] = memref.load %[[arg1]][%[[arg6]], %[[arg7]]] : memref<?x?xf64>
// CHECK: %[[x2:.+]] = memref.load %[[arg0]][%[[arg6]], %[[arg7]]] : memref<?x?xf64>
// CHECK: %[[x3:.+]] = memref.load %[[arg3]][%[[arg7]]] : memref<?xf64>
Expand Down
10 changes: 5 additions & 5 deletions enzyme/test/MLIR/ForwardMode/parallel_reduce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ module {
}

// CHECK: @fwddiffenrm2(%[[arg0:.+]]: memref<?xf64>, %[[arg1:.+]]: memref<?xf64>) -> f64 {
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[c1:.+]] = arith.constant 1 : index
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[dim:.+]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf64>
// CHECK: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[dim:.+]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf64>
// CHECK: %[[x0:.+]]:2 = scf.parallel (%[[arg2:.+]]) = (%[[c0]]) to (%dim) step (%[[c1]]) init (%[[cst_0]], %[[cst]]) -> (f64, f64) {
// CHECK: %[[x0:.+]]:2 = scf.parallel (%[[arg2:.+]]) = (%[[c0]]) to (%[[dim]]) step (%[[c1]]) init (%[[cst]], %[[cst_0]]) -> (f64, f64) {
// CHECK: %[[x1:.+]] = memref.load %[[arg1]][%[[arg2]]] : memref<?xf64>
// CHECK: %[[x2:.+]] = memref.load %[[arg0]][%[[arg2]]] : memref<?xf64>
// CHECK: %[[x3:.+]] = arith.mulf %[[x1]], %[[x2]] : f64
Expand Down
12 changes: 6 additions & 6 deletions enzyme/test/MLIR/ForwardMode/while.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ module {
}
// CHECK: @fwddiffewhile
// CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[c1:.+]] = arith.constant 1 : index
// CHECK: %[[c10:.+]] = arith.constant 10 : index
// CHECK: %[[r0:.+]]:3 = scf.while (%[[arg2:.+]] = %[[c0]], %[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) : (index, f64, f64) -> (index, f64, f64) {
// CHECK-DAG: %[[TEN:.+]] = arith.constant 1.000000e+01 : f64
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
// CHECK: %[[r0:.+]]:3 = scf.while (%[[arg2:.+]] = %[[c0]], %[[arg3:.+]] = %[[TEN]], %[[arg4:.+]] = %[[ZERO]]) : (index, f64, f64) -> (index, f64, f64) {
// CHECK: %[[v1:.+]] = arith.cmpi slt, %[[arg2]], %[[c10]] : index
// CHECK: scf.condition(%[[v1]]) %[[arg2]], %[[arg3]], %[[arg4]] : index, f64, f64
// CHECK: } do {
Expand Down
Loading