Skip to content

Commit 66cb97f

Browse files
committed
Fix forward mode AD for for-like ops
1 parent 72e76ee commit 66cb97f

3 files changed

Lines changed: 138 additions & 4 deletions

File tree

enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,11 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
334334
llvm::SmallDenseSet<unsigned> operandPositionsToShadow;
335335
llvm::SmallDenseSet<unsigned> resultPositionsToShadow;
336336

337+
// while these operands are inactive in the op region(s), we may still need to
338+
// create placeholder shadows for them to ensure syntactic correctness for the
339+
// IR
340+
llvm::SmallDenseSet<unsigned> constOperandPositionsToShadow;
341+
337342
SmallVector<RegionSuccessor> entrySuccessors;
338343
regionBranchOp.getEntrySuccessorRegions(
339344
SmallVector<Attribute>(op->getNumOperands(), Attribute()),
@@ -352,8 +357,25 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
352357
// operands.
353358
for (auto &&[i, regionValue, operand] :
354359
llvm::enumerate(targetValues, operandRange)) {
355-
if (gutils->isConstantValue(regionValue))
360+
361+
// check if all the predecessorValues are const too
362+
SmallVector<Value> possibleActivePreds;
363+
regionBranchOp.getPredecessorValues(successor, i, possibleActivePreds);
364+
365+
bool skipOpShadow = true;
366+
for (auto pv : possibleActivePreds) {
367+
if (!skipOpShadow)
368+
break;
369+
skipOpShadow = skipOpShadow && gutils->isConstantValue(pv);
370+
};
371+
372+
if (!skipOpShadow)
373+
constOperandPositionsToShadow.insert(
374+
operandRange.getBeginOperandIndex() + i);
375+
376+
if (skipOpShadow && gutils->isConstantValue(regionValue))
356377
continue;
378+
357379
operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i);
358380
if (successor.isParent())
359381
resultPositionsToShadow.insert(i);
@@ -365,13 +387,16 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
365387
resultPositionsToShadow.insert(res.getResultNumber());
366388

367389
return controlFlowForwardHandler(
368-
op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow);
390+
op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow,
391+
constOperandPositionsToShadow);
369392
}
370393

371394
LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
372395
Operation *op, OpBuilder &builder, MGradientUtils *gutils,
373396
const llvm::SmallDenseSet<unsigned> &operandPositionsToShadow,
374-
const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow) {
397+
const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow,
398+
const llvm::SmallDenseSet<unsigned> &constOperandPositionToShadow) {
399+
375400
// For all active results, add shadow types.
376401
// For now, assuming all results are relevant.
377402
Operation *newOp = gutils->getNewFromOriginal(op);
@@ -423,6 +448,63 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
423448
replacementRegion.takeBody(region);
424449
}
425450

451+
// Re-fix block args for entry -> blk-successors
452+
// if constOperandPositionsToShadow is non-empty, the takeBody(...) from
453+
// earlier replaces the body for replacement(which has well formed
454+
// successor-args) with newOp's successor regions
455+
//
456+
// We fix it by modifying the blockarguments for all the entry successors,
457+
// adding a newblockarg for every entry in constOperandPositionsToShadow and
458+
// updating the invertedPointerMap
459+
if (!constOperandPositionToShadow.empty()) {
460+
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
461+
auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
462+
463+
SmallVector<RegionSuccessor> entrySuccessors;
464+
regionBranchOp.getEntrySuccessorRegions(
465+
SmallVector<Attribute>(op->getNumOperands(), Attribute()),
466+
entrySuccessors);
467+
468+
for (const RegionSuccessor &successor : entrySuccessors) {
469+
if (successor.isParent())
470+
continue;
471+
472+
OperandRange operandRange =
473+
iface.getSuccessorOperands(regionBranchOp, successor);
474+
ValueRange targetValues = regionBranchOp.getSuccessorInputs(successor);
475+
476+
for (int i = targetValues.size() - 1; i >= 0; --i) {
477+
unsigned operandPosition = operandRange.getBeginOperandIndex() + i;
478+
if (!constOperandPositionToShadow.contains(operandPosition))
479+
continue;
480+
481+
auto regionValue = dyn_cast<BlockArgument>(targetValues[i]);
482+
if (!regionValue || gutils->invertedPointers.contains(regionValue))
483+
continue;
484+
485+
auto replacementArg =
486+
cast<BlockArgument>(gutils->getNewFromOriginal(regionValue));
487+
Block *replacementBlock = replacementArg.getOwner();
488+
489+
Value shadowArg;
490+
if (replacementArg.getArgNumber() ==
491+
replacementBlock->getNumArguments() - 1) {
492+
shadowArg = replacementBlock->addArgument(
493+
gutils->getShadowType(regionValue.getType()),
494+
regionValue.getLoc());
495+
} else {
496+
shadowArg = replacementBlock->insertArgument(
497+
replacementBlock->args_begin() + replacementArg.getArgNumber() +
498+
1,
499+
gutils->getShadowType(regionValue.getType()),
500+
regionValue.getLoc());
501+
}
502+
503+
gutils->invertedPointers.map(regionValue, shadowArg);
504+
}
505+
}
506+
}
507+
426508
// Inject the mapping for the new results into GradientUtil's shadow
427509
// table.
428510
SmallVector<Value> reps;

enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder,
4040
LogicalResult controlFlowForwardHandler(
4141
Operation *op, OpBuilder &builder, MGradientUtils *gutils,
4242
const llvm::SmallDenseSet<unsigned> &operandPositionsToShadow,
43-
const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow);
43+
const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow,
44+
const llvm::SmallDenseSet<unsigned> &constOperandPositionToShadow);
4445

4546
// Implements forward-mode differentiation of branching operations.
4647
// Assumes that successive shadows are legal
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: %eopt --enzyme %s | FileCheck %s
2+
3+
module {
4+
func.func @carry_mismatch_scf(%x : f64) -> f64 {
5+
%zero = arith.constant 0.0 : f64
6+
%c0 = arith.constant 0 : index
7+
%c1 = arith.constant 1 : index
8+
%c10 = arith.constant 10 : index
9+
%r = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %zero) -> (f64) {
10+
scf.yield %x : f64
11+
}
12+
return %r : f64
13+
}
14+
15+
func.func @dcarry_mismatch_scf(%x : f64, %dx : f64) -> f64 {
16+
%r = enzyme.fwddiff @carry_mismatch_scf(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
17+
return %r : f64
18+
}
19+
20+
func.func @carry_mismatch_affine(%x : f64) -> f64 {
21+
%zero = arith.constant 0.0 : f64
22+
%r = affine.for %i = 0 to 10 iter_args(%acc = %zero) -> (f64) {
23+
affine.yield %x : f64
24+
}
25+
return %r : f64
26+
}
27+
28+
func.func @dcarry_mismatch_affine(%x : f64, %dx : f64) -> f64 {
29+
%r = enzyme.fwddiff @carry_mismatch_affine(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
30+
return %r : f64
31+
}
32+
}
33+
34+
// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_scf(
35+
// CHECK-DAG: %[[ZERO0:.+]] = arith.constant 0.000000e+00 : f64
36+
// CHECK-DAG: %[[ZERO1:.+]] = arith.constant 0.000000e+00 : f64
37+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
38+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
39+
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
40+
// CHECK: %[[LOOP:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[DACC:.+]] = %[[ZERO1]], %[[ACC:.+]] = %[[ZERO0]]) -> (f64, f64) {
41+
// CHECK-NEXT: scf.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64
42+
// CHECK-NEXT: }
43+
// CHECK-NEXT: return %[[LOOP]]#1 : f64
44+
45+
// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_affine(
46+
// CHECK-DAG: %[[AZERO0:.+]] = arith.constant 0.000000e+00 : f64
47+
// CHECK-DAG: %[[AZERO1:.+]] = arith.constant 0.000000e+00 : f64
48+
// CHECK: %[[ALOOP:.+]]:2 = affine.for %[[AIV:.+]] = 0 to 10 iter_args(%[[ADACC:.+]] = %[[AZERO1]], %[[AACC:.+]] = %[[AZERO0]]) -> (f64, f64) {
49+
// CHECK-NEXT: affine.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64
50+
// CHECK-NEXT: }
51+
// CHECK-NEXT: return %[[ALOOP]]#1 : f64

0 commit comments

Comments
 (0)