Skip to content

Commit b847f3b

Browse files
authored
MLIR: Fix a couple of issues with reverse mode applied to scf.if + scf.parallel (#2723)
* Implement MemRefAutoDiffTypeInterface::createNullValue Needed for reverse mode of scf.if. I implemented this by creating a memref of the given type with 0 as the size for all dynamic dimensions. * Fix ParallelOpEnzymeOpsRemover::getCanonicalLoopIVs The step was being subtracted from the nonzero lower bound instead of the lower bound itself * Add reverse mode test if scf.if + scf.parallel * Run clang-format
1 parent 62b3124 commit b847f3b

3 files changed

Lines changed: 70 additions & 2 deletions

File tree

enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,15 @@ class MemRefAutoDiffTypeInterface
242242
}
243243
mlir::Value createNullValue(mlir::Type self, OpBuilder &builder,
244244
Location loc) const {
245-
llvm_unreachable("Cannot create null of memref (todo polygeist null)");
245+
// Create a memref of the given type with the required number of
246+
// dynamic dimensions, all set to 0
247+
MemRefType MT = cast<MemRefType>(self);
248+
unsigned numDynamicDims = MT.getNumDynamicDims();
249+
SmallVector<mlir::Value> dynamicSizes(numDynamicDims);
250+
for (unsigned i = 0; i < numDynamicDims; ++i) {
251+
dynamicSizes[i] = builder.create<mlir::arith::ConstantIndexOp>(loc, 0);
252+
}
253+
return mlir::memref::AllocOp::create(builder, loc, MT, dynamicSizes);
246254
}
247255

248256
Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,

enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ struct ParallelOpEnzymeOpsRemover
675675
parOp.getStep())) {
676676
Value val = iv;
677677
if (!matchPattern(lb, m_Zero())) {
678-
val = arith::SubIOp::create(builder, parOp.getLoc(), val, step);
678+
val = arith::SubIOp::create(builder, parOp.getLoc(), val, lb);
679679
}
680680

681681
if (!matchPattern(step, m_One())) {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: %eopt %s --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse | FileCheck %s
2+
3+
module {
4+
func.func @scale(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: memref<?xf64>) {
5+
%c0 = arith.constant 0 : index
6+
%c1 = arith.constant 1 : index
7+
%dim = memref.dim %arg0, %c0 : memref<?xf64>
8+
%0 = arith.cmpi sgt, %dim, %c0 : index
9+
scf.if %0 {
10+
scf.parallel (%arg3) = (%c0) to (%dim) step (%c1) {
11+
%1 = memref.load %arg0[%arg3] : memref<?xf64>
12+
%2 = memref.load %arg1[%arg3] : memref<?xf64>
13+
%3 = arith.mulf %1, %2 : f64
14+
memref.store %3, %arg2[%arg3] : memref<?xf64>
15+
}
16+
}
17+
return
18+
}
19+
20+
func.func @dscale(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: memref<?xf64>, %arg3: memref<?xf64>, %arg4: memref<?xf64>) {
21+
enzyme.autodiff @scale(%arg0, %arg3, %arg1, %arg2, %arg4) {
22+
activity=[#enzyme<activity enzyme_dup>,
23+
#enzyme<activity enzyme_const>,
24+
#enzyme<activity enzyme_dup>],
25+
ret_activity=[]
26+
} : (memref<?xf64>, memref<?xf64>, memref<?xf64>, memref<?xf64>, memref<?xf64>) -> ()
27+
28+
return
29+
}
30+
31+
// CHECK: @diffescale(%[[arg0:.+]]: memref<?xf64>, %[[arg1:.+]]: memref<?xf64>, %[[arg2:.+]]: memref<?xf64>, %[[arg3:.+]]: memref<?xf64>, %[[arg4:.+]]: memref<?xf64>) {
32+
// CHECK: %[[c1:.+]] = arith.constant 1 : index
33+
// CHECK: %[[c0:.+]] = arith.constant 0 : index
34+
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
35+
// CHECK: %[[dim:.+]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf64>
36+
// CHECK: %[[x0:.+]] = arith.cmpi sgt, %[[dim]], %[[c0]] : index
37+
// CHECK: scf.if %[[x0]] {
38+
// CHECK: %[[alloc:.+]] = memref.alloc(%[[dim]]) : memref<?xf64>
39+
// CHECK: scf.parallel (%[[arg5:.+]]) = (%[[c0]]) to (%[[dim]]) step (%[[c1]]) {
40+
// CHECK: %[[x1:.+]] = memref.load %arg0[%[[arg5]]] : memref<?xf64>
41+
// CHECK: %[[x2:.+]] = memref.load %arg2[%[[arg5]]] : memref<?xf64>
42+
// CHECK: memref.store %[[x2]], %[[alloc]][%[[arg5]]] : memref<?xf64>
43+
// CHECK: %[[x3:.+]] = arith.mulf %[[x1]], %[[x2]] : f64
44+
// CHECK: memref.store %[[x3]], %arg3[%[[arg5]]] : memref<?xf64>
45+
// CHECK: scf.reduce
46+
// CHECK: }
47+
// CHECK: scf.parallel (%[[arg5:.+]]) = (%[[c0]]) to (%[[dim]]) step (%[[c1]]) {
48+
// CHECK: %[[x1:.+]] = memref.load %[[alloc]][%[[arg5]]] : memref<?xf64>
49+
// CHECK: %[[x2:.+]] = memref.load %arg4[%[[arg5]]] : memref<?xf64>
50+
// CHECK: memref.store %[[cst]], %arg4[%[[arg5]]] : memref<?xf64>
51+
// CHECK: %[[x3:.+]] = arith.mulf %[[x2]], %[[x1]] : f64
52+
// CHECK: %[[x4:.+]] = memref.atomic_rmw addf %[[x3]], %arg1[%[[arg5]]] : (f64, memref<?xf64>) -> f64
53+
// CHECK: scf.reduce
54+
// CHECK: }
55+
// CHECK: memref.dealloc %[[alloc]] : memref<?xf64>
56+
// CHECK: }
57+
// CHECK: return
58+
// CHECK: }
59+
60+
}

0 commit comments

Comments
 (0)