@@ -334,6 +334,11 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
334334 llvm::SmallDenseSet<unsigned > operandPositionsToShadow;
335335 llvm::SmallDenseSet<unsigned > resultPositionsToShadow;
336336
337+ // while these operands are inactive in the op region(s), we may still need to
338+ // create placeholder shadows for them to ensure syntactic correctness for the
339+ // IR
340+ llvm::SmallDenseSet<unsigned > constOperandPositionsToShadow;
341+
337342 SmallVector<RegionSuccessor> entrySuccessors;
338343 regionBranchOp.getEntrySuccessorRegions (
339344 SmallVector<Attribute>(op->getNumOperands (), Attribute ()),
@@ -352,8 +357,25 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
352357 // operands.
353358 for (auto &&[i, regionValue, operand] :
354359 llvm::enumerate (targetValues, operandRange)) {
355- if (gutils->isConstantValue (regionValue))
360+
361+ // check if all the predecessorValues are const too
362+ SmallVector<Value> possibleActivePreds;
363+ regionBranchOp.getPredecessorValues (successor, i, possibleActivePreds);
364+
365+ bool skipOpShadow = true ;
366+ for (auto pv : possibleActivePreds) {
367+ if (!skipOpShadow)
368+ break ;
369+ skipOpShadow = skipOpShadow && gutils->isConstantValue (pv);
370+ };
371+
372+ if (!skipOpShadow)
373+ constOperandPositionsToShadow.insert (
374+ operandRange.getBeginOperandIndex () + i);
375+
376+ if (skipOpShadow && gutils->isConstantValue (regionValue))
356377 continue ;
378+
357379 operandPositionsToShadow.insert (operandRange.getBeginOperandIndex () + i);
358380 if (successor.isParent ())
359381 resultPositionsToShadow.insert (i);
@@ -365,13 +387,16 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
365387 resultPositionsToShadow.insert (res.getResultNumber ());
366388
367389 return controlFlowForwardHandler (
368- op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow);
390+ op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow,
391+ constOperandPositionsToShadow);
369392}
370393
371394LogicalResult mlir::enzyme::detail::controlFlowForwardHandler (
372395 Operation *op, OpBuilder &builder, MGradientUtils *gutils,
373396 const llvm::SmallDenseSet<unsigned > &operandPositionsToShadow,
374- const llvm::SmallDenseSet<unsigned > &resultPositionsToShadow) {
397+ const llvm::SmallDenseSet<unsigned > &resultPositionsToShadow,
398+ const llvm::SmallDenseSet<unsigned > &constOperandPositionToShadow) {
399+
375400 // For all active results, add shadow types.
376401 // For now, assuming all results are relevant.
377402 Operation *newOp = gutils->getNewFromOriginal (op);
@@ -423,6 +448,63 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
423448 replacementRegion.takeBody (region);
424449 }
425450
451+ // Re-fix block args for entry -> blk-successors
452+ // if constOperandPositionsToShadow is non-empty, the takeBody(...) from
453+ // earlier replaces the body for replacement(which has well formed
454+ // successor-args) with newOp's successor regions
455+ //
456+ // We fix it by modifying the blockarguments for all the entry successors,
457+ // adding a newblockarg for every entry in constOperandPositionsToShadow and
458+ // updating the invertedPointerMap
459+ if (!constOperandPositionToShadow.empty ()) {
460+ auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
461+ auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
462+
463+ SmallVector<RegionSuccessor> entrySuccessors;
464+ regionBranchOp.getEntrySuccessorRegions (
465+ SmallVector<Attribute>(op->getNumOperands (), Attribute ()),
466+ entrySuccessors);
467+
468+ for (const RegionSuccessor &successor : entrySuccessors) {
469+ if (successor.isParent ())
470+ continue ;
471+
472+ OperandRange operandRange =
473+ iface.getSuccessorOperands (regionBranchOp, successor);
474+ ValueRange targetValues = regionBranchOp.getSuccessorInputs (successor);
475+
476+ for (int i = targetValues.size () - 1 ; i >= 0 ; --i) {
477+ unsigned operandPosition = operandRange.getBeginOperandIndex () + i;
478+ if (!constOperandPositionToShadow.contains (operandPosition))
479+ continue ;
480+
481+ auto regionValue = dyn_cast<BlockArgument>(targetValues[i]);
482+ if (!regionValue || gutils->invertedPointers .contains (regionValue))
483+ continue ;
484+
485+ auto replacementArg =
486+ cast<BlockArgument>(gutils->getNewFromOriginal (regionValue));
487+ Block *replacementBlock = replacementArg.getOwner ();
488+
489+ Value shadowArg;
490+ if (replacementArg.getArgNumber () ==
491+ replacementBlock->getNumArguments () - 1 ) {
492+ shadowArg = replacementBlock->addArgument (
493+ gutils->getShadowType (regionValue.getType ()),
494+ regionValue.getLoc ());
495+ } else {
496+ shadowArg = replacementBlock->insertArgument (
497+ replacementBlock->args_begin () + replacementArg.getArgNumber () +
498+ 1 ,
499+ gutils->getShadowType (regionValue.getType ()),
500+ regionValue.getLoc ());
501+ }
502+
503+ gutils->invertedPointers .map (regionValue, shadowArg);
504+ }
505+ }
506+ }
507+
426508 // Inject the mapping for the new results into GradientUtil's shadow
427509 // table.
428510 SmallVector<Value> reps;
0 commit comments