Skip to content

Commit edb0ed5

Browse files
committed
canonical ordering for pow and inv
1 parent bee4465 commit edb0ed5

2 files changed

Lines changed: 64 additions & 4 deletions

File tree

mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,28 @@ struct MoveCtrlOutside final : OpRewritePattern<InvOp> {
5555
}
5656
};
5757

58+
/**
59+
* @brief Reorder inv around pow, i.e., `inv(pow(p, g)) => pow(p, inv(g))`.
60+
*/
61+
struct MovePowOutside final : OpRewritePattern<InvOp> {
62+
using OpRewritePattern::OpRewritePattern;
63+
LogicalResult matchAndRewrite(InvOp invOp,
64+
PatternRewriter& rewriter) const override {
65+
auto innerPow =
66+
llvm::dyn_cast<PowOp>(invOp.getBodyUnitary().getOperation());
67+
if (!innerPow) {
68+
return failure();
69+
}
70+
const double exponent = innerPow.getExponentValue();
71+
rewriter.replaceOpWithNewOp<PowOp>(invOp, exponent, [&] {
72+
InvOp::create(rewriter, invOp.getLoc(), [&] {
73+
rewriter.clone(*innerPow.getBodyUnitary().getOperation());
74+
});
75+
});
76+
return success();
77+
}
78+
};
79+
5880
/**
5981
* @brief Remove inverse modifiers around self-adjoint gates.
6082
*
@@ -299,6 +321,6 @@ LogicalResult InvOp::verify() {
299321

300322
void InvOp::getCanonicalizationPatterns(RewritePatternSet& results,
301323
MLIRContext* context) {
302-
results.add<CancelNestedInv, MoveCtrlOutside, InlineSelfAdjoint,
303-
ReplaceWithKnownGates>(context);
324+
results.add<CancelNestedInv, MoveCtrlOutside, MovePowOutside,
325+
InlineSelfAdjoint, ReplaceWithKnownGates>(context);
304326
}

mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,44 @@ struct MoveCtrlOutside final : OpRewritePattern<InvOp> {
8181
}
8282
};
8383

84+
/**
85+
* @brief Reorder inv around pow, i.e., `inv(pow(p, g)) => pow(p, inv(g))`.
86+
*/
87+
struct MovePowOutside final : OpRewritePattern<InvOp> {
88+
using OpRewritePattern::OpRewritePattern;
89+
90+
LogicalResult matchAndRewrite(InvOp invOp,
91+
PatternRewriter& rewriter) const override {
92+
auto innerPow =
93+
llvm::dyn_cast<PowOp>(invOp.getBodyUnitary().getOperation());
94+
if (!innerPow) {
95+
return failure();
96+
}
97+
98+
const double exponent = innerPow.getExponentValue();
99+
100+
rewriter.replaceOpWithNewOp<PowOp>(
101+
invOp, invOp.getQubitsIn(), exponent,
102+
[&](ValueRange powArgs) -> llvm::SmallVector<Value> {
103+
return InvOp::create(
104+
rewriter, invOp.getLoc(), powArgs,
105+
[&](ValueRange invArgs) -> llvm::SmallVector<Value> {
106+
IRMapping mapping;
107+
auto* innerBody = innerPow.getBody();
108+
for (size_t i = 0; i < innerPow.getNumTargets(); ++i) {
109+
mapping.map(innerBody->getArgument(i), invArgs[i]);
110+
}
111+
return rewriter
112+
.clone(*innerPow.getBodyUnitary().getOperation(),
113+
mapping)
114+
->getResults();
115+
})
116+
.getResults();
117+
});
118+
return success();
119+
}
120+
};
121+
84122
/**
85123
* @brief Remove inverse modifiers around self-adjoint gates.
86124
*
@@ -401,8 +439,8 @@ LogicalResult InvOp::verify() {
401439

402440
void InvOp::getCanonicalizationPatterns(RewritePatternSet& results,
403441
MLIRContext* context) {
404-
results.add<MoveCtrlOutside, InlineSelfAdjoint, ReplaceWithKnownGates,
405-
CancelNestedInv>(context);
442+
results.add<MoveCtrlOutside, MovePowOutside, InlineSelfAdjoint,
443+
ReplaceWithKnownGates, CancelNestedInv>(context);
406444
}
407445

408446
std::optional<Eigen::MatrixXcd> InvOp::getUnitaryMatrix() {

0 commit comments

Comments
 (0)