Skip to content

Commit b6e4d27

Browse files
authored
[MLIR][Mem2Reg] Extract shared utilities for PromotableRegionOpInterface (llvm#188514)
The `PromotableRegionOpInterface` implementations use two helpers that are likely useful for other dialects implementing this interface as well: - `updateTerminator`: Appends the reaching definition as an operand to a block's terminator, falling back to a default when the block has no entry (e.g. dead code). - `replaceWithNewResults`: Clones an operation with additional result types while preserving its regions, then replaces the original. This PR extracts them into a common utility header so that downstream dialects can reuse them directly. I'm open to discussion about the location of these utilities.
1 parent 06725d7 commit b6e4d27

7 files changed

Lines changed: 319 additions & 63 deletions

File tree

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- MemorySlotUtils.h - Utilities for MemorySlot interfaces --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares common utilities for implementing MemorySlot interfaces,
10+
// in particular PromotableRegionOpInterface.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
15+
#define MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
16+
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
namespace mlir {
20+
namespace memoryslot {
21+
22+
/// Appends the reaching definition for the given block as an operand to its
23+
/// terminator. If the block has no entry in `reachingAtBlockEnd` (e.g. dead
24+
/// code or the region does not use the slot), `defaultReachingDef` is used.
25+
void updateTerminator(Block *block, Value defaultReachingDef,
26+
const DenseMap<Block *, Value> &reachingAtBlockEnd);
27+
28+
/// Creates a shallow copy of an operation with new result types, moving the
29+
/// regions out of the original operation and deleting the original operation.
30+
Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
31+
TypeRange resultTypes);
32+
33+
} // namespace memoryslot
34+
} // namespace mlir
35+
36+
#endif // MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H

mlir/lib/Dialect/SCF/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRSCFDialect
1717
MLIRFunctionInterfaces
1818
MLIRIR
1919
MLIRLoopLikeInterface
20+
MLIRMemorySlotUtils
2021
MLIRSideEffectInterfaces
2122
MLIRTensorDialect
2223
MLIRValueBoundsOpInterface

mlir/lib/Dialect/SCF/IR/MemorySlot.cpp

Lines changed: 21 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,54 +7,11 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/SCF/IR/SCF.h"
10-
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10+
#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
1111

1212
using namespace mlir;
1313
using namespace mlir::scf;
1414

15-
//===----------------------------------------------------------------------===//
16-
// Helper functions
17-
//===----------------------------------------------------------------------===//
18-
19-
/// Adds the corresponding reaching definition to the terminator of the block if
20-
/// the terminator is of the provided type.
21-
template <typename TermTy>
22-
static void
23-
updateTerminator(Block *block, Value defaultReachingDef,
24-
const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
25-
Operation *terminator = block->getTerminator();
26-
if (!isa<TermTy>(terminator))
27-
return;
28-
Value blockReachingDef = reachingAtBlockEnd.lookup(block);
29-
if (!blockReachingDef) {
30-
// Block is dead code or the region is not using the slot, so we use the
31-
// default provided reaching definition.
32-
blockReachingDef = defaultReachingDef;
33-
}
34-
terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
35-
}
36-
37-
/// Creates a shallow copy of an operation with new result types, moving the
38-
/// regions out of the original operation and deleting the original operation.
39-
static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
40-
TypeRange resultTypes) {
41-
RewriterBase::InsertionGuard guard(rewriter);
42-
rewriter.setInsertionPoint(op);
43-
Operation *newOp =
44-
mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
45-
rewriter.startOpModification(newOp);
46-
rewriter.startOpModification(op);
47-
for (unsigned int i : llvm::seq(op->getNumRegions()))
48-
newOp->getRegion(i).takeBody(op->getRegion(i));
49-
rewriter.finalizeOpModification(op);
50-
rewriter.finalizeOpModification(newOp);
51-
52-
SmallVector<Value> replacementValues(newOp->getResults().drop_back());
53-
rewriter.replaceAllOpUsesWith(op, replacementValues);
54-
rewriter.eraseOp(op);
55-
return newOp;
56-
}
57-
5815
//===----------------------------------------------------------------------===//
5916
// ExecuteRegionOp
6017
//===----------------------------------------------------------------------===//
@@ -80,14 +37,15 @@ Value ExecuteRegionOp::finalizePromotion(
8037
// Update the yield terminators to return the newly defined reaching
8138
// definition.
8239
for (Block &block : getRegion().getBlocks())
83-
updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
40+
if (isa<YieldOp>(block.getTerminator()))
41+
memoryslot::updateTerminator(&block, reachingDef, reachingAtBlockEnd);
8442

8543
SmallVector<Type> resultTypes(getResultTypes());
8644
resultTypes.push_back(slot.elemType);
8745

8846
IRRewriter rewriter(builder);
8947
Operation *newOp =
90-
replaceWithNewResults(rewriter, getOperation(), resultTypes);
48+
memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
9149
return newOp->getResults().back();
9250
}
9351

