Skip to content

Commit 1597298

Browse files
♻️ Apply CRTP to Unit class (#1379)
## Description This pull requests updates the unit implementations such that the base class `Unit` is an actual interface. For that purpose, I implemented the Curiously Recurring Template Pattern ([CRTP](https://en.cppreference.com/w/cpp/language/crtp.html)) to avoid virtual tables. This pattern is commonly used in the LLVM code base so it is also a good fit here, I think. I've also removed the `restore` flag from the Unit because the routing pass already knows when to restore and when not to. It didn't really serve any purpose in the unit itself. ## Rationale While I was writing the respective section for my thesis, I noticed that the `Unit` class is semantically an interface - but the implementation isn't. Hence, this PR. I think units could be useful elsewhere also, with potentially different strategies, so refactoring the class to an interface makes sense imo. ## Checklist: <!--- This checklist serves as a reminder of a couple of things that ensure your pull request will be merged swiftly. --> - [x] The pull request only contains commits that are focused and relevant to this change. - [x] I have added appropriate tests that cover the new/changed functionality. - [x] I have updated the documentation to reflect these changes. - [ ] I have added entries to the changelog for any noteworthy additions, changes, fixes, or removals. - [x] I have added migration instructions to the upgrade guide (if needed). - [x] The changes follow the project's style guidelines and introduce no new warnings. - [x] The changes are fully tested and pass the CI checks. - [x] I have reviewed my own code changes.
2 parents 68391db + 4750804 commit 1597298

8 files changed

Lines changed: 60 additions & 57 deletions

File tree

mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,13 @@ struct Layer {
4343
};
4444

4545
/// @brief A LayeredUnit traverses a program layer-by-layer.
46-
class LayeredUnit : public Unit {
46+
class LayeredUnit : public Unit<LayeredUnit> {
4747
public:
48-
using Layers = mlir::SmallVector<Layer, 0>;
49-
5048
[[nodiscard]] static LayeredUnit
5149
fromEntryPointFunction(mlir::func::FuncOp func, std::size_t nqubits);
5250

53-
LayeredUnit(Layout layout, mlir::Region* region, bool restore = false);
51+
LayeredUnit(Layout layout, mlir::Region* region);
5452

55-
[[nodiscard]] mlir::SmallVector<LayeredUnit, 3> next();
56-
[[nodiscard]] Layers::const_iterator begin() const { return layers_.begin(); }
57-
[[nodiscard]] Layers::const_iterator end() const { return layers_.end(); }
5853
[[nodiscard]] const Layer& operator[](std::size_t i) const {
5954
return layers_[i];
6055
}
@@ -65,6 +60,14 @@ class LayeredUnit : public Unit {
6560
#endif
6661

6762
private:
63+
friend class Unit<LayeredUnit>;
64+
using Layers = mlir::SmallVector<Layer, 0>;
65+
using const_iterator = Layers::const_iterator;
66+
67+
[[nodiscard]] mlir::SmallVector<LayeredUnit, 3> nextImpl();
68+
[[nodiscard]] const_iterator beginImpl() const { return layers_.begin(); }
69+
[[nodiscard]] const_iterator endImpl() const { return layers_.end(); }
70+
6871
Layers layers_;
6972
};
7073
} // namespace mqt::ir::opt

mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/SequentialUnit.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,24 @@
2222
namespace mqt::ir::opt {
2323

2424
/// @brief A SequentialUnit traverses a program sequentially.
25-
class SequentialUnit : public Unit {
25+
class SequentialUnit : public Unit<SequentialUnit> {
2626
public:
2727
[[nodiscard]] static SequentialUnit
2828
fromEntryPointFunction(mlir::func::FuncOp func, std::size_t nqubits);
2929

3030
SequentialUnit(Layout layout, mlir::Region* region,
31-
mlir::Region::OpIterator start, bool restore = false);
31+
mlir::Region::OpIterator start);
3232

33-
SequentialUnit(Layout layout, mlir::Region* region, bool restore = false)
34-
: SequentialUnit(std::move(layout), region, region->op_begin(), restore) {
35-
}
36-
37-
[[nodiscard]] mlir::SmallVector<SequentialUnit, 3> next();
38-
[[nodiscard]] mlir::Region::OpIterator begin() const { return start_; }
39-
[[nodiscard]] mlir::Region::OpIterator end() const { return end_; }
33+
SequentialUnit(Layout layout, mlir::Region* region)
34+
: SequentialUnit(std::move(layout), region, region->op_begin()) {}
4035

4136
private:
37+
friend class Unit<SequentialUnit>;
38+
39+
[[nodiscard]] mlir::SmallVector<SequentialUnit, 3> nextImpl();
40+
[[nodiscard]] mlir::Region::OpIterator beginImpl() const { return start_; }
41+
[[nodiscard]] mlir::Region::OpIterator endImpl() const { return end_; }
42+
4243
mlir::Region::OpIterator start_;
4344
mlir::Region::OpIterator end_;
4445
};

mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/Unit.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,36 @@
1717
namespace mqt::ir::opt {
1818

1919
/// @brief A Unit divides a quantum-classical program into routable sections.
20-
class Unit {
20+
template <class Derived> class Unit {
2121
public:
22-
Unit(Layout layout, mlir::Region* region, bool restore = false)
23-
: layout_(std::move(layout)), region_(region), restore_(restore) {}
22+
/// @brief Compute and return subsequent units.
23+
[[nodiscard]] mlir::SmallVector<Derived, 3> next() {
24+
return static_cast<Derived*>(this)->nextImpl();
25+
}
26+
27+
/// @returns an iterator pointing at the first element of the unit.
28+
[[nodiscard]] auto begin() const {
29+
return static_cast<const Derived*>(this)->beginImpl();
30+
}
31+
32+
/// @returns an iterator pointing at the past-the-end position.
33+
[[nodiscard]] auto end() const {
34+
return static_cast<const Derived*>(this)->endImpl();
35+
}
2436

2537
/// @returns the managed layout.
2638
[[nodiscard]] Layout& layout() { return layout_; }
2739

28-
/// @returns true iff. the unit has to be restored.
29-
[[nodiscard]] bool restore() const { return restore_; }
30-
3140
protected:
41+
Unit(Layout layout, mlir::Region* region)
42+
: layout_(std::move(layout)), region_(region) {}
43+
3244
/// @brief The layout this unit manages.
3345
Layout layout_;
3446
/// @brief The region this unit belongs to.
3547
mlir::Region* region_;
3648
/// @brief Pointer to the next dividing operation.
3749
mlir::Operation* divider_{};
38-
/// @brief Whether to uncompute the inserted SWAP sequence.
39-
bool restore_;
4050
};
4151

4252
} // namespace mqt::ir::opt

mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ LayeredUnit LayeredUnit::fromEntryPointFunction(mlir::func::FuncOp func,
141141
return {std::move(layout), &func.getBody()};
142142
}
143143

144-
LayeredUnit::LayeredUnit(Layout layout, mlir::Region* region, bool restore)
145-
: Unit(std::move(layout), region, restore) {
144+
LayeredUnit::LayeredUnit(Layout layout, mlir::Region* region)
145+
: Unit(std::move(layout), region) {
146146
SynchronizationMap sync;
147147

148148
mlir::SmallVector<Wire, 0> curr;
@@ -272,7 +272,7 @@ LayeredUnit::LayeredUnit(Layout layout, mlir::Region* region, bool restore)
272272
};
273273
}
274274

275-
mlir::SmallVector<LayeredUnit, 3> LayeredUnit::next() {
275+
mlir::SmallVector<LayeredUnit, 3> LayeredUnit::nextImpl() {
276276
if (divider_ == nullptr) {
277277
return {};
278278
}
@@ -283,14 +283,14 @@ mlir::SmallVector<LayeredUnit, 3> LayeredUnit::next() {
283283
Layout forLayout(layout_); // Copy layout.
284284
forLayout.remapToLoopBody(op);
285285
layout_.remapToLoopResults(op);
286-
units.emplace_back(std::move(layout_), region_, restore_);
287-
units.emplace_back(std::move(forLayout), &op.getRegion(), true);
286+
units.emplace_back(std::move(layout_), region_);
287+
units.emplace_back(std::move(forLayout), &op.getRegion());
288288
})
289289
.Case<mlir::scf::IfOp>([&](mlir::scf::IfOp op) {
290-
units.emplace_back(layout_, &op.getThenRegion(), true);
291-
units.emplace_back(layout_, &op.getElseRegion(), true);
290+
units.emplace_back(layout_, &op.getThenRegion());
291+
units.emplace_back(layout_, &op.getElseRegion());
292292
layout_.remapIfResults(op);
293-
units.emplace_back(layout_, region_, restore_);
293+
units.emplace_back(layout_, region_);
294294
})
295295
.Default([](auto) { llvm_unreachable("invalid 'next' operation"); });
296296

mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/SequentialUnit.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ SequentialUnit::fromEntryPointFunction(mlir::func::FuncOp func,
4040
}
4141

4242
SequentialUnit::SequentialUnit(Layout layout, mlir::Region* region,
43-
mlir::Region::OpIterator start, bool restore)
44-
: Unit(std::move(layout), region, restore), start_(start),
45-
end_(region->op_end()) {
43+
mlir::Region::OpIterator start)
44+
: Unit(std::move(layout), region), start_(start), end_(region->op_end()) {
4645
mlir::Region::OpIterator it = start_;
4746
for (; it != end_; ++it) {
4847
mlir::Operation* op = &*it;
@@ -54,7 +53,7 @@ SequentialUnit::SequentialUnit(Layout layout, mlir::Region* region,
5453
end_ = it;
5554
}
5655

57-
mlir::SmallVector<SequentialUnit, 3> SequentialUnit::next() {
56+
mlir::SmallVector<SequentialUnit, 3> SequentialUnit::nextImpl() {
5857
if (divider_ == nullptr) {
5958
return {};
6059
}
@@ -65,16 +64,14 @@ mlir::SmallVector<SequentialUnit, 3> SequentialUnit::next() {
6564
Layout forLayout(layout_); // Copy layout.
6665
forLayout.remapToLoopBody(op);
6766
layout_.remapToLoopResults(op);
68-
units.emplace_back(std::move(layout_), region_, std::next(end_),
69-
restore_);
70-
units.emplace_back(std::move(forLayout), &op.getRegion(), true);
67+
units.emplace_back(std::move(layout_), region_, std::next(end_));
68+
units.emplace_back(std::move(forLayout), &op.getRegion());
7169
})
7270
.Case<mlir::scf::IfOp>([&](mlir::scf::IfOp op) {
73-
units.emplace_back(layout_, &op.getThenRegion(), true);
74-
units.emplace_back(layout_, &op.getElseRegion(), true);
71+
units.emplace_back(layout_, &op.getThenRegion());
72+
units.emplace_back(layout_, &op.getElseRegion());
7573
layout_.remapIfResults(op);
76-
units.emplace_back(std::move(layout_), region_, std::next(end_),
77-
restore_);
74+
units.emplace_back(std::move(layout_), region_, std::next(end_));
7875
})
7976
.Default([](auto) { llvm_unreachable("invalid 'next' operation"); });
8077

mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/AStarRoutingPass.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,9 @@ struct AStarRoutingPassSC final
166166
.Case<ResetOp>([&](ResetOp op) { unit.layout().remap(op); })
167167
.Case<MeasureOp>([&](MeasureOp op) { unit.layout().remap(op); })
168168
.Case<scf::YieldOp>([&](scf::YieldOp op) {
169-
if (unit.restore()) {
170-
rewriter.setInsertionPoint(op);
171-
insertSWAPs(op.getLoc(), llvm::reverse(history),
172-
unit.layout(), rewriter);
173-
}
169+
rewriter.setInsertionPoint(op);
170+
insertSWAPs(op.getLoc(), llvm::reverse(history),
171+
unit.layout(), rewriter);
174172
})
175173
.Default([](auto) {
176174
llvm_unreachable("unhandled 'curr' operation");

mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/NaiveRoutingPass.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,9 @@ struct NaiveRoutingPassSC final
141141
.Case<ResetOp>([&](ResetOp op) { unit.layout().remap(op); })
142142
.Case<MeasureOp>([&](MeasureOp op) { unit.layout().remap(op); })
143143
.Case<scf::YieldOp>([&](scf::YieldOp op) {
144-
if (unit.restore()) {
145-
rewriter.setInsertionPointAfter(op->getPrevNode());
146-
insertSWAPs(op.getLoc(), llvm::reverse(history),
147-
unit.layout(), rewriter);
148-
}
144+
rewriter.setInsertionPoint(op);
145+
insertSWAPs(op.getLoc(), llvm::reverse(history), unit.layout(),
146+
rewriter);
149147
});
150148
}
151149

mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/RoutingVerificationPass.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,6 @@ struct RoutingVerificationPassSC final
125125
return success();
126126
})
127127
.Case<scf::YieldOp>([&](scf::YieldOp op) -> LogicalResult {
128-
if (!unit.restore()) {
129-
return success();
130-
}
131-
132128
/// Verify that the layouts match at the end.
133129
const auto mappingBefore = unmodified.getCurrentLayout();
134130
const auto mappingNow = unit.layout().getCurrentLayout();

0 commit comments

Comments
 (0)