Skip to content

Commit 7c0141f

Browse files
authored
Revert tablegen handler (#2868)
* Revert "Optionally check regionBranchOp (#2861)" This reverts commit c47a056. * Revert "Add const shadows for forward mode AD on `RegionBranchOpInterface` (#2780)" This reverts commit 9102b25.
1 parent 8ca23de commit 7c0141f

6 files changed

Lines changed: 36 additions & 180 deletions

File tree

enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp

Lines changed: 20 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,15 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
337337
// add the shadow as operand.
338338
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
339339
if (!regionBranchOp) {
340-
return op->emitError() << " RegionBranchOpInterface not implemented for "
341-
<< *op << "\n";
340+
op->emitError() << " RegionBranchOpInterface not implemented for " << *op
341+
<< "\n";
342+
return failure();
342343
}
343344
auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
344345
if (!iface) {
345-
return op->emitError()
346-
<< " ControlFlowAutoDiffOpInterface not implemented for " << *op
347-
<< "\n";
346+
op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for "
347+
<< *op << "\n";
348+
return failure();
348349
}
349350

350351
// TODO: we may need to record, for every successor, which of its inputs
@@ -370,37 +371,8 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
370371
// operands.
371372
for (auto &&[i, regionValue, operand] :
372373
llvm::enumerate(targetValues, operandRange)) {
373-
374-
// if all the possible predecessors for this value are also const, then
375-
// we can skip creating a shadow. Else we need to create a shadow for
376-
// activity correctness
377-
if (gutils->isConstantValue(regionValue)) {
378-
SmallVector<Value> possibleActivePreds;
379-
SmallVector<RegionBranchPoint> predecessors;
380-
regionBranchOp.getPredecessors(successor, predecessors);
381-
for (RegionBranchPoint predecessor : predecessors) {
382-
if (predecessor.isParent()) {
383-
// if the predecessor is the parent itself, then it's just
384-
// `operand`
385-
possibleActivePreds.push_back(operand);
386-
continue;
387-
}
388-
auto terminator = predecessor.getTerminatorPredecessorOrNull();
389-
auto predecessorOperands = terminator.getSuccessorOperands(successor);
390-
if (i < predecessorOperands.size())
391-
possibleActivePreds.push_back(predecessorOperands[i]);
392-
}
393-
394-
bool skipOpShadow = true;
395-
for (auto pv : possibleActivePreds) {
396-
if (!skipOpShadow)
397-
break;
398-
skipOpShadow = skipOpShadow && gutils->isConstantValue(pv);
399-
};
400-
if (skipOpShadow)
401-
continue;
402-
// if there's any possible active predecessor, we create a shadow for it
403-
}
374+
if (gutils->isConstantValue(regionValue))
375+
continue;
404376
operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i);
405377
if (successor.isParent())
406378
resultPositionsToShadow.insert(i);
@@ -435,98 +407,33 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
435407
continue;
436408
auto typeIface = dyn_cast<AutoDiffTypeInterface>(result.getType());
437409
if (!typeIface) {
438-
return op->emitError() << " AutoDiffTypeInterface not implemented for "
439-
<< result.getType() << "\n";
410+
op->emitError() << " AutoDiffTypeInterface not implemented for "
411+
<< result.getType() << "\n";
412+
return failure();
440413
}
441414
newOpResultTypes.push_back(typeIface.getShadowType(gutils->width));
442415
}
443416