@@ -123,14 +81,14 @@ Value ForOp::finalizePromotion(
12381

12482
// Update the yield terminator to return the newly defined reaching
12583
// definition.
126-
updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
84+
memoryslot::updateTerminator(getBody(), reachingDef, reachingAtBlockEnd);
12785

12886
SmallVector<Type> resultTypes(getResultTypes());
12987
resultTypes.push_back(slot.elemType);
13088

13189
IRRewriter rewriter(builder);
13290
Operation *newOp =
133-
replaceWithNewResults(rewriter, getOperation(), resultTypes);
91+
memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
13492
return newOp->getResults().back();
13593
}
13694

@@ -187,11 +145,11 @@ Value IfOp::finalizePromotion(
187145

188146
// Update the yield terminators to return the newly defined reaching
189147
// definition.
190-
updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
191-
reachingAtBlockEnd);
148+
memoryslot::updateTerminator(&getThenRegion().back(), reachingDef,
149+
reachingAtBlockEnd);
192150
if (getElseRegion().hasOneBlock()) {
193-
updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
194-
reachingAtBlockEnd);
151+
memoryslot::updateTerminator(&getElseRegion().back(), reachingDef,
152+
reachingAtBlockEnd);
195153
} else {
196154
OpBuilder::InsertionGuard guard(rewriter);
197155
rewriter.createBlock(&getElseRegion());
@@ -202,7 +160,7 @@ Value IfOp::finalizePromotion(
202160
resultTypes.push_back(slot.elemType);
203161

204162
Operation *newOp =
205-
replaceWithNewResults(rewriter, getOperation(), resultTypes);
163+
memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
206164
return newOp->getResults().back();
207165
}
208166

@@ -234,17 +192,17 @@ Value IndexSwitchOp::finalizePromotion(
234192

235193
// Update the yield terminators to return the newly defined reaching
236194
// definition.
237-
updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
238-
reachingAtBlockEnd);
195+
memoryslot::updateTerminator(&getDefaultRegion().back(), reachingDef,
196+
reachingAtBlockEnd);
239197
for (Region &caseRegion : getCaseRegions())
240-
updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
241-
reachingAtBlockEnd);
198+
memoryslot::updateTerminator(&caseRegion.back(), reachingDef,
199+
reachingAtBlockEnd);
242200

243201
SmallVector<Type> resultTypes(getResultTypes());
244202
resultTypes.push_back(slot.elemType);
245203

246204
Operation *newOp =
247-
replaceWithNewResults(rewriter, getOperation(), resultTypes);
205+
memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
248206
return newOp->getResults().back();
249207
}
250208

