@@ -318,15 +318,14 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
318318 // add the shadow as operand.
319319 auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
320320 if (!regionBranchOp) {
321- op->emitError () << " RegionBranchOpInterface not implemented for " << *op
322- << " \n " ;
323- return failure ();
321+ return op->emitError () << " RegionBranchOpInterface not implemented for "
322+ << *op << " \n " ;
324323 }
325324 auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
326325 if (!iface) {
327- op->emitError () << " ControlFlowAutoDiffOpInterface not implemented for "
328- << *op << " \n " ;
329- return failure () ;
326+ return op->emitError ()
327+ << " ControlFlowAutoDiffOpInterface not implemented for " << *op
328+ << " \n " ;
330329 }
331330
332331 // TODO: we may need to record, for every successor, which of its inputs
@@ -352,8 +351,37 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
352351 // operands.
353352 for (auto &&[i, regionValue, operand] :
354353 llvm::enumerate (targetValues, operandRange)) {
355- if (gutils->isConstantValue (regionValue))
356- continue ;
354+
355+ // if all the possible predecessors for this value are also const, then
356+ // we can skip creating a shadow. Else we need to create a shadow for
357+ // activity correctness
358+ if (gutils->isConstantValue (regionValue)) {
359+ SmallVector<Value> possibleActivePreds;
360+ SmallVector<RegionBranchPoint> predecessors;
361+ regionBranchOp.getPredecessors (successor, predecessors);
362+ for (RegionBranchPoint predecessor : predecessors) {
363+ if (predecessor.isParent ()) {
364+ // if the predecessor is the parent itself, then it's just
365+ // `operand`
366+ possibleActivePreds.push_back (operand);
367+ continue ;
368+ }
369+ auto terminator = predecessor.getTerminatorPredecessorOrNull ();
370+ auto predecessorOperands = terminator.getSuccessorOperands (successor);
371+ if (i < predecessorOperands.size ())
372+ possibleActivePreds.push_back (predecessorOperands[i]);
373+ }
374+
375+ bool skipOpShadow = true ;
376+ for (auto pv : possibleActivePreds) {
377+ if (!skipOpShadow)
378+ break ;
379+ skipOpShadow = skipOpShadow && gutils->isConstantValue (pv);
380+ };
381+ if (skipOpShadow)
382+ continue ;
383+ // if there's any possible active predecessor, we create a shadow for it
384+ }
357385 operandPositionsToShadow.insert (operandRange.getBeginOperandIndex () + i);
358386 if (successor.isParent ())
359387 resultPositionsToShadow.insert (i);
@@ -372,9 +400,11 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
372400 Operation *op, OpBuilder &builder, MGradientUtils *gutils,
373401 const llvm::SmallDenseSet<unsigned > &operandPositionsToShadow,
374402 const llvm::SmallDenseSet<unsigned > &resultPositionsToShadow) {
403+
375404 // For all active results, add shadow types.
376405 // For now, assuming all results are relevant.
377406 Operation *newOp = gutils->getNewFromOriginal (op);
407+ bool hasConstOperandShadow = false ;
378408 SmallVector<Type> newOpResultTypes;
379409 newOpResultTypes.reserve (op->getNumResults () * 2 );
380410 for (auto result : op->getResults ()) {
@@ -388,33 +418,81 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
388418 continue ;
389419 auto typeIface = dyn_cast<AutoDiffTypeInterface>(result.getType ());
390420 if (!typeIface) {
391- op->emitError () << " AutoDiffTypeInterface not implemented for "
392- << result.getType () << " \n " ;
393- return failure ();
421+ return op->emitError () << " AutoDiffTypeInterface not implemented for "
422+ << result.getType () << " \n " ;
394423 }
395424 newOpResultTypes.push_back (typeIface.getShadowType (gutils->width ));
396425 }
397426
398427 SmallVector<Value> newOperands;
399428 newOperands.reserve (op->getNumOperands () + operandPositionsToShadow.size ());
429+
430+ auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
431+ if (!iface) {
432+ return op->emitError ()
433+ << " ControlFlowAutoDiffOpInterface not implemented for " << *op
434+ << " \n " ;
435+ }
436+
437+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
438+ SmallVector<RegionSuccessor> entrySuccessors;
439+ regionBranchOp.getEntrySuccessorRegions (
440+ SmallVector<Attribute>(op->getNumOperands (), Attribute ()),
441+ entrySuccessors);
400442 for (OpOperand &operand : op->getOpOperands ()) {
401443 newOperands.push_back (gutils->getNewFromOriginal (operand.get ()));
402- if (operandPositionsToShadow.contains (operand.getOperandNumber ()))
403- newOperands.push_back (gutils->invertPointerM (operand.get (), builder));
444+ if (operandPositionsToShadow.contains (operand.getOperandNumber ())) {
445+ Value shadowValue = nullptr ;
446+ if (!gutils->isConstantValue (operand.get ()))
447+ shadowValue = gutils->invertPointerM (operand.get (), builder);
448+ else {
449+ auto Ty = operand.get ().getType ();
450+ auto shadowType =
451+ cast<AutoDiffTypeInterface>(Ty).getShadowType (gutils->width );
452+ shadowValue = cast<AutoDiffTypeInterface>(shadowType)
453+ .createNullValue (builder, operand.get ().getLoc ());
454+ hasConstOperandShadow = true ;
455+
456+ // modify block arguments for entry successors to newOp, since
457+ // forceAugmentedReturns will not shadow const operands. No need to add
458+ // to the invertPointers map since `operand` is const (the shadow will
459+ // be unused)
460+ for (const RegionSuccessor &successor : entrySuccessors) {
461+ if (successor.isParent ())
462+ continue ;
463+ auto &newOpRegion =
464+ newOp->getRegion (successor.getSuccessor ()->getRegionNumber ());
465+ OperandRange succOperands =
466+ iface.getSuccessorOperands (regionBranchOp, successor);
467+ ValueRange succInputs = regionBranchOp.getSuccessorInputs (successor);
468+ auto succInputPos =
469+ operand.getOperandNumber () - succOperands.getBeginOperandIndex ();
470+
471+ if (succInputPos >= 0 && succInputPos < succInputs.size ()) {
472+ auto newOpBlockVal = dyn_cast<BlockArgument>(
473+ gutils->getNewFromOriginal (succInputs[succInputPos]));
474+ auto i = newOpBlockVal.getArgNumber ();
475+ mlir::Value dval = nullptr ;
476+ if (i == newOpRegion.getNumArguments () - 1 ) {
477+ dval =
478+ newOpRegion.addArgument (shadowType, newOpBlockVal.getLoc ());
479+ } else {
480+ dval = newOpRegion.insertArgument (
481+ newOpRegion.args_begin () + i + 1 , shadowType,
482+ newOpBlockVal.getLoc ());
483+ }
484+ }
485+ }
486+ }
487+ newOperands.push_back (shadowValue);
488+ }
404489 }
405490 // We are assuming the op can forward additional operands, listed
406491 // immediately after the original operands, to the same regions.
407492 // ^^
408493 // Our interface guarantees this.
409494 // We also assume that the region-holding op returns all of the values
410495 // yielded by terminators, and only those values.
411-
412- auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
413- if (!iface) {
414- op->emitError () << " ControlFlowAutoDiffOpInterface not implemented for "
415- << *op << " \n " ;
416- return failure ();
417- }
418496 Operation *replacement = iface.createWithShadows (
419497 builder, gutils, op, newOperands, newOpResultTypes);
420498 assert (replacement->getNumResults () == newOpResultTypes.size ());
0 commit comments