77// ===----------------------------------------------------------------------===//
88
99#include " mlir/Dialect/SCF/IR/SCF.h"
10- #include " mlir/Dialect /Utils/StructuredOpsUtils .h"
10+ #include " mlir/Interfaces /Utils/MemorySlotUtils .h"
1111
1212using namespace mlir ;
1313using namespace mlir ::scf;
1414
15- // ===----------------------------------------------------------------------===//
16- // Helper functions
17- // ===----------------------------------------------------------------------===//
18-
19- // / Adds the corresponding reaching definition to the terminator of the block if
20- // / the terminator is of the provided type.
21- template <typename TermTy>
22- static void
23- updateTerminator (Block *block, Value defaultReachingDef,
24- const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
25- Operation *terminator = block->getTerminator ();
26- if (!isa<TermTy>(terminator))
27- return ;
28- Value blockReachingDef = reachingAtBlockEnd.lookup (block);
29- if (!blockReachingDef) {
30- // Block is dead code or the region is not using the slot, so we use the
31- // default provided reaching definition.
32- blockReachingDef = defaultReachingDef;
33- }
34- terminator->insertOperands (terminator->getNumOperands (), {blockReachingDef});
35- }
36-
37- // / Creates a shallow copy of an operation with new result types, moving the
38- // / regions out of the original operation and deleting the original operation.
39- static Operation *replaceWithNewResults (RewriterBase &rewriter, Operation *op,
40- TypeRange resultTypes) {
41- RewriterBase::InsertionGuard guard (rewriter);
42- rewriter.setInsertionPoint (op);
43- Operation *newOp =
44- mlir::cloneWithoutRegions (rewriter, op, resultTypes, op->getOperands ());
45- rewriter.startOpModification (newOp);
46- rewriter.startOpModification (op);
47- for (unsigned int i : llvm::seq (op->getNumRegions ()))
48- newOp->getRegion (i).takeBody (op->getRegion (i));
49- rewriter.finalizeOpModification (op);
50- rewriter.finalizeOpModification (newOp);
51-
52- SmallVector<Value> replacementValues (newOp->getResults ().drop_back ());
53- rewriter.replaceAllOpUsesWith (op, replacementValues);
54- rewriter.eraseOp (op);
55- return newOp;
56- }
57-
5815// ===----------------------------------------------------------------------===//
5916// ExecuteRegionOp
6017// ===----------------------------------------------------------------------===//
@@ -80,14 +37,15 @@ Value ExecuteRegionOp::finalizePromotion(
8037 // Update the yield terminators to return the newly defined reaching
8138 // definition.
8239 for (Block &block : getRegion ().getBlocks ())
83- updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
40+ if (isa<YieldOp>(block.getTerminator ()))
41+ memoryslot::updateTerminator (&block, reachingDef, reachingAtBlockEnd);
8442
8543 SmallVector<Type> resultTypes (getResultTypes ());
8644 resultTypes.push_back (slot.elemType );
8745
8846 IRRewriter rewriter (builder);
8947 Operation *newOp =
90- replaceWithNewResults (rewriter, getOperation (), resultTypes);
48+ memoryslot:: replaceWithNewResults (rewriter, getOperation (), resultTypes);
9149 return newOp->getResults ().back ();
9250}
9351
@@ -123,14 +81,14 @@ Value ForOp::finalizePromotion(
12381
12482 // Update the yield terminator to return the newly defined reaching
12583 // definition.
126- updateTerminator<YieldOp> (getBody (), reachingDef, reachingAtBlockEnd);
84+ memoryslot:: updateTerminator (getBody (), reachingDef, reachingAtBlockEnd);
12785
12886 SmallVector<Type> resultTypes (getResultTypes ());
12987 resultTypes.push_back (slot.elemType );
13088
13189 IRRewriter rewriter (builder);
13290 Operation *newOp =
133- replaceWithNewResults (rewriter, getOperation (), resultTypes);
91+ memoryslot:: replaceWithNewResults (rewriter, getOperation (), resultTypes);
13492 return newOp->getResults ().back ();
13593}
13694
@@ -187,11 +145,11 @@ Value IfOp::finalizePromotion(
187145
188146 // Update the yield terminators to return the newly defined reaching
189147 // definition.
190- updateTerminator<YieldOp> (&getThenRegion ().back (), reachingDef,
191- reachingAtBlockEnd);
148+ memoryslot:: updateTerminator (&getThenRegion ().back (), reachingDef,
149+ reachingAtBlockEnd);
192150 if (getElseRegion ().hasOneBlock ()) {
193- updateTerminator<YieldOp> (&getElseRegion ().back (), reachingDef,
194- reachingAtBlockEnd);
151+ memoryslot:: updateTerminator (&getElseRegion ().back (), reachingDef,
152+ reachingAtBlockEnd);
195153 } else {
196154 OpBuilder::InsertionGuard guard (rewriter);
197155 rewriter.createBlock (&getElseRegion ());
@@ -202,7 +160,7 @@ Value IfOp::finalizePromotion(
202160 resultTypes.push_back (slot.elemType );
203161
204162 Operation *newOp =
205- replaceWithNewResults (rewriter, getOperation (), resultTypes);
163+ memoryslot:: replaceWithNewResults (rewriter, getOperation (), resultTypes);
206164 return newOp->getResults ().back ();
207165}
208166
@@ -234,17 +192,17 @@ Value IndexSwitchOp::finalizePromotion(
234192
235193 // Update the yield terminators to return the newly defined reaching
236194 // definition.
237- updateTerminator<YieldOp> (&getDefaultRegion ().back (), reachingDef,
238- reachingAtBlockEnd);
195+ memoryslot:: updateTerminator (&getDefaultRegion ().back (), reachingDef,
196+ reachingAtBlockEnd);
239197 for (Region &caseRegion : getCaseRegions ())
240- updateTerminator<YieldOp> (&caseRegion.back (), reachingDef,
241- reachingAtBlockEnd);
198+ memoryslot:: updateTerminator (&caseRegion.back (), reachingDef,
199+ reachingAtBlockEnd);
242200
243201 SmallVector<Type> resultTypes (getResultTypes ());
244202 resultTypes.push_back (slot.elemType );
245203
246204 Operation *newOp =
247- replaceWithNewResults (rewriter, getOperation (), resultTypes);
205+ memoryslot:: replaceWithNewResults (rewriter, getOperation (), resultTypes);
248206 return newOp->getResults ().back ();
249207}
250208
@@ -339,17 +297,17 @@ Value WhileOp::finalizePromotion(
339297
340298 // Update the yield terminators to return the newly defined reaching
341299 // definition.
342- updateTerminator<ConditionOp> (&getBefore ().back (),
343- getBefore ().getArguments ().back (),
344- reachingAtBlockEnd);
345- updateTerminator<YieldOp> (
300+ memoryslot:: updateTerminator (&getBefore ().back (),
301+ getBefore ().getArguments ().back (),
302+ reachingAtBlockEnd);
303+ memoryslot:: updateTerminator (
346304 &getAfter ().back (), getAfter ().getArguments ().back (), reachingAtBlockEnd);
347305
348306 SmallVector<Type> resultTypes (getResultTypes ());
349307 resultTypes.push_back (slot.elemType );
350308
351309 IRRewriter rewriter (builder);
352310 Operation *newOp =
353- replaceWithNewResults (rewriter, getOperation (), resultTypes);
311+ memoryslot:: replaceWithNewResults (rewriter, getOperation (), resultTypes);
354312 return newOp->getResults ().back ();
355313}
0 commit comments