@@ -358,23 +358,42 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
358358 for (auto &&[i, regionValue, operand] :
359359 llvm::enumerate (targetValues, operandRange)) {
360360
361- // check if all the predecessorValues are const too
362- SmallVector<Value> possibleActivePreds;
363- regionBranchOp.getPredecessorValues (successor, i, possibleActivePreds);
361+ // check if we need to create a shadow for an inactive region value
362+ if (gutils->isConstantValue (regionValue)) {
363+
364+ // if all the possible predecessors for this value are also const, then
365+ // we can skip creating a shadow. Else we need to create a shadow for
366+ // syntactic correctness
367+
368+ SmallVector<Value> possibleActivePreds;
369+ SmallVector<RegionBranchPoint> predecessors;
370+ regionBranchOp.getPredecessors (successor, predecessors);
371+ for (RegionBranchPoint predecessor : predecessors) {
372+ if (predecessor.isParent ()) {
373+ // if the predecessor is the parent itself, then it's just
374+ // operand!
375+ possibleActivePreds.push_back (operand);
376+ continue ;
377+ }
378+ auto terminator = predecessor.getTerminatorPredecessorOrNull ();
379+ auto predecessorOperands = terminator.getSuccessorOperands (successor);
380+ if (i < predecessorOperands.size ())
381+ possibleActivePreds.push_back (predecessorOperands[i]);
382+ }
364383
365- bool skipOpShadow = true ;
366- for (auto pv : possibleActivePreds) {
384+ bool skipOpShadow = true ;
385+ for (auto pv : possibleActivePreds) {
386+ if (!skipOpShadow)
387+ break ;
388+ skipOpShadow = skipOpShadow && gutils->isConstantValue (pv);
389+ };
390+ // if there's any possible active predecessor, we create a shadow for it
367391 if (!skipOpShadow)
368- break ;
369- skipOpShadow = skipOpShadow && gutils->isConstantValue (pv);
370- };
371-
372- if (!skipOpShadow)
373- constOperandPositionsToShadow.insert (
374- operandRange.getBeginOperandIndex () + i);
375-
376- if (skipOpShadow && gutils->isConstantValue (regionValue))
377- continue ;
392+ constOperandPositionsToShadow.insert (
393+ operandRange.getBeginOperandIndex () + i);
394+ if (skipOpShadow)
395+ continue ;
396+ }
378397
379398 operandPositionsToShadow.insert (operandRange.getBeginOperandIndex () + i);
380399 if (successor.isParent ())
@@ -448,14 +467,11 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
448467 replacementRegion.takeBody (region);
449468 }
450469
451- // Re-fix block args for entry -> blk-successors
452- // if constOperandPositionsToShadow is non-empty, the takeBody(...) from
453- // earlier replaces the body for replacement(which has well formed
454- // successor-args) with newOp's successor regions
455- //
456- // We fix it by modifying the blockarguments for all the entry successors,
457- // adding a newblockarg for every entry in constOperandPositionsToShadow and
458- // updating the invertedPointerMap
470+ // Re-fix block args for all successor regions
471+ // Even though createWithShadows properly creates the differentiated control
472+ // flow op(accounting for any const args which might have shadows),
473+ // takeBody(...) replaces the successor regions entirely, including the block
474+ // arguments. We fix the block arguments here for the entry successor regions.
459475 if (!constOperandPositionToShadow.empty ()) {
460476 auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
461477 auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
@@ -466,41 +482,55 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
466482 entrySuccessors);
467483
468484 for (const RegionSuccessor &successor : entrySuccessors) {
485+
469486 if (successor.isParent ())
470487 continue ;
471488
472- OperandRange operandRange =
489+ OperandRange oldRegionOperands =
473490 iface.getSuccessorOperands (regionBranchOp, successor);
474- ValueRange targetValues = regionBranchOp.getSuccessorInputs (successor);
491+ ValueRange oldRegionInputs = regionBranchOp.getSuccessorInputs (successor);
492+
493+ // the new region corresponding to this successor(we want to modify the
494+ // arguments of this region in-place)
495+ auto &newRegion =
496+ replacement->getRegion (successor.getSuccessor ()->getRegionNumber ());
475497
476- for (int i = targetValues .size () - 1 ; i >= 0 ; --i) {
477- unsigned operandPosition = operandRange .getBeginOperandIndex () + i;
498+ for (int i = oldRegionInputs .size () - 1 ; i >= 0 ; --i) {
499+ unsigned operandPosition = oldRegionOperands .getBeginOperandIndex () + i;
478500 if (!constOperandPositionToShadow.contains (operandPosition))
479501 continue ;
480502
481- auto regionValue = dyn_cast<BlockArgument>(targetValues[i]);
482- if (!regionValue || gutils->invertedPointers .contains (regionValue))
503+ auto oldRegionInput = dyn_cast<BlockArgument>(oldRegionInputs[i]);
504+
505+ if (!oldRegionInput ||
506+ gutils->invertedPointers .contains (oldRegionInput))
483507 continue ;
484508
485- auto replacementArg =
486- cast<BlockArgument>(gutils->getNewFromOriginal (regionValue));
487- Block *replacementBlock = replacementArg.getOwner ();
509+ auto newRegionInput =
510+ cast<BlockArgument>(gutils->getNewFromOriginal (oldRegionInput));
511+
512+ auto typeIface =
513+ dyn_cast<AutoDiffTypeInterface>(oldRegionInput.getType ());
514+
515+ if (!typeIface) {
516+ op->emitError () << " AutoDiffTypeInterface not implemented for "
517+ << oldRegionInput.getType () << " \n " ;
518+ return failure ();
519+ }
488520
489- Value shadowArg;
490- if (replacementArg.getArgNumber () ==
491- replacementBlock->getNumArguments () - 1 ) {
492- shadowArg = replacementBlock->addArgument (
493- gutils->getShadowType (regionValue.getType ()),
494- regionValue.getLoc ());
521+ Value newRegionShadow;
522+ if (newRegionInput.getArgNumber () == newRegion.getNumArguments () - 1 ) {
523+ newRegionShadow = newRegion.addArgument (
524+ typeIface.getShadowType (gutils->width ), newRegionInput.getLoc ());
495525 } else {
496- shadowArg = replacementBlock->insertArgument (
497- replacementBlock->args_begin () + replacementArg.getArgNumber () +
498- 1 ,
499- gutils->getShadowType (regionValue.getType ()),
500- regionValue.getLoc ());
526+ // insert at position i+1
527+ newRegionShadow = newRegion.insertArgument (
528+ newRegion.args_begin () + newRegionInput.getArgNumber () + 1 ,
529+ typeIface.getShadowType (gutils->width ), newRegionInput.getLoc ());
501530 }
502531
503- gutils->invertedPointers .map (regionValue, shadowArg);
532+ // update the inverted pointer map
533+ gutils->invertedPointers .map (oldRegionInput, newRegionShadow);
504534 }
505535 }
506536 }
0 commit comments