444417
SmallVector<Value> newOperands;
445418
newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size());
446-
447-
auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
448-
if (!iface) {
449-
return op->emitError()
450-
<< " ControlFlowAutoDiffOpInterface not implemented for " << *op
451-
<< "\n";
452-
}
453-
454-
// Not all users of ControlFlowHandler(...) implement the
455-
// RegionBranchOpInterface -- for example stablehlo.while doesn't implement
456-
// this. We will still retain creating shadows for constant operands, but only
457-
// restrict the behavior to RegionBranchOpInterface.
458-
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
459-
SmallVector<RegionSuccessor> entrySuccessors;
460-
if (regionBranchOp)
461-
regionBranchOp.getEntrySuccessorRegions(
462-
SmallVector<Attribute>(op->getNumOperands(), Attribute()),
463-
entrySuccessors);
464-
465419
for (OpOperand &operand : op->getOpOperands()) {
466420
newOperands.push_back(gutils->getNewFromOriginal(operand.get()));
467-
if (operandPositionsToShadow.contains(operand.getOperandNumber())) {
468-
Value shadowValue = nullptr;
469-
if (!gutils->isConstantValue(operand.get()))
470-
shadowValue = gutils->invertPointerM(operand.get(), builder);
471-
else if (regionBranchOp) {
472-
auto Ty = operand.get().getType();
473-
auto shadowType =
474-
cast<AutoDiffTypeInterface>(Ty).getShadowType(gutils->width);
475-
shadowValue = cast<AutoDiffTypeInterface>(shadowType)
476-
.createNullValue(builder, operand.get().getLoc());
477-
478-
// modify block arguments for entry successors to newOp, since
479-
// forceAugmentedReturns will not shadow const operands. No need to add
480-
// to the invertPointers map since `operand` is const (the shadow will
481-
// be unused)
482-
for (const RegionSuccessor &successor : entrySuccessors) {
483-
if (successor.isParent())
484-
continue;
485-
auto &newOpRegion =
486-
newOp->getRegion(successor.getSuccessor()->getRegionNumber());
487-
OperandRange succOperands =
488-
iface.getSuccessorOperands(regionBranchOp, successor);
489-
ValueRange succInputs = regionBranchOp.getSuccessorInputs(successor);
490-
491-
if (succOperands.empty())
492-
continue;
493-
494-
auto succInputPos =
495-
operand.getOperandNumber() - succOperands.getBeginOperandIndex();
496-
497-
if (succInputPos >= 0 && succInputPos < succInputs.size()) {
498-
auto oldRegionInput =
499-
dyn_cast<BlockArgument>(succInputs[succInputPos]);
500-
if (!oldRegionInput)
501-
continue;
502-
if (gutils->invertedPointers.contains(oldRegionInput))
503-
continue;
504-
auto newOpBlockVal =
505-
cast<BlockArgument>(gutils->getNewFromOriginal(oldRegionInput));
506-
auto i = newOpBlockVal.getArgNumber();
507-
if (i == newOpRegion.getNumArguments() - 1) {
508-
newOpRegion.addArgument(shadowType, newOpBlockVal.getLoc());
509-
} else {
510-
newOpRegion.insertArgument(newOpRegion.args_begin() + i + 1,
511-
shadowType, newOpBlockVal.getLoc());
512-
}
513-
}
514-
}
515-
} else {
516-
// TODO: a const operand, but it also has to be shadowed (but the op
517-
// doesn't implement the RegionBranchOpInterface). Unimplemented for
518-
// now.
519-
}
520-
newOperands.push_back(shadowValue);
521-
}
421+
if (operandPositionsToShadow.contains(operand.getOperandNumber()))
422+
newOperands.push_back(gutils->invertPointerM(operand.get(), builder));
522423
}
523-
524424
// We are assuming the op can forward additional operands, listed
525425
// immediately after the original operands, to the same regions.
526426
// ^^
527427
// Our interface guarantees this.
528428
// We also assume that the region-holding op returns all of the values
529429
// yielded by terminators, and only those values.
430+
431+
auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
432+
if (!iface) {
433+
op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for "
434+
<< *op << "\n";
435+
return failure();
436+
}
530437
Operation *replacement = iface.createWithShadows(
531438
builder, gutils, op, newOperands, newOpResultTypes);
532439
assert(replacement->getNumResults() == newOpResultTypes.size());

enzyme/test/MLIR/ForwardMode/affine.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ module {
1515
}
1616
// CHECK: @fwddiffeloop
1717
// CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64)
18-
// CHECK-DAG: %[[TEN:.+]] = arith.constant 1.000000e+01 : f64
19-
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64
20-
// CHECK: %[[r0:.+]]:2 = affine.for %{{.*}} = 0 to 10 iter_args(%[[arg3:.+]] = %[[TEN]], %[[arg4:.+]] = %[[ZERO]]) -> (f64, f64) {
18+
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
19+
// CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64
20+
// CHECK: %[[r0:.+]]:2 = affine.for %{{.*}} = 0 to 10 iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, f64) {
2121
// CHECK: %[[v1:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64
2222
// CHECK: %[[v2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64
2323
// CHECK: affine.yield %[[v2]], %[[v1]] : f64, f64

enzyme/test/MLIR/ForwardMode/for3.mlir

Lines changed: 0 additions & 51 deletions
This file was deleted.

