|
14 | 14 | #include <llvm/Support/Casting.h> |
15 | 15 | #include <mlir/Dialect/Arith/IR/Arith.h> |
16 | 16 | #include <mlir/IR/Builders.h> |
| 17 | +#include <mlir/IR/IRMapping.h> |
17 | 18 | #include <mlir/IR/MLIRContext.h> |
18 | 19 | #include <mlir/IR/OperationSupport.h> |
19 | 20 | #include <mlir/IR/PatternMatch.h> |
20 | 21 | #include <mlir/Support/LLVM.h> |
21 | 22 | #include <mlir/Support/LogicalResult.h> |
22 | 23 |
|
| 24 | +#include <cmath> |
| 25 | +#include <numbers> |
| 26 | + |
23 | 27 | using namespace mlir; |
24 | 28 | using namespace mlir::qc; |
25 | 29 |
|
@@ -104,6 +108,248 @@ struct MoveCtrlOutside final : OpRewritePattern<PowOp> { |
104 | 108 | } |
105 | 109 | }; |
106 | 110 |
|
| 111 | +/// Check if a floating-point value is an integer |
| 112 | +bool isIntegerExponent(double r) { |
| 113 | + return r == std::floor(r) && std::isfinite(r); |
| 114 | +} |
| 115 | + |
| 116 | +/// Materialize r * constant as an arith.constant |
| 117 | +Value mulConst(double r, double c, PowOp op, PatternRewriter& rewriter) { |
| 118 | + return arith::ConstantOp::create(rewriter, op.getLoc(), |
| 119 | + rewriter.getF64FloatAttr(r * c)); |
| 120 | +} |
| 121 | + |
| 122 | +/// Materialize exponent * param as arith ops |
| 123 | +Value scaleByExponent(Value param, PowOp op, PatternRewriter& rewriter) { |
| 124 | + auto loc = op.getLoc(); |
| 125 | + auto exponent = |
| 126 | + arith::ConstantOp::create(rewriter, loc, op.getExponentAttr()); |
| 127 | + return arith::MulFOp::create(rewriter, loc, exponent, param); |
| 128 | +} |
| 129 | + |
| 130 | +template <typename GateOp> |
| 131 | +LogicalResult replaceOneTargetOneParam(auto theta, PowOp op, |
| 132 | + PatternRewriter& rewriter) { |
| 133 | + rewriter.replaceOpWithNewOp<GateOp>(op, op.getTarget(0), theta); |
| 134 | + return success(); |
| 135 | +} |
| 136 | + |
| 137 | +template <typename GateOp> |
| 138 | +LogicalResult replaceTwoTargetsOneParam(auto theta, PowOp op, |
| 139 | + PatternRewriter& rewriter) { |
| 140 | + rewriter.replaceOpWithNewOp<GateOp>(op, op.getTarget(0), op.getTarget(1), |
| 141 | + theta); |
| 142 | + return success(); |
| 143 | +} |
| 144 | + |
| 145 | +template <typename GateOp> |
| 146 | +LogicalResult replaceOneTargetTwoParams(auto theta, auto phi, PowOp op, |
| 147 | + PatternRewriter& rewriter) { |
| 148 | + rewriter.replaceOpWithNewOp<GateOp>(op, op.getTarget(0), theta, phi); |
| 149 | + return success(); |
| 150 | +} |
| 151 | + |
| 152 | +template <typename GateOp> |
| 153 | +LogicalResult replaceTwoTargetsTwoParams(auto theta, auto beta, PowOp op, |
| 154 | + PatternRewriter& rewriter) { |
| 155 | + rewriter.replaceOpWithNewOp<GateOp>(op, op.getTarget(0), op.getTarget(1), |
| 156 | + theta, beta); |
| 157 | + return success(); |
| 158 | +} |
| 159 | + |
| 160 | +/** |
| 161 | + * @brief Fold pow(r) around gates into simpler operations. |
| 162 | + * |
| 163 | + * Rotation gates: multiply angle by exponent, e.g., pow(r) { rx(θ) } → rx(r*θ) |
| 164 | + * Phase/diagonal gates: convert to P gate, e.g., pow(r) { s } → p(r*π/2) |
| 165 | + * Hermitian gates (integer exponent): even → erase, odd → gate |
| 166 | + * Identity/barrier: pass through unchanged |
| 167 | + */ |
| 168 | +struct FoldPowIntoGate final : OpRewritePattern<PowOp> { |
| 169 | + using OpRewritePattern::OpRewritePattern; |
| 170 | + |
| 171 | + LogicalResult matchAndRewrite(PowOp op, |
| 172 | + PatternRewriter& rewriter) const override { |
| 173 | + auto* innerOp = op.getBodyUnitary().getOperation(); |
| 174 | + const double r = op.getExponentValue(); |
| 175 | + auto loc = op.getLoc(); |
| 176 | + |
| 177 | + // Move supporting ops (constants, arithmetic) out of the body so their |
| 178 | + // Values are accessible from outside and survive PowOp erasure. |
| 179 | + for (auto& bodyOp : llvm::make_early_inc_range(*op.getBody())) { |
| 180 | + if (&bodyOp != innerOp && !llvm::isa<YieldOp>(&bodyOp)) { |
| 181 | + rewriter.moveOpBefore(&bodyOp, op); |
| 182 | + } |
| 183 | + } |
| 184 | + |
| 185 | + return llvm::TypeSwitch<Operation*, LogicalResult>(innerOp) |
| 186 | + // --- Rotation gates: multiply angle by exponent --- |
| 187 | + // pow(r) { gphase(θ) } → gphase(r*θ) |
| 188 | + .Case<GPhaseOp>([&](auto gate) { |
| 189 | + auto newParam = scaleByExponent(gate.getTheta(), op, rewriter); |
| 190 | + rewriter.replaceOpWithNewOp<GPhaseOp>(op, newParam); |
| 191 | + return success(); |
| 192 | + }) |
| 193 | + // pow(r) { rx/ry/rz/p(θ) } → rx/ry/rz/p(r*θ) |
| 194 | + .Case<RXOp, RYOp, RZOp, POp>([&](auto gate) { |
| 195 | + auto newParam = scaleByExponent(gate.getTheta(), op, rewriter); |
| 196 | + return replaceOneTargetOneParam<decltype(gate)>(newParam, op, |
| 197 | + rewriter); |
| 198 | + }) |
| 199 | + // pow(r) { rxx/ryy/rzx/rzz(θ) } → rxx/ryy/rzx/rzz(r*θ) |
| 200 | + .Case<RXXOp, RYYOp, RZXOp, RZZOp>([&](auto gate) { |
| 201 | + auto newParam = scaleByExponent(gate.getTheta(), op, rewriter); |
| 202 | + return replaceTwoTargetsOneParam<decltype(gate)>(newParam, op, |
| 203 | + rewriter); |
| 204 | + }) |
| 205 | + // pow(r) { r(θ, φ) } → r(r*θ, φ) |
| 206 | + .Case<ROp>([&](auto gate) { |
| 207 | + auto mul = scaleByExponent(gate.getTheta(), op, rewriter); |
| 208 | + return replaceOneTargetTwoParams<ROp>(mul, gate.getPhi(), op, |
| 209 | + rewriter); |
| 210 | + }) |
| 211 | + // pow(r) { xx±yy(θ, β) } → xx±yy(r*θ, β) |
| 212 | + .Case<XXPlusYYOp, XXMinusYYOp>([&](auto gate) { |
| 213 | + auto mul = scaleByExponent(gate.getTheta(), op, rewriter); |
| 214 | + return replaceTwoTargetsTwoParams<decltype(gate)>(mul, gate.getBeta(), |
| 215 | + op, rewriter); |
| 216 | + }) |
| 217 | + // --- Pauli gates: decompose to rotation + global phase --- |
| 218 | + // pow(r) { z } → p(r*π) |
| 219 | + .Case<ZOp>([&](auto) { |
| 220 | + rewriter.replaceOpWithNewOp<POp>( |
| 221 | + op, op.getTarget(0), mulConst(r, std::numbers::pi, op, rewriter)); |
| 222 | + return success(); |
| 223 | + }) |
| 224 | + // pow(r) { x } → gphase(-r*π/2); rx(r*π) |
| 225 | + .Case<XOp>([&](auto) { |
| 226 | + if (llvm::isa<CtrlOp, InvOp, PowOp>(op->getParentOp())) { |
| 227 | + return failure(); |
| 228 | + } |
| 229 | + GPhaseOp::create(rewriter, loc, |
| 230 | + mulConst(r, -std::numbers::pi / 2.0, op, rewriter)); |
| 231 | + rewriter.replaceOpWithNewOp<RXOp>( |
| 232 | + op, op.getTarget(0), mulConst(r, std::numbers::pi, op, rewriter)); |
| 233 | + return success(); |
| 234 | + }) |
| 235 | + // pow(r) { y } → gphase(-r*π/2); ry(r*π) |
| 236 | + .Case<YOp>([&](auto) { |
| 237 | + if (llvm::isa<CtrlOp, InvOp, PowOp>(op->getParentOp())) { |
| 238 | + return failure(); |
| 239 | + } |
| 240 | + GPhaseOp::create(rewriter, loc, |
| 241 | + mulConst(r, -std::numbers::pi / 2.0, op, rewriter)); |
| 242 | + rewriter.replaceOpWithNewOp<RYOp>( |
| 243 | + op, op.getTarget(0), mulConst(r, std::numbers::pi, op, rewriter)); |
| 244 | + return success(); |
| 245 | + }) |
| 246 | + // --- Phase/diagonal gates: convert to P gate --- |
| 247 | + // pow(r) { s } → p(r*π/2) |
| 248 | + .Case<SOp>([&](auto) { |
| 249 | + rewriter.replaceOpWithNewOp<POp>( |
| 250 | + op, op.getTarget(0), |
| 251 | + mulConst(r, std::numbers::pi / 2.0, op, rewriter)); |
| 252 | + return success(); |
| 253 | + }) |
| 254 | + // pow(r) { sdg } → p(-r*π/2) |
| 255 | + .Case<SdgOp>([&](auto) { |
| 256 | + rewriter.replaceOpWithNewOp<POp>( |
| 257 | + op, op.getTarget(0), |
| 258 | + mulConst(r, -std::numbers::pi / 2.0, op, rewriter)); |
| 259 | + return success(); |
| 260 | + }) |
| 261 | + // pow(r) { t } → p(r*π/4) |
| 262 | + .Case<TOp>([&](auto) { |
| 263 | + rewriter.replaceOpWithNewOp<POp>( |
| 264 | + op, op.getTarget(0), |
| 265 | + mulConst(r, std::numbers::pi / 4.0, op, rewriter)); |
| 266 | + return success(); |
| 267 | + }) |
| 268 | + // pow(r) { tdg } → p(-r*π/4) |
| 269 | + .Case<TdgOp>([&](auto) { |
| 270 | + rewriter.replaceOpWithNewOp<POp>( |
| 271 | + op, op.getTarget(0), |
| 272 | + mulConst(r, -std::numbers::pi / 4.0, op, rewriter)); |
| 273 | + return success(); |
| 274 | + }) |
| 275 | + // --- SX/SXdg gates: decompose to rotation + global phase --- |
| 276 | + // pow(r) { sx } → gphase(-r*π/4); rx(r*π/2) |
| 277 | + .Case<SXOp>([&](auto) { |
| 278 | + if (llvm::isa<CtrlOp, InvOp, PowOp>(op->getParentOp())) { |
| 279 | + return failure(); |
| 280 | + } |
| 281 | + GPhaseOp::create(rewriter, loc, |
| 282 | + mulConst(r, -std::numbers::pi / 4.0, op, rewriter)); |
| 283 | + rewriter.replaceOpWithNewOp<RXOp>( |
| 284 | + op, op.getTarget(0), |
| 285 | + mulConst(r, std::numbers::pi / 2.0, op, rewriter)); |
| 286 | + return success(); |
| 287 | + }) |
| 288 | + // pow(r) { sxdg } → gphase(r*π/4); rx(-r*π/2) |
| 289 | + .Case<SXdgOp>([&](auto) { |
| 290 | + if (llvm::isa<CtrlOp, InvOp, PowOp>(op->getParentOp())) { |
| 291 | + return failure(); |
| 292 | + } |
| 293 | + GPhaseOp::create(rewriter, loc, |
| 294 | + mulConst(r, std::numbers::pi / 4.0, op, rewriter)); |
| 295 | + rewriter.replaceOpWithNewOp<RXOp>( |
| 296 | + op, op.getTarget(0), |
| 297 | + mulConst(r, -std::numbers::pi / 2.0, op, rewriter)); |
| 298 | + return success(); |
| 299 | + }) |
| 300 | + // --- Hermitian gates (integer exponent): even → erase, odd → gate --- |
| 301 | + // pow(n) { h/ecr } → erase (n even) | h/ecr (n odd) |
| 302 | + .Case<HOp, ECROp>([&](auto gate) { |
| 303 | + if (!isIntegerExponent(r)) { |
| 304 | + return failure(); |
| 305 | + } |
| 306 | + const auto n = static_cast<int64_t>(r); |
| 307 | + if (n % 2 == 0) { |
| 308 | + rewriter.eraseOp(op); |
| 309 | + } else { |
| 310 | + rewriter.moveOpBefore(gate, op); |
| 311 | + rewriter.eraseOp(op); |
| 312 | + } |
| 313 | + return success(); |
| 314 | + }) |
| 315 | + // pow(n) { swap } → erase (n even) | swap (n odd) |
| 316 | + .Case<SWAPOp>([&](auto gate) { |
| 317 | + if (!isIntegerExponent(r)) { |
| 318 | + return failure(); |
| 319 | + } |
| 320 | + const auto n = static_cast<int64_t>(r); |
| 321 | + if (n % 2 == 0) { |
| 322 | + rewriter.eraseOp(op); |
| 323 | + } else { |
| 324 | + rewriter.moveOpBefore(gate, op); |
| 325 | + rewriter.eraseOp(op); |
| 326 | + } |
| 327 | + return success(); |
| 328 | + }) |
| 329 | + // --- iSWAP: decompose to parametric gate --- |
| 330 | + // pow(r) { iswap } → xx_plus_yy(-r*π, 0) |
| 331 | + .Case<iSWAPOp>([&](auto) { |
| 332 | + rewriter.replaceOpWithNewOp<XXPlusYYOp>( |
| 333 | + op, op.getTarget(0), op.getTarget(1), |
| 334 | + mulConst(r, -std::numbers::pi, op, rewriter), |
| 335 | + mulConst(r, 0.0, op, rewriter)); |
| 336 | + return success(); |
| 337 | + }) |
| 338 | + // --- Identity and barrier: pass through unchanged --- |
| 339 | + // pow(r) { id } → id |
| 340 | + .Case<IdOp>([&](auto) { |
| 341 | + rewriter.replaceOpWithNewOp<IdOp>(op, op.getTarget(0)); |
| 342 | + return success(); |
| 343 | + }) |
| 344 | + // pow(r) { barrier } → barrier |
| 345 | + .Case<BarrierOp>([&](auto gate) { |
| 346 | + rewriter.replaceOpWithNewOp<BarrierOp>(op, gate.getTargets()); |
| 347 | + return success(); |
| 348 | + }) |
| 349 | + .Default([&](auto) { return failure(); }); |
| 350 | + } |
| 351 | +}; |
| 352 | + |
107 | 353 | } // namespace |
108 | 354 |
|
109 | 355 | UnitaryOpInterface PowOp::getBodyUnitary() { |
@@ -149,4 +395,6 @@ void PowOp::getCanonicalizationPatterns(RewritePatternSet& results, |
149 | 395 | MLIRContext* context) { |
150 | 396 | results.add<InlinePow1, ErasePow0, NegPowToInvPow, MergeNestedPow, |
151 | 397 | MoveCtrlOutside>(context); |
| 398 | + // Prefer Known Gate optimizations over everything else |
| 399 | + results.add<FoldPowIntoGate>(context, /*benefit=*/2); |
152 | 400 | } |
0 commit comments