Skip to content

Commit d31bb21

Browse files
committed
Support batching properly
1 parent 66cb97f commit d31bb21

1 file changed

Lines changed: 74 additions & 44 deletions

File tree

enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp

Lines changed: 74 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -358,23 +358,42 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
358358
for (auto &&[i, regionValue, operand] :
359359
llvm::enumerate(targetValues, operandRange)) {
360360

361-
// check if all the predecessorValues are const too
362-
SmallVector<Value> possibleActivePreds;
363-
regionBranchOp.getPredecessorValues(successor, i, possibleActivePreds);
361+
// check if we need to create a shadow for an inactive region value
362+
if (gutils->isConstantValue(regionValue)) {
363+
364+
// if all the possible predecessors for this value are also const, then
365+
// we can skip creating a shadow. Else we need to create a shadow for
366+
// syntactic correctness
367+
368+
SmallVector<Value> possibleActivePreds;
369+
SmallVector<RegionBranchPoint> predecessors;
370+
regionBranchOp.getPredecessors(successor, predecessors);
371+
for (RegionBranchPoint predecessor : predecessors) {
372+
if (predecessor.isParent()) {
373+
// if the predecessor is the parent itself, then it's just
374+
// operand!
375+
possibleActivePreds.push_back(operand);
376+
continue;
377+
}
378+
auto terminator = predecessor.getTerminatorPredecessorOrNull();
379+
auto predecessorOperands = terminator.getSuccessorOperands(successor);
380+
if (i < predecessorOperands.size())
381+
possibleActivePreds.push_back(predecessorOperands[i]);
382+
}
364383

365-
bool skipOpShadow = true;
366-
for (auto pv : possibleActivePreds) {
384+
bool skipOpShadow = true;
385+
for (auto pv : possibleActivePreds) {
386+
if (!skipOpShadow)
387+
break;
388+
skipOpShadow = skipOpShadow && gutils->isConstantValue(pv);
389+
};
390+
// if there's any possible active predecessor, we create a shadow for it
367391
if (!skipOpShadow)
368-
break;
369-
skipOpShadow = skipOpShadow && gutils->isConstantValue(pv);
370-
};
371-
372-
if (!skipOpShadow)
373-
constOperandPositionsToShadow.insert(
374-
operandRange.getBeginOperandIndex() + i);
375-
376-
if (skipOpShadow && gutils->isConstantValue(regionValue))
377-
continue;
392+
constOperandPositionsToShadow.insert(
393+
operandRange.getBeginOperandIndex() + i);
394+
if (skipOpShadow)
395+
continue;
396+
}
378397

379398
operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i);
380399
if (successor.isParent())
@@ -448,14 +467,11 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
448467
replacementRegion.takeBody(region);
449468
}
450469

451-
// Re-fix block args for entry -> blk-successors
452-
// if constOperandPositionsToShadow is non-empty, the takeBody(...) from
453-
// earlier replaces the body for replacement(which has well formed
454-
// successor-args) with newOp's successor regions
455-
//
456-
// We fix it by modifying the blockarguments for all the entry successors,
457-
// adding a newblockarg for every entry in constOperandPositionsToShadow and
458-
// updating the invertedPointerMap
470+
// Re-fix block args for all successor regions
471+
// Even though createWithShadows properly creates the differentiated control
472+
// flow op(accounting for any const args which might have shadows),
473+
// takeBody(...) replaces the successor regions entirely, including the block
474+
// arguments. We fix the block arguments here for the entry successor regions.
459475
if (!constOperandPositionToShadow.empty()) {
460476
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
461477
auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
@@ -466,41 +482,55 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
466482
entrySuccessors);
467483

468484
for (const RegionSuccessor &successor : entrySuccessors) {
485+
469486
if (successor.isParent())
470487
continue;
471488

472-
OperandRange operandRange =
489+
OperandRange oldRegionOperands =
473490
iface.getSuccessorOperands(regionBranchOp, successor);
474-
ValueRange targetValues = regionBranchOp.getSuccessorInputs(successor);
491+
ValueRange oldRegionInputs = regionBranchOp.getSuccessorInputs(successor);
492+
493+
// the new region corresponding to this successor(we want to modify the
494+
// arguments of this region in-place)
495+
auto &newRegion =
496+
replacement->getRegion(successor.getSuccessor()->getRegionNumber());
475497

476-
for (int i = targetValues.size() - 1; i >= 0; --i) {
477-
unsigned operandPosition = operandRange.getBeginOperandIndex() + i;
498+
for (int i = oldRegionInputs.size() - 1; i >= 0; --i) {
499+
unsigned operandPosition = oldRegionOperands.getBeginOperandIndex() + i;
478500
if (!constOperandPositionToShadow.contains(operandPosition))
479501
continue;
480502

481-
auto regionValue = dyn_cast<BlockArgument>(targetValues[i]);
482-
if (!regionValue || gutils->invertedPointers.contains(regionValue))
503+
auto oldRegionInput = dyn_cast<BlockArgument>(oldRegionInputs[i]);
504+
505+
if (!oldRegionInput ||
506+
gutils->invertedPointers.contains(oldRegionInput))
483507
continue;
484508

485-
auto replacementArg =
486-
cast<BlockArgument>(gutils->getNewFromOriginal(regionValue));
487-
Block *replacementBlock = replacementArg.getOwner();
509+
auto newRegionInput =
510+
cast<BlockArgument>(gutils->getNewFromOriginal(oldRegionInput));
511+
512+
auto typeIface =
513+
dyn_cast<AutoDiffTypeInterface>(oldRegionInput.getType());
514+
515+
if (!typeIface) {
516+
op->emitError() << " AutoDiffTypeInterface not implemented for "
517+
<< oldRegionInput.getType() << "\n";
518+
return failure();
519+
}
488520

489-
Value shadowArg;
490-
if (replacementArg.getArgNumber() ==
491-
replacementBlock->getNumArguments() - 1) {
492-
shadowArg = replacementBlock->addArgument(
493-
gutils->getShadowType(regionValue.getType()),
494-
regionValue.getLoc());
521+
Value newRegionShadow;
522+
if (newRegionInput.getArgNumber() == newRegion.getNumArguments() - 1) {
523+
newRegionShadow = newRegion.addArgument(
524+
typeIface.getShadowType(gutils->width), newRegionInput.getLoc());
495525
} else {
496-
shadowArg = replacementBlock->insertArgument(
497-
replacementBlock->args_begin() + replacementArg.getArgNumber() +
498-
1,
499-
gutils->getShadowType(regionValue.getType()),
500-
regionValue.getLoc());
526+
// insert at position i+1
527+
newRegionShadow = newRegion.insertArgument(
528+
newRegion.args_begin() + newRegionInput.getArgNumber() + 1,
529+
typeIface.getShadowType(gutils->width), newRegionInput.getLoc());
501530
}
502531

503-
gutils->invertedPointers.map(regionValue, shadowArg);
532+
// update the inverted pointer map
533+
gutils->invertedPointers.map(oldRegionInput, newRegionShadow);
504534
}
505535
}
506536
}

0 commit comments

Comments
 (0)