@@ -337,14 +337,15 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
337337 // add the shadow as operand.
338338 auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
339339 if (!regionBranchOp) {
340- return op->emitError () << " RegionBranchOpInterface not implemented for "
341- << *op << " \n " ;
340+ op->emitError () << " RegionBranchOpInterface not implemented for " << *op
341+ << " \n " ;
342+ return failure ();
342343 }
343344 auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
344345 if (!iface) {
345- return op->emitError ()
346- << " ControlFlowAutoDiffOpInterface not implemented for " << *op
347- << " \n " ;
346+ op->emitError () << " ControlFlowAutoDiffOpInterface not implemented for "
347+ << *op << " \n " ;
348+ return failure () ;
348349 }
349350
350351 // TODO: we may need to record, for every successor, which of its inputs
@@ -370,37 +371,8 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
370371 // operands.
371372 for (auto &&[i, regionValue, operand] :
372373 llvm::enumerate (targetValues, operandRange)) {
373-
374- // if all the possible predecessors for this value are also const, then
375- // we can skip creating a shadow. Else we need to create a shadow for
376- // activity correctness
377- if (gutils->isConstantValue (regionValue)) {
378- SmallVector<Value> possibleActivePreds;
379- SmallVector<RegionBranchPoint> predecessors;
380- regionBranchOp.getPredecessors (successor, predecessors);
381- for (RegionBranchPoint predecessor : predecessors) {
382- if (predecessor.isParent ()) {
383- // if the predecessor is the parent itself, then it's just
384- // `operand`
385- possibleActivePreds.push_back (operand);
386- continue ;
387- }
388- auto terminator = predecessor.getTerminatorPredecessorOrNull ();
389- auto predecessorOperands = terminator.getSuccessorOperands (successor);
390- if (i < predecessorOperands.size ())
391- possibleActivePreds.push_back (predecessorOperands[i]);
392- }
393-
394- bool skipOpShadow = true ;
395- for (auto pv : possibleActivePreds) {
396- if (!skipOpShadow)
397- break ;
398- skipOpShadow = skipOpShadow && gutils->isConstantValue (pv);
399- };
400- if (skipOpShadow)
401- continue ;
402- // if there's any possible active predecessor, we create a shadow for it
403- }
374+ if (gutils->isConstantValue (regionValue))
375+ continue ;
404376 operandPositionsToShadow.insert (operandRange.getBeginOperandIndex () + i);
405377 if (successor.isParent ())
406378 resultPositionsToShadow.insert (i);
@@ -435,98 +407,33 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
435407 continue ;
436408 auto typeIface = dyn_cast<AutoDiffTypeInterface>(result.getType ());
437409 if (!typeIface) {
438- return op->emitError () << " AutoDiffTypeInterface not implemented for "
439- << result.getType () << " \n " ;
410+ op->emitError () << " AutoDiffTypeInterface not implemented for "
411+ << result.getType () << " \n " ;
412+ return failure ();
440413 }
441414 newOpResultTypes.push_back (typeIface.getShadowType (gutils->width ));
442415 }
443416
444417 SmallVector<Value> newOperands;
445418 newOperands.reserve (op->getNumOperands () + operandPositionsToShadow.size ());
446-
447- auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
448- if (!iface) {
449- return op->emitError ()
450- << " ControlFlowAutoDiffOpInterface not implemented for " << *op
451- << " \n " ;
452- }
453-
454- // Not all users of ControlFlowHandler(...) implement the
455- // RegionBranchOpInterface -- for example stablehlo.while doesn't implement
456- // this. We will still retain creating shadows for constant operands, but only
457- // restrict the behavior to RegionBranchOpInterface.
458- auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
459- SmallVector<RegionSuccessor> entrySuccessors;
460- if (regionBranchOp)
461- regionBranchOp.getEntrySuccessorRegions (
462- SmallVector<Attribute>(op->getNumOperands (), Attribute ()),
463- entrySuccessors);
464-
465419 for (OpOperand &operand : op->getOpOperands ()) {
466420 newOperands.push_back (gutils->getNewFromOriginal (operand.get ()));
467- if (operandPositionsToShadow.contains (operand.getOperandNumber ())) {
468- Value shadowValue = nullptr ;
469- if (!gutils->isConstantValue (operand.get ()))
470- shadowValue = gutils->invertPointerM (operand.get (), builder);
471- else if (regionBranchOp) {
472- auto Ty = operand.get ().getType ();
473- auto shadowType =
474- cast<AutoDiffTypeInterface>(Ty).getShadowType (gutils->width );
475- shadowValue = cast<AutoDiffTypeInterface>(shadowType)
476- .createNullValue (builder, operand.get ().getLoc ());
477-
478- // modify block arguments for entry successors to newOp, since
479- // forceAugmentedReturns will not shadow const operands. No need to add
480- // to the invertPointers map since `operand` is const (the shadow will
481- // be unused)
482- for (const RegionSuccessor &successor : entrySuccessors) {
483- if (successor.isParent ())
484- continue ;
485- auto &newOpRegion =
486- newOp->getRegion (successor.getSuccessor ()->getRegionNumber ());
487- OperandRange succOperands =
488- iface.getSuccessorOperands (regionBranchOp, successor);
489- ValueRange succInputs = regionBranchOp.getSuccessorInputs (successor);
490-
491- if (succOperands.empty ())
492- continue ;
493-
494- auto succInputPos =
495- operand.getOperandNumber () - succOperands.getBeginOperandIndex ();
496-
497- if (succInputPos >= 0 && succInputPos < succInputs.size ()) {
498- auto oldRegionInput =
499- dyn_cast<BlockArgument>(succInputs[succInputPos]);
500- if (!oldRegionInput)
501- continue ;
502- if (gutils->invertedPointers .contains (oldRegionInput))
503- continue ;
504- auto newOpBlockVal =
505- cast<BlockArgument>(gutils->getNewFromOriginal (oldRegionInput));
506- auto i = newOpBlockVal.getArgNumber ();
507- if (i == newOpRegion.getNumArguments () - 1 ) {
508- newOpRegion.addArgument (shadowType, newOpBlockVal.getLoc ());
509- } else {
510- newOpRegion.insertArgument (newOpRegion.args_begin () + i + 1 ,
511- shadowType, newOpBlockVal.getLoc ());
512- }
513- }
514- }
515- } else {
516- // TODO: a const operand, but it also has to be shadowed (but the op
517- // doesn't implement the RegionBranchOpInterface). Unimplemented for
518- // now.
519- }
520- newOperands.push_back (shadowValue);
521- }
421+ if (operandPositionsToShadow.contains (operand.getOperandNumber ()))
422+ newOperands.push_back (gutils->invertPointerM (operand.get (), builder));
522423 }
523-
524424 // We are assuming the op can forward additional operands, listed
525425 // immediately after the original operands, to the same regions.
526426 // ^^
527427 // Our interface guarantees this.
528428 // We also assume that the region-holding op returns all of the values
529429 // yielded by terminators, and only those values.
430+
431+ auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
432+ if (!iface) {
433+ op->emitError () << " ControlFlowAutoDiffOpInterface not implemented for "
434+ << *op << " \n " ;
435+ return failure ();
436+ }
530437 Operation *replacement = iface.createWithShadows (
531438 builder, gutils, op, newOperands, newOpResultTypes);
532439 assert (replacement->getNumResults () == newOpResultTypes.size ());
0 commit comments