Skip to content

Commit f0faa17

Browse files
committed
add Known Gate Specializations
1 parent aa22f38 commit f0faa17

2 files changed

Lines changed: 497 additions & 0 deletions

File tree

mlir/lib/Dialect/QC/IR/Modifiers/PowOp.cpp

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
#include <llvm/Support/Casting.h>
1515
#include <mlir/Dialect/Arith/IR/Arith.h>
1616
#include <mlir/IR/Builders.h>
17+
#include <mlir/IR/IRMapping.h>
1718
#include <mlir/IR/MLIRContext.h>
1819
#include <mlir/IR/OperationSupport.h>
1920
#include <mlir/IR/PatternMatch.h>
2021
#include <mlir/Support/LLVM.h>
2122
#include <mlir/Support/LogicalResult.h>
2223

24+
#include <cmath>
25+
#include <numbers>
26+
2327
using namespace mlir;
2428
using namespace mlir::qc;
2529

@@ -104,6 +108,248 @@ struct MoveCtrlOutside final : OpRewritePattern<PowOp> {
104108
}
105109
};
106110

111+
/// Check if a floating-point value is an integer
112+
bool isIntegerExponent(double r) {
113+
return r == std::floor(r) && std::isfinite(r);
114+
}
115+
116+
/// Materialize r * constant as an arith.constant
117+
Value mulConst(double r, double c, PowOp op, PatternRewriter& rewriter) {
118+
return arith::ConstantOp::create(rewriter, op.getLoc(),
119+
rewriter.getF64FloatAttr(r * c));
120+
}
121+
122+
/// Materialize exponent * param as arith ops
123+
Value scaleByExponent(Value param, PowOp op, PatternRewriter& rewriter) {
124+
auto loc = op.getLoc();
125+
auto exponent =
126+
arith::ConstantOp::create(rewriter, loc, op.getExponentAttr());
127+
return arith::MulFOp::create(rewriter, loc, exponent, param);
128+
}
129+
130+
template <typename GateOp>
131+
LogicalResult replaceOneTargetOneParam(auto theta, PowOp op,
132+
PatternRewriter& rewriter) {
133+
rewriter.replaceOpWithNewOp<GateOp>(op, op.getTarget(0), theta);
134+
return success();
135+
}
136+
137+
template <typename GateOp>
138+
LogicalResult replaceTwoTargetsOneParam(auto theta, PowOp op,
139+
PatternRewriter& rewriter) {
140+
rewriter.replaceOpWithNewOp<GateOp>(op, op.getTarget(0), op.getTarget(1),
141+
theta);
142+
return success();
143+
}
144+
145+
template <typename GateOp>
146+
LogicalResult replaceOneTargetTwoParams(auto theta, auto phi, PowOp op,
147+
PatternRewriter& rewriter) {
148+
rewriter.replaceOpWithNewOp<GateOp>(op, op.getTarget(0), theta, phi);
149+
return success();
150+
}
151+
152+
template <typename GateOp>
153+
LogicalResult replaceTwoTargetsTwoParams(auto theta, auto beta, PowOp op,
154+
PatternRewriter& rewriter) {
155+
rewriter.replaceOpWithNewOp<GateOp>(op, op.getTarget(0), op.getTarget(1),
156+
theta, beta);
157+
return success();
158+
}
159+
160+
/**
161+
* @brief Fold pow(r) around gates into simpler operations.
162+
*
163+
* Rotation gates: multiply angle by exponent, e.g., pow(r) { rx(θ) } → rx(r*θ)
164+
* Phase/diagonal gates: convert to P gate, e.g., pow(r) { s } → p(r*π/2)
165+
* Hermitian gates (integer exponent): even → erase, odd → gate
166+
* Identity/barrier: pass through unchanged
167+
*/
168+
struct FoldPowIntoGate final : OpRewritePattern<PowOp> {
169+
using OpRewritePattern::OpRewritePattern;
170+
171+
LogicalResult matchAndRewrite(PowOp op,
172+
PatternRewriter& rewriter) const override {
173+
auto* innerOp = op.getBodyUnitary().getOperation();
174+
const double r = op.getExponentValue();
175+
auto loc = op.getLoc();
176+
177+
// Move supporting ops (constants, arithmetic) out of the body so their
178+
// Values are accessible from outside and survive PowOp erasure.
179+
for (auto& bodyOp : llvm::make_early_inc_range(*op.getBody())) {
180+
if (&bodyOp != innerOp && !llvm::isa<YieldOp>(&bodyOp)) {
181+
rewriter.moveOpBefore(&bodyOp, op);
182+
}
183+
}
184+
185+
return llvm::TypeSwitch<Operation*, LogicalResult>(innerOp)
186+
// --- Rotation gates: multiply angle by exponent ---
187+
// pow(r) { gphase(θ) } → gphase(r*θ)
188+
.Case<GPhaseOp>([&](auto gate) {
189+
auto newParam = scaleByExponent(gate.getTheta(), op, rewriter);
190+
rewriter.replaceOpWithNewOp<GPhaseOp>(op, newParam);
191+
return success();
192+
})
193+
// pow(r) { rx/ry/rz/p(θ) } → rx/ry/rz/p(r*θ)
194+
.Case<RXOp, RYOp, RZOp, POp>([&](auto gate) {
195+
auto newParam = scaleByExponent(gate.getTheta(), op, rewriter);
196+
return replaceOneTargetOneParam<decltype(gate)>(newParam, op,
197+
rewriter);
198+
})
199+
// pow(r) { rxx/ryy/rzx/rzz(θ) } → rxx/ryy/rzx/rzz(r*θ)
200+
.Case<RXXOp, RYYOp, RZXOp, RZZOp>([&](auto gate) {
201+
auto newParam = scaleByExponent(gate.getTheta(), op, rewriter);
202+
return replaceTwoTargetsOneParam<decltype(gate)>(newParam, op,
203+
rewriter);
204+
})
205+
// pow(r) { r(θ, φ) } → r(r*θ, φ)
206+
.Case<ROp>([&](auto gate) {
207+
auto mul = scaleByExponent(gate.getTheta(), op, rewriter);
208+
return replaceOneTargetTwoParams<ROp>(mul, gate.getPhi(), op,
209+
rewriter);
210+
})
211+
// pow(r) { xx±yy(θ, β) } → xx±yy(r*θ, β)
212+
.Case<XXPlusYYOp, XXMinusYYOp>([&](auto gate) {
213+
auto mul = scaleByExponent(gate.getTheta(), op, rewriter);
214+
return replaceTwoTargetsTwoParams<decltype(gate)>(mul, gate.getBeta(),
215+
op, rewriter);
216+
})
217+
// --- Pauli gates: decompose to rotation + global phase ---
218+
// pow(r) { z } → p(r*π)
219+
.Case<ZOp>([&](auto) {
220+
rewriter.replaceOpWithNewOp<POp>(
221+
op, op.getTarget(0), mulConst(r, std::numbers::pi, op, rewriter));
222+
return success();
223+
})
224+
// pow(r) { x } → gphase(-r*π/2); rx(r*π)
225+
.Case<XOp>([&](auto) {
226+
if (llvm::isa<CtrlOp, InvOp, PowOp>(op->getParentOp())) {
227+
return failure();
228+
}
229+
GPhaseOp::create(rewriter, loc,
230+
mulConst(r, -std::numbers::pi / 2.0, op, rewriter));
231+
rewriter.replaceOpWithNewOp<RXOp>(
232+
op, op.getTarget(0), mulConst(r, std::numbers::pi, op, rewriter));
233+
return success();
234+
})
235+
// pow(r) { y } → gphase(-r*π/2); ry(r*π)
236+
.Case<YOp>([&](auto) {
237+
if (llvm::isa<CtrlOp, InvOp, PowOp>(op->getParentOp())) {
238+
return failure();
239+
}
240+
GPhaseOp::create(rewriter, loc,
241+
mulConst(r, -std::numbers::pi / 2.0, op, rewriter));
242+
rewriter.replaceOpWithNewOp<RYOp>(
243+
op, op.getTarget(0), mulConst(r, std::numbers::pi, op, rewriter));
244+
return success();
245+
})
246+
// --- Phase/diagonal gates: convert to P gate ---
247+
// pow(r) { s } → p(r*π/2)
248+
.Case<SOp>([&](auto) {
249+
rewriter.replaceOpWithNewOp<POp>(
250+
op, op.getTarget(0),
251+
mulConst(r, std::numbers::pi / 2.0, op, rewriter));
252+
return success();
253+
})
254+
// pow(r) { sdg } → p(-r*π/2)
255+
.Case<SdgOp>([&](auto) {
256+
rewriter.replaceOpWithNewOp<POp>(
257+
op, op.getTarget(0),
258+
mulConst(r, -std::numbers::pi / 2.0, op, rewriter));
259+
return success();
260+
})
261+
// pow(r) { t } → p(r*π/4)
262+
.Case<TOp>([&](auto) {
263+
rewriter.replaceOpWithNewOp<POp>(
264+
op, op.getTarget(0),
265+
mulConst(r, std::numbers::pi / 4.0, op, rewriter));
266+
return success();
267+
})
268+
// pow(r) { tdg } → p(-r*π/4)
269+
.Case<TdgOp>([&](auto) {
270+
rewriter.replaceOpWithNewOp<POp>(
271+
op, op.getTarget(0),
272+
mulConst(r, -std::numbers::pi / 4.0, op, rewriter));
273+
return success();
274+
})
275+
// --- SX/SXdg gates: decompose to rotation + global phase ---
276+
// pow(r) { sx } → gphase(-r*π/4); rx(r*π/2)
277+
.Case<SXOp>([&](auto) {
278+
if (llvm::isa<CtrlOp, InvOp, PowOp>(op->getParentOp())) {
279+
return failure();
280+
}
281+
GPhaseOp::create(rewriter, loc,
282+
mulConst(r, -std::numbers::pi / 4.0, op, rewriter));
283+
rewriter.replaceOpWithNewOp<RXOp>(
284+
op, op.getTarget(0),
285+
mulConst(r, std::numbers::pi / 2.0, op, rewriter));
286+
return success();
287+
})
288+
// pow(r) { sxdg } → gphase(r*π/4); rx(-r*π/2)
289+
.Case<SXdgOp>([&](auto) {
290+
if (llvm::isa<CtrlOp, InvOp, PowOp>(op->getParentOp())) {
291+
return failure();
292+
}
293+
GPhaseOp::create(rewriter, loc,
294+
mulConst(r, std::numbers::pi / 4.0, op, rewriter));
295+
rewriter.replaceOpWithNewOp<RXOp>(
296+
op, op.getTarget(0),
297+
mulConst(r, -std::numbers::pi / 2.0, op, rewriter));
298+
return success();
299+
})
300+
// --- Hermitian gates (integer exponent): even → erase, odd → gate ---
301+
// pow(n) { h/ecr } → erase (n even) | h/ecr (n odd)
302+
.Case<HOp, ECROp>([&](auto gate) {
303+
if (!isIntegerExponent(r)) {
304+
return failure();
305+
}
306+
const auto n = static_cast<int64_t>(r);
307+
if (n % 2 == 0) {
308+
rewriter.eraseOp(op);
309+
} else {
310+
rewriter.moveOpBefore(gate, op);
311+
rewriter.eraseOp(op);
312+
}
313+
return success();
314+
})
315+
// pow(n) { swap } → erase (n even) | swap (n odd)
316+
.Case<SWAPOp>([&](auto gate) {
317+
if (!isIntegerExponent(r)) {
318+
return failure();
319+
}
320+
const auto n = static_cast<int64_t>(r);
321+
if (n % 2 == 0) {
322+
rewriter.eraseOp(op);
323+
} else {
324+
rewriter.moveOpBefore(gate, op);
325+
rewriter.eraseOp(op);
326+
}
327+
return success();
328+
})
329+
// --- iSWAP: decompose to parametric gate ---
330+
// pow(r) { iswap } → xx_plus_yy(-r*π, 0)
331+
.Case<iSWAPOp>([&](auto) {
332+
rewriter.replaceOpWithNewOp<XXPlusYYOp>(
333+
op, op.getTarget(0), op.getTarget(1),
334+
mulConst(r, -std::numbers::pi, op, rewriter),
335+
mulConst(r, 0.0, op, rewriter));
336+
return success();
337+
})
338+
// --- Identity and barrier: pass through unchanged ---
339+
// pow(r) { id } → id
340+
.Case<IdOp>([&](auto) {
341+
rewriter.replaceOpWithNewOp<IdOp>(op, op.getTarget(0));
342+
return success();
343+
})
344+
// pow(r) { barrier } → barrier
345+
.Case<BarrierOp>([&](auto gate) {
346+
rewriter.replaceOpWithNewOp<BarrierOp>(op, gate.getTargets());
347+
return success();
348+
})
349+
.Default([&](auto) { return failure(); });
350+
}
351+
};
352+
107353
} // namespace
108354

109355
UnitaryOpInterface PowOp::getBodyUnitary() {
@@ -149,4 +395,6 @@ void PowOp::getCanonicalizationPatterns(RewritePatternSet& results,
149395
MLIRContext* context) {
150396
results.add<InlinePow1, ErasePow0, NegPowToInvPow, MergeNestedPow,
151397
MoveCtrlOutside>(context);
398+
// Prefer Known Gate optimizations over everything else
399+
results.add<FoldPowIntoGate>(context, /*benefit=*/2);
152400
}

0 commit comments

Comments
 (0)