Skip to content

Commit 736e317

Browse files
committed
Add const shadows for forward mode AD on RegionBranchOpInterface
When an operand has const activity, we may still need to create a shadow for it (particularly in scf.for, where a const iter_arg may still have to be shadowed if the result and the terminator are active)
1 parent 72e76ee commit 736e317

2 files changed

Lines changed: 149 additions & 20 deletions

File tree

enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp

Lines changed: 98 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,14 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
318318
// add the shadow as operand.
319319
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
320320
if (!regionBranchOp) {
321-
op->emitError() << " RegionBranchOpInterface not implemented for " << *op
322-
<< "\n";
323-
return failure();
321+
return op->emitError() << " RegionBranchOpInterface not implemented for "
322+
<< *op << "\n";
324323
}
325324
auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
326325
if (!iface) {
327-
op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for "
328-
<< *op << "\n";
329-
return failure();
326+
return op->emitError()
327+
<< " ControlFlowAutoDiffOpInterface not implemented for " << *op
328+
<< "\n";
330329
}
331330

332331
// TODO: we may need to record, for every successor, which of its inputs
@@ -352,8 +351,37 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
352351
// operands.
353352
for (auto &&[i, regionValue, operand] :
354353
llvm::enumerate(targetValues, operandRange)) {
355-
if (gutils->isConstantValue(regionValue))
356-
continue;
354+
355+
// if all the possible predecessors for this value are also const, then
356+
// we can skip creating a shadow. Else we need to create a shadow for
357+
// activity correctness
358+
if (gutils->isConstantValue(regionValue)) {
359+
SmallVector<Value> possibleActivePreds;
360+
SmallVector<RegionBranchPoint> predecessors;
361+
regionBranchOp.getPredecessors(successor, predecessors);
362+
for (RegionBranchPoint predecessor : predecessors) {
363+
if (predecessor.isParent()) {
364+
// if the predecessor is the parent itself, then it's just
365+
// `operand`
366+
possibleActivePreds.push_back(operand);
367+
continue;
368+
}
369+
auto terminator = predecessor.getTerminatorPredecessorOrNull();
370+
auto predecessorOperands = terminator.getSuccessorOperands(successor);
371+
if (i < predecessorOperands.size())
372+
possibleActivePreds.push_back(predecessorOperands[i]);
373+
}
374+
375+
bool skipOpShadow = true;
376+
for (auto pv : possibleActivePreds) {
377+
if (!skipOpShadow)
378+
break;
379+
skipOpShadow = skipOpShadow && gutils->isConstantValue(pv);
380+
};
381+
if (skipOpShadow)
382+
continue;
383+
// if there's any possible active predecessor, we create a shadow for it
384+
}
357385
operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i);
358386
if (successor.isParent())
359387
resultPositionsToShadow.insert(i);
@@ -372,9 +400,11 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
372400
Operation *op, OpBuilder &builder, MGradientUtils *gutils,
373401
const llvm::SmallDenseSet<unsigned> &operandPositionsToShadow,
374402
const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow) {
403+
375404
// For all active results, add shadow types.
376405
// For now, assuming all results are relevant.
377406
Operation *newOp = gutils->getNewFromOriginal(op);
407+
bool hasConstOperandShadow = false;
378408
SmallVector<Type> newOpResultTypes;
379409
newOpResultTypes.reserve(op->getNumResults() * 2);
380410
for (auto result : op->getResults()) {
@@ -388,33 +418,81 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
388418
continue;
389419
auto typeIface = dyn_cast<AutoDiffTypeInterface>(result.getType());
390420
if (!typeIface) {
391-
op->emitError() << " AutoDiffTypeInterface not implemented for "
392-
<< result.getType() << "\n";
393-
return failure();
421+
return op->emitError() << " AutoDiffTypeInterface not implemented for "
422+
<< result.getType() << "\n";
394423
}
395424
newOpResultTypes.push_back(typeIface.getShadowType(gutils->width));
396425
}
397426

398427
SmallVector<Value> newOperands;
399428
newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size());
429+
430+
auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
431+
if (!iface) {
432+
return op->emitError()
433+
<< " ControlFlowAutoDiffOpInterface not implemented for " << *op
434+
<< "\n";
435+
}
436+
437+
auto regionBranchOp = cast<RegionBranchOpInterface>(op);
438+
SmallVector<RegionSuccessor> entrySuccessors;
439+
regionBranchOp.getEntrySuccessorRegions(
440+
SmallVector<Attribute>(op->getNumOperands(), Attribute()),
441+
entrySuccessors);
400442
for (OpOperand &operand : op->getOpOperands()) {
401443
newOperands.push_back(gutils->getNewFromOriginal(operand.get()));
402-
if (operandPositionsToShadow.contains(operand.getOperandNumber()))
403-
newOperands.push_back(gutils->invertPointerM(operand.get(), builder));
444+
if (operandPositionsToShadow.contains(operand.getOperandNumber())) {
445+
Value shadowValue = nullptr;
446+
if (!gutils->isConstantValue(operand.get()))
447+
shadowValue = gutils->invertPointerM(operand.get(), builder);
448+
else {
449+
auto Ty = operand.get().getType();
450+
auto shadowType =
451+
cast<AutoDiffTypeInterface>(Ty).getShadowType(gutils->width);
452+
shadowValue = cast<AutoDiffTypeInterface>(shadowType)
453+
.createNullValue(builder, operand.get().getLoc());
454+
hasConstOperandShadow = true;
455+
456+
// modify block arguments for entry successors to newOp, since
457+
// forceAugmentedReturns will not shadow const operands. No need to add
458+
// to the invertPointers map since `operand` is const (the shadow will
459+
// be unused)
460+
for (const RegionSuccessor &successor : entrySuccessors) {
461+
if (successor.isParent())
462+
continue;
463+
auto &newOpRegion =
464+
newOp->getRegion(successor.getSuccessor()->getRegionNumber());
465+
OperandRange succOperands =
466+
iface.getSuccessorOperands(regionBranchOp, successor);
467+
ValueRange succInputs = regionBranchOp.getSuccessorInputs(successor);
468+
auto succInputPos =
469+
operand.getOperandNumber() - succOperands.getBeginOperandIndex();
470+
471+
if (succInputPos >= 0 && succInputPos < succInputs.size()) {
472+
auto newOpBlockVal = dyn_cast<BlockArgument>(
473+
gutils->getNewFromOriginal(succInputs[succInputPos]));
474+
auto i = newOpBlockVal.getArgNumber();
475+
mlir::Value dval = nullptr;
476+
if (i == newOpRegion.getNumArguments() - 1) {
477+
dval =
478+
newOpRegion.addArgument(shadowType, newOpBlockVal.getLoc());
479+
} else {
480+
dval = newOpRegion.insertArgument(
481+
newOpRegion.args_begin() + i + 1, shadowType,
482+
newOpBlockVal.getLoc());
483+
}
484+
}
485+
}
486+
}
487+
newOperands.push_back(shadowValue);
488+
}
404489
}
405490
// We are assuming the op can forward additional operands, listed
406491
// immediately after the original operands, to the same regions.
407492
// ^^
408493
// Our interface guarantees this.
409494
// We also assume that the region-holding op returns all of the values
410495
// yielded by terminators, and only those values.
411-
412-
auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
413-
if (!iface) {
414-
op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for "
415-
<< *op << "\n";
416-
return failure();
417-
}
418496
Operation *replacement = iface.createWithShadows(
419497
builder, gutils, op, newOperands, newOpResultTypes);
420498
assert(replacement->getNumResults() == newOpResultTypes.size());
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: %eopt --enzyme --split-input-file %s | FileCheck %s
2+
3+
module {
4+
func.func @carry_mismatch_scf(%x : f64) -> f64 {
5+
%zero = arith.constant 0.0 : f64
6+
%c0 = arith.constant 0 : index
7+
%c1 = arith.constant 1 : index
8+
%c10 = arith.constant 10 : index
9+
%r = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %zero) -> (f64) {
10+
scf.yield %x : f64
11+
}
12+
return %r : f64
13+
}
14+
15+
func.func @dcarry_mismatch_scf(%x : f64, %dx : f64) -> f64 {
16+
%r = enzyme.fwddiff @carry_mismatch_scf(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
17+
return %r : f64
18+
}
19+
}
20+
21+
// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_scf(
22+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
23+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
24+
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
25+
// CHECK: %[[LOOP:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ACC:.+]] = %{{.+}}, %[[DACC:.+]] = %{{.+}}) -> (f64, f64) {
26+
// CHECK-NEXT: scf.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64
27+
// CHECK-NEXT: }
28+
// CHECK-NEXT: return %[[LOOP]]#1 : f64
29+
30+
// -----
31+
32+
module {
33+
func.func @carry_mismatch_affine(%x : f64) -> f64 {
34+
%zero = arith.constant 0.0 : f64
35+
%r = affine.for %i = 0 to 10 iter_args(%acc = %zero) -> (f64) {
36+
affine.yield %x : f64
37+
}
38+
return %r : f64
39+
}
40+
41+
func.func @dcarry_mismatch_affine(%x : f64, %dx : f64) -> f64 {
42+
%r = enzyme.fwddiff @carry_mismatch_affine(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
43+
return %r : f64
44+
}
45+
}
46+
47+
// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_affine(
48+
// CHECK: %[[ALOOP:.+]]:2 = affine.for %[[AIV:.+]] = 0 to 10 iter_args(%[[AACC:.+]] = %{{.+}}, %[[ADACC:.+]] = %{{.+}}) -> (f64, f64) {
49+
// CHECK-NEXT: affine.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64
50+
// CHECK-NEXT: }
51+
// CHECK-NEXT: return %[[ALOOP]]#1 : f64

0 commit comments

Comments
 (0)