@@ -339,17 +297,17 @@ Value WhileOp::finalizePromotion(
339297

340298
// Update the yield terminators to return the newly defined reaching
341299
// definition.
342-
updateTerminator<ConditionOp>(&getBefore().back(),
343-
getBefore().getArguments().back(),
344-
reachingAtBlockEnd);
345-
updateTerminator<YieldOp>(
300+
memoryslot::updateTerminator(&getBefore().back(),
301+
getBefore().getArguments().back(),
302+
reachingAtBlockEnd);
303+
memoryslot::updateTerminator(
346304
&getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd);
347305

348306
SmallVector<Type> resultTypes(getResultTypes());
349307
resultTypes.push_back(slot.elemType);
350308

351309
IRRewriter rewriter(builder);
352310
Operation *newOp =
353-
replaceWithNewResults(rewriter, getOperation(), resultTypes);
311+
memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
354312
return newOp->getResults().back();
355313
}

mlir/lib/Interfaces/Utils/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
set(LLVM_OPTIONAL_SOURCES
2+
InferIntRangeCommon.cpp
3+
MemorySlotUtils.cpp
4+
)
5+
16
add_mlir_library(MLIRInferIntRangeCommon
27
InferIntRangeCommon.cpp
38

@@ -12,3 +17,13 @@ add_mlir_library(MLIRInferIntRangeCommon
1217
MLIRInferIntRangeInterface
1318
MLIRIR
1419
)
20+
21+
add_mlir_library(MLIRMemorySlotUtils
22+
MemorySlotUtils.cpp
23+
24+
ADDITIONAL_HEADER_DIRS
25+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils
26+
27+
LINK_LIBS PUBLIC
28+
MLIRIR
29+
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//===- MemorySlotUtils.cpp - Utilities for MemorySlot interfaces ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements common utilities for implementing MemorySlot interfaces,
10+
// in particular PromotableRegionOpInterface.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
15+
16+
using namespace mlir;
17+
18+
void mlir::memoryslot::updateTerminator(
19+
Block *block, Value defaultReachingDef,
20+
const DenseMap<Block *, Value> &reachingAtBlockEnd) {
21+
Value blockReachingDef = reachingAtBlockEnd.lookup(block);
22+
if (!blockReachingDef)
23+
blockReachingDef = defaultReachingDef;
24+
Operation *terminator = block->getTerminator();
25+
terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
26+
}
27+
28+
Operation *mlir::memoryslot::replaceWithNewResults(RewriterBase &rewriter,
29+
Operation *op,
30+
TypeRange resultTypes) {
31+
RewriterBase::InsertionGuard guard(rewriter);
32+
rewriter.setInsertionPoint(op);
33+
OperationState state(op->getLoc(), op->getName(), op->getOperands(),
34+
resultTypes, op->getAttrs());
35+
state.propertiesAttr = op->getPropertiesAsAttribute();
36+
unsigned numRegions = op->getNumRegions();
37+
for (unsigned i = 0; i < numRegions; ++i)
38+
state.addRegion();
39+
Operation *newOp = rewriter.create(state);
40+
rewriter.startOpModification(newOp);
41+
rewriter.startOpModification(op);
42+
for (unsigned i = 0; i < numRegions; ++i)
43+
newOp->getRegion(i).takeBody(op->getRegion(i));
44+
rewriter.finalizeOpModification(op);
45+
rewriter.finalizeOpModification(newOp);
46+
47+
rewriter.replaceAllOpUsesWith(
48+
op, newOp->getResults().take_front(op->getNumResults()));
49+
rewriter.eraseOp(op);
50+
return newOp;
51+
}

mlir/unittests/Interfaces/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@ add_mlir_unittest(MLIRInterfacesTests
22
ControlFlowInterfacesTest.cpp
33
DataLayoutInterfacesTest.cpp
44
InferIntRangeInterfaceTest.cpp
5+
MemorySlotUtilsTest.cpp
56
SideEffectInterfacesTest.cpp
67
InferTypeOpInterfaceTest.cpp
8+
9+
DEPENDS
10+
MLIRTestInterfaceIncGen
711
)
12+
target_include_directories(MLIRInterfacesTests PRIVATE "${MLIR_BINARY_DIR}/test/lib/Dialect/Test")
813

914
mlir_target_link_libraries(MLIRInterfacesTests
1015
PRIVATE
@@ -15,6 +20,8 @@ mlir_target_link_libraries(MLIRInterfacesTests
1520
MLIRFuncDialect
1621
MLIRInferIntRangeInterface
1722
MLIRInferTypeOpInterface
23+
MLIRMemorySlotUtils
1824
MLIRParser
1925
MLIRSideEffectInterfaces
2026
)
27+
target_link_libraries(MLIRInterfacesTests PRIVATE MLIRTestDialect)

0 commit comments

Comments
 (0)