enzyme/test/MLIR/ForwardMode/parallel.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ module {
3232

3333
// CHECK: @fwddiffematvec(%[[arg0:.+]]: memref<?x?xf64>, %[[arg1:.+]]: memref<?x?xf64>, %[[arg2:.+]]: memref<?xf64>, %[[arg3:.+]]: memref<?xf64>, %[[arg4:.+]]: memref<?xf64>, %[[arg5:.+]]: memref<?xf64>) {
3434
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
35+
// CHECK: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64
3536
// CHECK: %[[c1:.+]] = arith.constant 1 : index
3637
// CHECK: %[[c0:.+]] = arith.constant 0 : index
3738
// CHECK: %[[dim:.+]] = memref.dim %[[arg0:.+]], %[[c0]] : memref<?x?xf64>
3839
// CHECK: %[[dim_1:.+]] = memref.dim %[[arg0:.+]], %[[c1]] : memref<?x?xf64>
3940
// CHECK: scf.parallel (%[[arg6:.+]]) = (%[[c0]]) to (%dim) step (%[[c1]]) {
40-
// CHECK: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64
41-
// CHECK: %[[x0:.+]]:2 = scf.for %[[arg7:.+]] = %[[c0]] to %[[dim_1]] step %[[c1]] iter_args(%[[arg8:.+]] = %[[cst]], %[[arg9:.+]] = %[[cst_0]]) -> (f64, f64) {
41+
// CHECK: %[[x0:.+]]:2 = scf.for %[[arg7:.+]] = %[[c0]] to %dim_1 step %[[c1]] iter_args(%[[arg8:.+]] = %[[cst_0]], %[[arg9:.+]] = %[[cst]]) -> (f64, f64) {
4242
// CHECK: %[[x1:.+]] = memref.load %[[arg1]][%[[arg6]], %[[arg7]]] : memref<?x?xf64>
4343
// CHECK: %[[x2:.+]] = memref.load %[[arg0]][%[[arg6]], %[[arg7]]] : memref<?x?xf64>
4444
// CHECK: %[[x3:.+]] = memref.load %[[arg3]][%[[arg7]]] : memref<?xf64>

enzyme/test/MLIR/ForwardMode/parallel_reduce.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ module {
2727
}
2828

2929
// CHECK: @fwddiffenrm2(%[[arg0:.+]]: memref<?xf64>, %[[arg1:.+]]: memref<?xf64>) -> f64 {
30-
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
31-
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
32-
// CHECK-DAG: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
33-
// CHECK-DAG: %[[dim:.+]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf64>
30+
// CHECK: %[[c0:.+]] = arith.constant 0 : index
31+
// CHECK: %[[c1:.+]] = arith.constant 1 : index
32+
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
3433
// CHECK: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64
35-
// CHECK: %[[x0:.+]]:2 = scf.parallel (%[[arg2:.+]]) = (%[[c0]]) to (%[[dim]]) step (%[[c1]]) init (%[[cst]], %[[cst_0]]) -> (f64, f64) {
34+
// CHECK: %[[dim:.+]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf64>
35+
// CHECK: %[[x0:.+]]:2 = scf.parallel (%[[arg2:.+]]) = (%[[c0]]) to (%dim) step (%[[c1]]) init (%[[cst_0]], %[[cst]]) -> (f64, f64) {
3636
// CHECK: %[[x1:.+]] = memref.load %[[arg1]][%[[arg2]]] : memref<?xf64>
3737
// CHECK: %[[x2:.+]] = memref.load %[[arg0]][%[[arg2]]] : memref<?xf64>
3838
// CHECK: %[[x3:.+]] = arith.mulf %[[x1]], %[[x2]] : f64

enzyme/test/MLIR/ForwardMode/while.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ module {
2424
}
2525
// CHECK: @fwddiffewhile
2626
// CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {
27-
// CHECK-DAG: %[[TEN:.+]] = arith.constant 1.000000e+01 : f64
28-
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64
29-
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
30-
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
31-
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
32-
// CHECK: %[[r0:.+]]:3 = scf.while (%[[arg2:.+]] = %[[c0]], %[[arg3:.+]] = %[[TEN]], %[[arg4:.+]] = %[[ZERO]]) : (index, f64, f64) -> (index, f64, f64) {
27+
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
28+
// CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64
29+
// CHECK: %[[c0:.+]] = arith.constant 0 : index
30+
// CHECK: %[[c1:.+]] = arith.constant 1 : index
31+
// CHECK: %[[c10:.+]] = arith.constant 10 : index
32+
// CHECK: %[[r0:.+]]:3 = scf.while (%[[arg2:.+]] = %[[c0]], %[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) : (index, f64, f64) -> (index, f64, f64) {
3333
// CHECK: %[[v1:.+]] = arith.cmpi slt, %[[arg2]], %[[c10]] : index
3434
// CHECK: scf.condition(%[[v1]]) %[[arg2]], %[[arg3]], %[[arg4]] : index, f64, f64
3535
// CHECK: } do {

0 commit comments

Comments
 (0)