Skip to content

Commit cb9ca83

Browse files
committed
mlir: mark loads of read-only pointers as movable during mincut
* implement MemoryEffectOpInterface for enzyme.push/pop * fix warnings with int signedness
1 parent b847f3b commit cb9ca83

5 files changed

Lines changed: 117 additions & 69 deletions

File tree

enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,8 @@ std::optional<Value> getStored(Operation *op) {
477477
return storeOp.getValue();
478478
} else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
479479
return storeOp.getValue();
480+
} else if (auto pushOp = dyn_cast<enzyme::PushOp>(op)) {
481+
return pushOp.getValue();
480482
}
481483
return std::nullopt;
482484
}

enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.cpp

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ using namespace mlir;
4848
using namespace mlir::dataflow;
4949

5050
static bool isPointerLike(Type type) {
51-
return isa<MemRefType, LLVM::LLVMPointerType>(type);
51+
return isa<MemRefType, LLVM::LLVMPointerType, enzyme::CacheType>(type);
5252
}
5353

5454
//===----------------------------------------------------------------------===//
@@ -121,29 +121,6 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate,
121121
[&](DistinctAttr dest, AliasClassSet::State state) {
122122
assert(state == AliasClassSet::State::Defined &&
123123
"unknown must have been handled above");
124-
#ifndef NDEBUG
125-
if (replace) {
126-
auto it = map.find(dest);
127-
if (it != map.end()) {
128-
// Check that we are updating to a state that's >= in the
129-
// lattice.
130-
// TODO: consider a stricter check that we only replace unknown
131-
// values or a value with itself, currently blocked by memalign.
132-
AliasClassSet valuesCopy(values);
133-
(void)valuesCopy.join(it->getSecond());
134-
values.print(llvm::errs());
135-
llvm::errs() << "\n";
136-
it->getSecond().print(llvm::errs());
137-
llvm::errs() << "\n";
138-
valuesCopy.print(llvm::errs());
139-
llvm::errs() << "\n";
140-
assert(valuesCopy == values &&
141-
"attempting to replace a pointsTo entry with an alias class "
142-
"set that is ordered _before_ the existing one -> "
143-
"non-monotonous update ");
144-
}
145-
}
146-
#endif // NDEBUG
147124
return joinPotentiallyMissing(dest, values);
148125
});
149126
}
@@ -278,7 +255,7 @@ LogicalResult enzyme::PointsToPointerAnalysis::visitOperation(
278255
// fixpoint and bail.
279256
auto memory = dyn_cast<MemoryEffectOpInterface>(op);
280257
if (!memory) {
281-
if (isNoOp(op))
258+
if (isNoOp(op) || isMemoryEffectFree(op))
282259
return success();
283260
propagateIfChanged(after, after->markAllPointToUnknown());
284261
return success();
@@ -557,7 +534,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
557534
std::optional<LLVM::ModRefInfo> otherModRef =
558535
getFunctionOtherModRef(callee);
559536

560-
SmallVector<int> pointerLikeOperands;
537+
SmallVector<unsigned> pointerLikeOperands;
561538
for (auto &&[i, operand] : llvm::enumerate(call.getArgOperands())) {
562539
if (isPointerLike(operand.getType()))
563540
pointerLikeOperands.push_back(i);
@@ -575,7 +552,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
575552
// unknown alias sets into any writable pointer.
576553
(void)functionMayCapture.markUnknown();
577554
} else {
578-
for (int pointerAsData : pointerLikeOperands) {
555+
for (unsigned pointerAsData : pointerLikeOperands) {
579556
// If not captured, it cannot be stored in anything.
580557
if ((pointerAsData < numArguments &&
581558
!!callee.getArgAttr(pointerAsData,
@@ -593,7 +570,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
593570
AliasClassSet writableClasses = AliasClassSet::getUndefined();
594571
AliasClassSet nonWritableOperandClasses = AliasClassSet::getUndefined();
595572
ChangeResult changed = ChangeResult::NoChange;
596-
for (int pointerOperand : pointerLikeOperands) {
573+
for (unsigned pointerOperand : pointerLikeOperands) {
597574
auto *destClasses = getOrCreateFor<AliasClassLattice>(
598575
getProgramPointAfter(call), call.getArgOperands()[pointerOperand]);
599576

@@ -696,7 +673,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
696673
continue;
697674
}
698675

699-
for (int operandNo : pointerLikeOperands) {
676+
for (unsigned operandNo : pointerLikeOperands) {
700677
const auto *srcClasses = getOrCreateFor<AliasClassLattice>(
701678
getProgramPointAfter(call), call.getArgOperands()[operandNo]);
702679
if (mayReadArg(callee, operandNo, argModRef)) {
@@ -840,7 +817,8 @@ static bool isAliasTransferFullyDescribedByMemoryEffects(Operation *op) {
840817
}
841818
}
842819
}
843-
return isa<memref::LoadOp, memref::StoreOp, LLVM::LoadOp, LLVM::StoreOp>(op);
820+
return isa<memref::LoadOp, memref::StoreOp, LLVM::LoadOp, LLVM::StoreOp,
821+
enzyme::PushOp, enzyme::PopOp>(op);
844822
}
845823

846824
void enzyme::AliasAnalysis::transfer(

enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,23 +279,23 @@ def PushOp : Enzyme_Op<"push", [
279279
"cache", "value",
280280
"::llvm::cast<enzyme::CacheType>($_self).getType()">]> {
281281
let summary = "Push value to cache";
282-
let arguments = (ins AnyType : $cache, AnyType : $value);
282+
let arguments = (ins Arg<AnyType, "the cache to push to", [MemWrite]>:$cache, AnyType:$value);
283283
}
284284

285285
def PopOp : Enzyme_Op<"pop", [
286286
TypesMatchWith<"type of 'output' matches element type of 'cache'",
287287
"cache", "output",
288288
"::llvm::cast<enzyme::CacheType>($_self).getType()">]> {
289289
let summary = "Retrieve information for the reverse mode pass.";
290-
let arguments = (ins AnyType : $cache);
290+
let arguments = (ins Arg<AnyType, "the cache to pop from", [MemRead]>:$cache);
291291
let results = (outs AnyType:$output);
292292
}
293293

294294
def InitOp : Enzyme_Op<"init",
295295
[DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
296296
let summary = "Create enzyme.gradient and enzyme.cache";
297297
let arguments = (ins );
298-
let results = (outs AnyType);
298+
let results = (outs Res<AnyType, "", [MemAlloc<DefaultResource, 0, FullEffect>]>);
299299
}
300300

301301
def Cache : Enzyme_Type<"Cache"> {

enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#include <cassert>
1717
#include <deque>
1818

19+
#include "Analysis/DataFlowAliasAnalysis.h"
20+
#include "mlir/Analysis/DataFlow/Utils.h"
21+
1922
#include "llvm/ADT/MapVector.h"
2023

2124
using namespace mlir;
@@ -263,11 +266,76 @@ static inline void bfs(const Graph &G, const llvm::SetVector<Value> &Sources,
263266
}
264267
}
265268

269+
struct OverwriteAnalyzer {
270+
static OverwriteAnalyzer analyzeFunc(FunctionOpInterface funcOp) {
271+
return OverwriteAnalyzer(funcOp);
272+
}
273+
274+
bool isPtrPotentiallyModified(Value ptr) const {
275+
// If the alias analysis failed, conservatively assume all pointers may
276+
// be modified
277+
if (!valid)
278+
return true;
279+
280+
// Check if the pointer's alias classes intersect the modified alias classes
281+
auto *ptrClass = solver.lookupState<AliasClassLattice>(ptr);
282+
return !ptrClass->alias(modified).isNo();
283+
}
284+
285+
private:
286+
DataFlowSolver solver;
287+
// The set of all alias classes that are potentially modified in the function
288+
AliasClassLattice modified;
289+
bool valid = true;
290+
291+
OverwriteAnalyzer(FunctionOpInterface funcOp)
292+
: solver(DataFlowConfig().setInterprocedural(false)), modified(nullptr) {
293+
dataflow::loadBaselineAnalyses(solver);
294+
solver.load<enzyme::AliasAnalysis>(funcOp.getContext(), /*relative=*/false);
295+
solver.load<enzyme::PointsToPointerAnalysis>();
296+
if (failed(solver.initializeAndRun(funcOp))) {
297+
assert(false && "dataflow analysis failed");
298+
valid = false;
299+
} else {
300+
funcOp.walk([&](MemoryEffectOpInterface memory) {
301+
SmallVector<MemoryEffects::EffectInstance> effects;
302+
memory.getEffects(effects);
303+
for (const auto &effect : effects) {
304+
if (isa<MemoryEffects::Write>(effect.getEffect())) {
305+
Value val = effect.getValue();
306+
if (val) {
307+
(void)modified.join(*solver.lookupState<AliasClassLattice>(val));
308+
} else {
309+
(void)modified.markUnknown();
310+
}
311+
}
312+
}
313+
});
314+
}
315+
}
316+
};
317+
318+
bool isLoadMovable(const OverwriteAnalyzer &analyzer, Operation *op) {
319+
if (!hasSingleEffect<MemoryEffects::Read>(op)) {
320+
return false;
321+
}
322+
auto memory = cast<MemoryEffectOpInterface>(op);
323+
SmallVector<MemoryEffects::EffectInstance> effects;
324+
memory.getEffects(effects);
325+
assert(effects.size() == 1 &&
326+
isa<MemoryEffects::Read>(effects.front().getEffect()));
327+
Value ptr = effects.front().getValue();
328+
329+
// The load can be re-done if the pointer's contents are never modified
330+
// by the function.
331+
return !analyzer.isPtrPotentiallyModified(ptr);
332+
}
333+
266334
// Whether or not an operation can be moved from the forward region to the
267335
// reverse region or vice-versa.
268-
static inline bool isMovable(Operation *op) {
336+
static inline bool isMovable(const OverwriteAnalyzer &analyzer, Operation *op) {
269337
return op->getNumRegions() == 0 && op->getBlock()->getTerminator() != op &&
270-
mlir::isPure(op);
338+
(mlir::isPure(op) || isLoadMovable(analyzer, op));
271339
}
272340

273341
// Given a graph `G`, construct a new graph `G2`, where all paths must terminate
@@ -487,6 +555,8 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse,
487555
}
488556

489557
Graph G;
558+
auto overwriteAnalyzer = OverwriteAnalyzer::analyzeFunc(
559+
forward->getParent()->getParentOfType<FunctionOpInterface>());
490560

491561
LLVM_DEBUG(llvm::dbgs() << "trying min/cut\n");
492562
LLVM_DEBUG(
@@ -518,7 +588,7 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse,
518588
}
519589

520590
Operation *owner = todo.getDefiningOp();
521-
if (!owner || !isMovable(owner)) {
591+
if (!owner || !isMovable(overwriteAnalyzer, owner)) {
522592
roots.insert(todo);
523593
continue;
524594
}
@@ -544,7 +614,8 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse,
544614

545615
bool isRequired = false;
546616
for (auto user : poped.getUsers()) {
547-
if (user->getBlock() != reverse || !isMovable(user)) {
617+
if (user->getBlock() != reverse ||
618+
!isMovable(overwriteAnalyzer, user)) {
548619
G[info.pushedValue()].insert(Node(user));
549620
Required.insert(user);
550621
isRequired = true;
@@ -567,7 +638,8 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse,
567638

568639
bool isRequired = false;
569640
for (auto user : todo.getUsers()) {
570-
if (user->getBlock() != reverse || !isMovable(user)) {
641+
if (user->getBlock() != reverse ||
642+
!isMovable(overwriteAnalyzer, user)) {
571643
G[todo].insert(Node(user));
572644
Required.insert(user);
573645
isRequired = true;
Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: %eopt %s --pass-pipeline="builtin.module(enzyme,canonicalize,remove-unnecessary-enzyme-ops)" | FileCheck %s
2-
func.func @foo(%x: memref<?xf32>, %y: memref<?xf32>) {
1+
// RUN: %eopt %s --pass-pipeline="builtin.module(enzyme,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math)" | FileCheck %s
2+
func.func @foo(%x: memref<?xf32> {llvm.noalias}, %y: memref<?xf32> {llvm.noalias}) {
33
%c0 = arith.constant 0 : index
44
%c1 = arith.constant 1 : index
55
%c4 = arith.constant 4 : index
@@ -12,39 +12,35 @@ func.func @foo(%x: memref<?xf32>, %y: memref<?xf32>) {
1212
return
1313
}
1414

15-
func.func @dfoo(%x: memref<?xf32>, %dx: memref<?xf32>, %y: memref<?xf32>, %dy: memref<?xf32>) {
15+
func.func @dfoo(%x: memref<?xf32> {llvm.noalias}, %dx: memref<?xf32> {llvm.noalias}, %y: memref<?xf32> {llvm.noalias}, %dy: memref<?xf32> {llvm.noalias}) {
1616
enzyme.autodiff @foo(%x, %dx, %y, %dy) {
1717
activity = [#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>],
1818
ret_activity = []
1919
} : (memref<?xf32>, memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
2020
return
2121
}
2222

23-
// CHECK: func.func private @diffefoo(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>, %arg3: memref<?xf32>) {
24-
// CHECK-NEXT: %c4 = arith.constant 4 : index
25-
// CHECK-NEXT: %c1 = arith.constant 1 : index
26-
// CHECK-NEXT: %c0 = arith.constant 0 : index
27-
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
28-
// CHECK-NEXT: %alloc = memref.alloc() : memref<4xf32>
29-
// CHECK-NEXT: scf.parallel (%arg4) = (%c0) to (%c4) step (%c1) {
30-
// CHECK-NEXT: %0 = memref.load %arg0[%arg4] : memref<?xf32>
31-
// CHECK-NEXT: memref.store %0, %alloc[%arg4] : memref<4xf32>
32-
// CHECK-NEXT: %1 = arith.mulf %0, %0 : f32
33-
// CHECK-NEXT: memref.store %1, %arg2[%arg4] : memref<?xf32>
34-
// CHECK-NEXT: scf.reduce
35-
// CHECK-NEXT: }
36-
// CHECK-NEXT: scf.parallel (%arg4) = (%c0) to (%c4) step (%c1) {
37-
// CHECK-NEXT: %0 = memref.load %alloc[%arg4] : memref<4xf32>
38-
// CHECK-NEXT: %1 = memref.load %arg3[%arg4] : memref<?xf32>
39-
// CHECK-NEXT: %2 = arith.addf %1, %cst : f32
40-
// CHECK-NEXT: memref.store %cst, %arg3[%arg4] : memref<?xf32>
41-
// CHECK-NEXT: %3 = arith.mulf %2, %0 : f32
42-
// CHECK-NEXT: %4 = arith.addf %3, %cst : f32
43-
// CHECK-NEXT: %5 = arith.mulf %2, %0 : f32
44-
// CHECK-NEXT: %6 = arith.addf %4, %5 : f32
45-
// CHECK-NEXT: %7 = memref.atomic_rmw addf %6, %arg1[%arg4] : (f32, memref<?xf32>) -> f32
46-
// CHECK-NEXT: scf.reduce
47-
// CHECK-NEXT: }
48-
// CHECK-NEXT: memref.dealloc %alloc : memref<4xf32>
49-
// CHECK-NEXT: return
50-
// CHECK-NEXT: }
23+
// CHECK-LABEL: func.func private @diffefoo(
24+
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32> {llvm.noalias}, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf32> {llvm.noalias}, %[[ARG3:.*]]: memref<?xf32>) {
25+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 4 : index
26+
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : index
27+
// CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : index
28+
// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0.000000e+00 : f32
29+
// CHECK: scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_2]]) to (%[[CONSTANT_0]]) step (%[[CONSTANT_1]]) {
30+
// CHECK: %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref<?xf32>
31+
// CHECK: %[[MULF_0:.*]] = arith.mulf %[[LOAD_0]], %[[LOAD_0]] : f32
32+
// CHECK: memref.store %[[MULF_0]], %[[ARG2]]{{\[}}%[[VAL_0]]] : memref<?xf32>
33+
// CHECK: scf.reduce
34+
// CHECK: }
35+
// CHECK: scf.parallel (%[[VAL_1:.*]]) = (%[[CONSTANT_2]]) to (%[[CONSTANT_0]]) step (%[[CONSTANT_1]]) {
36+
// CHECK: %[[LOAD_1:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_1]]] : memref<?xf32>
37+
// CHECK: %[[LOAD_2:.*]] = memref.load %[[ARG3]]{{\[}}%[[VAL_1]]] : memref<?xf32>
38+
// CHECK: memref.store %[[CONSTANT_3]], %[[ARG3]]{{\[}}%[[VAL_1]]] : memref<?xf32>
39+
// CHECK: %[[MULF_1:.*]] = arith.mulf %[[LOAD_2]], %[[LOAD_1]] : f32
40+
// CHECK: %[[MULF_2:.*]] = arith.mulf %[[LOAD_2]], %[[LOAD_1]] : f32
41+
// CHECK: %[[ADDF_0:.*]] = arith.addf %[[MULF_1]], %[[MULF_2]] : f32
42+
// CHECK: %[[ATOMIC_RMW_0:.*]] = memref.atomic_rmw addf %[[ADDF_0]], %[[ARG1]]{{\[}}%[[VAL_1]]] : (f32, memref<?xf32>) -> f32
43+
// CHECK: scf.reduce
44+
// CHECK: }
45+
// CHECK: return
46+
// CHECK: }

0 commit comments

Comments
 (0)