diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f5b0045e82..ce4b2ed980 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -5,6 +5,13 @@

Improvements 🛠

+* The `decompose-lowering` pass now supports applying a selection of the available decomposition rules via the `target_rules` parameter. + The pass also no longer applies the `inline`, `cse` and `canonicalize` passes to avoid unnecessary IR mutations. + Instead, decomposition rules are deterministically inlined by a custom function (`inline` is non-deterministic, using an estimated benefit and threshold as criteria for inlining). + Decomposition rules are no longer removed after the `decompose-lowering` pass, which allows them to be used by subsequent passes, namely `graph-decomposition`. + Instead, rules are removed by the `symbol-dce` pass at the end of the `QuantumCompilationStage`. + [(#2973)](https://github.com/PennyLaneAI/catalyst/pull/2973) + * The new `pennylane.core.Operator2` can now be lowered to MLIR with program capture for operators without non-lowerable arguments. [(#2969)](https://github.com/PennyLaneAI/catalyst/pull/2969/) diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index e8da695ae2..16740f815c 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -189,6 +189,14 @@ def GraphDecompositionPass : Pass<"graph-decomposition", "mlir::ModuleOp"> { def DecomposeLoweringPass : Pass<"decompose-lowering"> { let summary = "Replace quantum operations with compiled decomposition rules."; + + let options = [ + ListOption< + /*C++ name*/"targetRulesOption", + /*CLI name*/"target-rules", + /*Type*/"std::string", + /*Description*/"The set of decomposition rules to apply. If empty, applies all rules."> + ]; } def DisentangleCNOTPass : Pass<"disentangle-cnot"> { diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index 4fdc0d421c..ce8d37d971 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -13,15 +13,26 @@ // limitations under the License. #include // std::move_backward +#include +#include +#include +#include +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "Quantum/IR/QuantumOps.h" #include "Quantum/IR/QuantumTypes.h" @@ -34,8 +45,8 @@ namespace catalyst { namespace quantum { // The goal of this class is to analyze the signature of a custom operation to get the enough -// information to prepare the call operands and results for replacing the op to calling the -// decomposition function. +// information to prepare the operands and results for replacing the op with the decomposition +// function. class BaseSignatureAnalyzer { protected: bool isValid = true; @@ -177,7 +188,7 @@ class BaseSignatureAnalyzer { return latestQreg; } - // Prepare the operands for calling the decomposition function + // Prepare the operands for the decomposition function // There are two cases: // 1. The first input is a qreg, which means the decomposition function is a qreg mode function // 2. Otherwise, the decomposition function is a qubit mode function @@ -187,10 +198,10 @@ class BaseSignatureAnalyzer { // - func(qreg, param*, inWires*, inCtrlWires*?, inCtrlValues*?) -> qreg // 2. qubit mode: // - func(param*, inQubits*, inCtrlQubits*?, inCtrlValues*?) -> outQubits* - llvm::SmallVector prepareCallOperands(func::FuncOp decompFunc, PatternRewriter &rewriter, - Location loc) + llvm::SmallVector prepareOperands(func::FuncOp rule, PatternRewriter &rewriter, + Location loc) { - auto funcType = decompFunc.getFunctionType(); + auto funcType = rule.getFunctionType(); auto funcInputs = funcType.getInputs(); SmallVector funcInputsNoQreg; @@ -202,10 +213,10 @@ class BaseSignatureAnalyzer { SmallVector operands(funcInputs.size()); - auto qregIt = llvm::find_if(decompFunc.getFunctionType().getInputs(), + auto qregIt = llvm::find_if(rule.getFunctionType().getInputs(), [](mlir::Type t) { return isa(t); }); - int qregIdx = std::distance(decompFunc.getFunctionType().getInputs().begin(), qregIt); - bool hasQreg = (qregIt != decompFunc.getFunctionType().getInputs().end()); + int qregIdx = std::distance(rule.getFunctionType().getInputs().begin(), qregIt); + bool hasQreg = (qregIt != rule.getFunctionType().getInputs().end()); int operandIdx = 0; if (!signature.params.empty()) { @@ -264,22 +275,18 @@ class BaseSignatureAnalyzer { return operands; } - // Prepare the results for the call operation - SmallVector prepareCallResultForQreg(func::CallOp callOp, PatternRewriter &rewriter) + // Prepare the results produced by a qreg-mode decomposition rule + SmallVector prepareResultsForQreg(Value qreg, Location loc, PatternRewriter &rewriter) { - assert(callOp.getNumResults() == 1 && "only one qreg result for qreg mode is allowed"); - - auto qreg = callOp.getResult(0); assert(isa(qreg.getType()) && "only allow to have qreg result"); SmallVector newResults; - rewriter.setInsertionPointAfter(callOp); for (const auto &indices : {signature.outQubitIndices, signature.outCtrlQubitIndices}) { for (const auto &index : indices) { auto extractOp = quantum::ExtractOp::create( - rewriter, callOp.getLoc(), rewriter.getType(), qreg, - index.getValue(), index.getAttr()); + rewriter, loc, rewriter.getType(), qreg, index.getValue(), + index.getAttr()); newResults.emplace_back(extractOp.getResult()); } } @@ -325,7 +332,7 @@ class BaseSignatureAnalyzer { return {startIdx, paramTypeEnd}; } - // generate params for calling the decomposition function based on function type requirements + // generate params for the decomposition function based on function type requirements SmallVector generateParams(ValueRange signatureParams, ArrayRef funcParamTypes, PatternRewriter &rewriter, Location loc) { diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index a4f0fdc30d..5f7c88878c 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -12,13 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define DEBUG_TYPE "decompose-lowering" +#include +#include +#include +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/AllocatorBase.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "Quantum/IR/QuantumInterfaces.h" #include "Quantum/IR/QuantumOps.h" @@ -26,12 +40,47 @@ #include "DecomposeLoweringImpl.hpp" +#define DEBUG_TYPE "decompose-lowering" + using namespace mlir; using namespace catalyst::quantum; namespace catalyst { namespace quantum { +SmallVector getDecompRuleResults(func::FuncOp rule, ValueRange operands, + PatternRewriter &rewriter) +{ + Block &body = rule.front(); + auto returnOp = cast(body.getTerminator()); + + IRMapping mapping; + mapping.map(body.getArguments(), operands); + + for (Operation &op : body.without_terminator()) { + rewriter.clone(op, mapping); + } + + SmallVector results; + for (Value operand : returnOp.getOperands()) { + results.push_back(mapping.lookupOrDefault(operand)); + } + return results; +} + +bool isInDecompRule(Operation *op) +{ + while (auto parentOp = op->getParentOp()) { + if (auto funcOp = dyn_cast(parentOp)) { + if (funcOp->hasAttr("target_gate")) { + return true; + } + } + op = parentOp; + } + return false; +} + struct DLCustomOpPattern : public OpRewritePattern { private: const llvm::StringMap &decompositionRegistry; @@ -54,18 +103,24 @@ struct DLCustomOpPattern : public OpRewritePattern { return failure(); } - // Find the corresponding decomposition function for the op + // do not nest decomposition rules, they're applied greedily and this can lead to + // cycles/identity rules + if (isInDecompRule(op)) { + return failure(); + } + + // Find the corresponding decomposition rule for the op auto it = decompositionRegistry.find(gateName); if (it == decompositionRegistry.end()) { return failure(); } - func::FuncOp decompFunc = it->second; + func::FuncOp rule = it->second; // For null decomp rules, the signature will not have any quantum values // This is a deviation from the standard decomp func signature, so we deal with it // separately - if (!llvm::any_of(llvm::concat(decompFunc.getFunctionType().getInputs(), - decompFunc.getFunctionType().getResults()), + if (!llvm::any_of(llvm::concat(rule.getFunctionType().getInputs(), + rule.getFunctionType().getResults()), [](const mlir::Type t) { return isa(t); })) { @@ -76,32 +131,32 @@ struct DLCustomOpPattern : public OpRewritePattern { return success(); } - // Here is the assumption that the decomposition function must have at least one input and + // Here is the assumption that the decomposition rule must have at least one input and // one result - assert(decompFunc.getFunctionType().getNumInputs() > 0 && + assert(rule.getFunctionType().getNumInputs() > 0 && "Decomposition function must have at least one input"); - assert(decompFunc.getFunctionType().getNumResults() >= 1 && + assert(rule.getFunctionType().getNumResults() >= 1 && "Decomposition function must have at least one result"); rewriter.setInsertionPointAfter(op); - auto enableQreg = llvm::any_of(decompFunc.getFunctionType().getInputs(), + auto enableQreg = llvm::any_of(rule.getFunctionType().getInputs(), [](mlir::Type t) { return isa(t); }); auto analyzer = CustomOpSignatureAnalyzer(op, enableQreg); assert(analyzer && "Analyzer should be valid"); - auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); - auto callOp = - func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(), - decompFunc.getSymName(), callOperands); + auto operands = analyzer.prepareOperands(rule, rewriter, op.getLoc()); + SmallVector inlinedFunctionResults = getDecompRuleResults(rule, operands, rewriter); - // Replace the op with the call op and adjust the insert ops for the qreg mode - if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) { - auto results = analyzer.prepareCallResultForQreg(callOp, rewriter); + // Replace the op with the inlined function and adjust the insert ops for the qreg mode + if (inlinedFunctionResults.size() == 1 && + isa(inlinedFunctionResults.front().getType())) { + auto results = analyzer.prepareResultsForQreg(inlinedFunctionResults.front(), + op.getLoc(), rewriter); rewriter.replaceOp(op, results); } else { - rewriter.replaceOp(op, callOp->getResults()); + rewriter.replaceOp(op, inlinedFunctionResults); } return success(); @@ -130,6 +185,12 @@ struct DLMultiRZOpPattern : public OpRewritePattern { return failure(); } + // do not nest decomposition rules, they're applied greedily and this can lead to + // cycles/identity rules + if (isInDecompRule(op)) { + return failure(); + } + // Find the corresponding decomposition function for the op auto numQubits = op.getInQubits().size(); auto MRZNameWithQubits = gateName + "_" + std::to_string(numQubits); @@ -165,18 +226,19 @@ struct DLMultiRZOpPattern : public OpRewritePattern { auto analyzer = MultiRZOpSignatureAnalyzer(op, enableQreg); assert(analyzer && "Analyzer should be valid"); - auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); - auto callOp = - func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(), - decompFunc.getSymName(), callOperands); + auto operands = analyzer.prepareOperands(decompFunc, rewriter, op.getLoc()); + SmallVector inlinedFunctionResults = + getDecompRuleResults(decompFunc, operands, rewriter); - // Replace the op with the call op and adjust the insert ops for the qreg mode - if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) { - auto results = analyzer.prepareCallResultForQreg(callOp, rewriter); + // Replace the op with the inlined function and adjust the insert ops for the qreg mode + if (inlinedFunctionResults.size() == 1 && + isa(inlinedFunctionResults.front().getType())) { + auto results = analyzer.prepareResultsForQreg(inlinedFunctionResults.front(), + op.getLoc(), rewriter); rewriter.replaceOp(op, results); } else { - rewriter.replaceOp(op, callOp->getResults()); + rewriter.replaceOp(op, inlinedFunctionResults); } return success(); @@ -205,6 +267,12 @@ struct DLPauliRotOpPattern : public OpRewritePattern { return failure(); } + // do not nest decomposition rules, they're applied greedily and this can lead to + // cycles/identity rules + if (isInDecompRule(op)) { + return failure(); + } + // Find the corresponding decomposition function for the op auto it = decompositionRegistry.find(gateName); if (it == decompositionRegistry.end()) { @@ -226,18 +294,19 @@ struct DLPauliRotOpPattern : public OpRewritePattern { auto analyzer = PauliRotOpSignatureAnalyzer(op, enableQreg); assert(analyzer && "Analyzer should be valid"); - auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); - auto callOp = - func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(), - decompFunc.getSymName(), callOperands); + auto operands = analyzer.prepareOperands(decompFunc, rewriter, op.getLoc()); + SmallVector inlinedFunctionResults = + getDecompRuleResults(decompFunc, operands, rewriter); - // Replace the op with the call op and adjust the insert ops for the qreg mode - if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) { - auto results = analyzer.prepareCallResultForQreg(callOp, rewriter); + // Replace the op with the inlined results and adjust the insert ops for the qreg mode + if (inlinedFunctionResults.size() == 1 && + isa(inlinedFunctionResults.front().getType())) { + auto results = analyzer.prepareResultsForQreg(inlinedFunctionResults.front(), + op.getLoc(), rewriter); rewriter.replaceOp(op, results); } else { - rewriter.replaceOp(op, callOp->getResults()); + rewriter.replaceOp(op, inlinedFunctionResults); } return success(); diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp index 522ed966b8..b1be26406c 100644 --- a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp +++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp @@ -12,34 +12,47 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define DEBUG_TYPE "decompose-lowering" +#include +#include +#include -// When we read the decomposition rules module from file, -// StablehloDialect may not be registered from start. #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/AllocatorBase.h" +#include "llvm/Support/DebugLog.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/WalkResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Inliner.h" #include "mlir/Transforms/Passes.h" -#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/StablehloOps.h" // When we read the decomposition rules module from file, StablehloDialect may not be registered from start. +#include "Quantum/IR/QuantumDialect.h" #include "Quantum/IR/QuantumOps.h" #include "Quantum/Transforms/Patterns.h" +#define DEBUG_TYPE "decompose-lowering" + using namespace mlir; using namespace catalyst::quantum; namespace catalyst { namespace quantum { + #define GEN_PASS_DEF_DECOMPOSELOWERINGPASS #define GEN_PASS_DECL_DECOMPOSELOWERINGPASS #include "Quantum/Transforms/Passes.h.inc" @@ -96,9 +109,15 @@ struct DecomposeLoweringPass : impl::DecomposeLoweringPassBase &decompositionRegistry) + llvm::StringMap &decompositionRegistry, + llvm::StringSet<> targetRules) { module.walk([&](func::FuncOp func) { + // if targetRules is provided, only add requested rules + if (!targetRules.empty() && !targetRules.contains(func.getName())) { + return WalkResult::skip(); + } + if (StringRef targetOp = DecompUtils::getTargetGateName(func); !targetOp.empty()) { removeUnusedFuncArgs(func); if (targetOp == "MultiRZ") { @@ -150,82 +169,33 @@ struct DecomposeLoweringPass : impl::DecomposeLoweringPassBasegetOperandTypes())); } - // Remove unused decomposition functions: - // Since the decomposition functions are marked as public from the frontend, - // there is no way to remove them with any DCE pass automatically. - // So we need to manually remove them from the module - void removeDecompositionFunctions(ModuleOp module, - llvm::StringMap &decompositionRegistry) - { - llvm::DenseSet usedDecompositionFunctions; - - module.walk([&](func::CallOp callOp) { - if (auto targetFunc = module.lookupSymbol(callOp.getCallee())) { - if (DecompUtils::isDecompositionFunction(targetFunc)) { - usedDecompositionFunctions.insert(targetFunc); - } - } - }); - - // remove unused decomposition functions - module.walk([&](func::FuncOp func) { - if (DecompUtils::isDecompositionFunction(func) && - !usedDecompositionFunctions.contains(func)) { - func.erase(); - } - return WalkResult::skip(); - }); - } - public: void runOnOperation() final { ModuleOp module = cast(getOperation()); // Step 1: Discover and register all decomposition functions in the module - discoverAndRegisterDecompositions(module, decompositionRegistry); + llvm::StringSet<> targetRules; + for (auto rule : targetRulesOption) { + targetRules.insert(rule); + } + discoverAndRegisterDecompositions(module, decompositionRegistry, targetRules); if (decompositionRegistry.empty()) { return; } - // Step 1.1: Find the target gate set + // Step 2: Find the target gate set findTargetGateSet(module, targetGateSet); - // Step 2: Canonicalize the module - RewritePatternSet patternsCanonicalization(&getContext()); - catalyst::quantum::CustomOp::getCanonicalizationPatterns(patternsCanonicalization, - &getContext()); - if (failed(applyPatternsGreedily(module, std::move(patternsCanonicalization)))) { - return signalPassFailure(); - } - - // Step 3: Apply the decomposition patterns + // Step 3: Apply the decomposition patterns, canonicalizing the insert/extract pairs RewritePatternSet decompositionPatterns(&getContext()); populateDecomposeLoweringPatterns(decompositionPatterns, decompositionRegistry, targetGateSet); - if (failed(applyPatternsGreedily(module, std::move(decompositionPatterns)))) { - return signalPassFailure(); - } - - // Step 4: Inline and canonicalize/CSE the module again - PassManager pm(&getContext()); - pm.addPass(createInlinerPass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - if (failed(pm.run(module))) { - return signalPassFailure(); - } - - // Step 5. Remove redundant decomposition functions - removeDecompositionFunctions(module, decompositionRegistry); - - // Step 6. Canonicalize the extract/insert pair - RewritePatternSet patternsInsertExtract(&getContext()); - catalyst::quantum::InsertOp::getCanonicalizationPatterns(patternsInsertExtract, + catalyst::quantum::InsertOp::getCanonicalizationPatterns(decompositionPatterns, &getContext()); - catalyst::quantum::ExtractOp::getCanonicalizationPatterns(patternsInsertExtract, + catalyst::quantum::ExtractOp::getCanonicalizationPatterns(decompositionPatterns, &getContext()); - if (failed(applyPatternsGreedily(module, std::move(patternsInsertExtract)))) { + if (failed(applyPatternsGreedily(module, std::move(decompositionPatterns)))) { return signalPassFailure(); } } diff --git a/mlir/lib/Quantum/Transforms/graph_decomposition.cpp b/mlir/lib/Quantum/Transforms/graph_decomposition.cpp index a8bf13b12f..5a9425e9cb 100644 --- a/mlir/lib/Quantum/Transforms/graph_decomposition.cpp +++ b/mlir/lib/Quantum/Transforms/graph_decomposition.cpp @@ -12,27 +12,47 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define DEBUG_TYPE "graph-decomposition" - +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/WalkResult.h" #include "stablehlo/dialect/StablehloOps.h" #include "Catalyst/Analysis/ResourceAnalysis.h" +#include "Catalyst/Analysis/ResourceResult.h" #include "Catalyst/Transforms/Passes.h" #include "QRef/Transforms/Passes.h" #include "Quantum/IR/QuantumDialect.h" +#include "Quantum/IR/QuantumInterfaces.h" #include "Quantum/IR/QuantumOps.h" #include "Quantum/Transforms/Passes.h" #include "Quantum/Transforms/QPDLoader.h" @@ -41,6 +61,8 @@ #include "DGSolver.hpp" #include "DGTypes.hpp" +#define DEBUG_TYPE "graph-decomposition" + using namespace mlir; using namespace catalyst::quantum; using namespace DecompGraph::Core; @@ -135,9 +157,12 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBasepush_back(rule.release()); + if (!symbolTable.lookup(rule->getName())) { + module.getBody()->push_back(rule.release()); + } } } diff --git a/mlir/test/Quantum/DecomposeLoweringTargetRulesTest.mlir b/mlir/test/Quantum/DecomposeLoweringTargetRulesTest.mlir new file mode 100644 index 0000000000..6c261d07dc --- /dev/null +++ b/mlir/test/Quantum/DecomposeLoweringTargetRulesTest.mlir @@ -0,0 +1,51 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt --pass-pipeline='builtin.module(decompose-lowering{target-rules=my_X_decomp,my_Z_decomp})' --split-input-file -verify-diagnostics %s | FileCheck %s + +// Test that decompose-lowering only applies the rules requested by the `target-rules` option when present + +module @test_module { + // CHECK: func.func private @my_X_decomp + func.func private @my_X_decomp(%q: !quantum.bit) -> !quantum.bit attributes {target_gate="X"} { + %angle = arith.constant 1.57 : f64 + %out = quantum.custom "RX"(%angle) %q : !quantum.bit + return %out : !quantum.bit + } + + // CHECK: func.func private @my_Y_decomp + func.func private @my_Y_decomp(%q: !quantum.bit) -> !quantum.bit attributes {target_gate="Y"} { + %angle = arith.constant 1.57 : f64 + %out = quantum.custom "RY"(%angle) %q : !quantum.bit + return %out : !quantum.bit + } + + // CHECK: func.func private @my_Z_decomp + func.func private @my_Z_decomp(%q: !quantum.bit) -> !quantum.bit attributes {target_gate="Z"} { + %angle = arith.constant 1.57 : f64 + %out = quantum.custom "RZ"(%angle) %q : !quantum.bit + return %out : !quantum.bit + } + + // CHECK: [[q:%.+]] = quantum.alloc_qb + // CHECK: [[x_out:%.+]] = quantum.custom "RX"(%{{.+}}) [[q]] + // CHECK: [[y_out:%.+]] = quantum.custom "Y"() [[x_out]] + // CHECK: [[z_out:%.+]] = quantum.custom "RZ"(%{{.+}}) [[y_out]] + // CHECK: quantum.dealloc_qb [[z_out]] + %0 = quantum.alloc_qb : !quantum.bit + %1 = quantum.custom "X"() %0 : !quantum.bit + %2 = quantum.custom "Y"() %1 : !quantum.bit + %3 = quantum.custom "Z"() %2 : !quantum.bit + quantum.dealloc_qb %3 : !quantum.bit +} diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index b7abe6f203..7a6d9d3566 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -41,8 +41,8 @@ module @two_hadamards { return %4 : tensor<4xf64> } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + // Decomposition function should be retained for future passes + // CHECK: func.func private @Hadamard_to_RY_decomp func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { %cst = arith.constant 3.1415926535897931 : f64 %cst_0 = arith.constant 1.5707963267948966 : f64 @@ -55,6 +55,8 @@ module @two_hadamards { // ----- // Test single Hadamard decomposition + +// CHECK-LABEL: module @single_hadamard module @single_hadamard { func.func @test_single_hadamard() -> !quantum.bit { // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 @@ -73,8 +75,8 @@ module @single_hadamard { return %2 : !quantum.bit } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + // Decomposition function should be retained for future passes + // CHECK: func.func private @Hadamard_to_RY_decomp func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { %cst = arith.constant 3.1415926535897931 : f64 %cst_0 = arith.constant 1.5707963267948966 : f64 @@ -85,6 +87,8 @@ module @single_hadamard { } // ----- + +// CHECK-LABEL: module @recursive module @recursive { func.func public @test_recursive() -> tensor<4xf64> attributes {quantum.node} { %0 = quantum.alloc( 2) : !quantum.reg @@ -112,15 +116,15 @@ module @recursive { return %4 : tensor<4xf64> } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + // Decomposition function should be retained for future passes + // CHECK: func.func private @Hadamard_to_RY_decomp func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { %out_qubits_0 = quantum.custom "RZRY"() %arg0 : !quantum.bit return %out_qubits_0 : !quantum.bit } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func private @RZRY_decomp + // Decomposition function should be retained for future passes + // CHECK: func.func private @RZRY_decomp func.func private @RZRY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "RZRY", llvm.linkage = #llvm.linkage} { %cst = arith.constant 3.1415926535897931 : f64 %cst_0 = arith.constant 1.5707963267948966 : f64 @@ -131,6 +135,8 @@ module @recursive { } // ----- + +// CHECK-LABEL: module @recursive module @recursive { func.func public @test_recursive() -> tensor<4xf64> attributes {quantum.node} { %0 = quantum.alloc( 2) : !quantum.reg @@ -158,15 +164,15 @@ module @recursive { return %4 : tensor<4xf64> } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + // Decomposition function should be retained for future passes + // CHECK: func.func private @Hadamard_to_RY_decomp func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { %out_qubits_0 = quantum.custom "RZRY"() %arg0 : !quantum.bit return %out_qubits_0 : !quantum.bit } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func private @RZRY_decomp + // Decomposition function should be retained for future passes + // CHECK: func.func private @RZRY_decomp func.func private @RZRY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "RZRY", llvm.linkage = #llvm.linkage} { %cst = arith.constant 3.1415926535897931 : f64 %cst_0 = arith.constant 1.5707963267948966 : f64 @@ -179,6 +185,8 @@ module @recursive { // ----- // Test parametric gates and wires + +// CHECK-LABEL: module @param_rxry module @param_rxry { func.func public @test_param_rxry(%arg0: tensor, %arg1: tensor) -> tensor<2xf64> attributes {quantum.node} { %c0_i64 = arith.constant 0 : i64 @@ -209,7 +217,7 @@ module @param_rxry { } // Decomposition function expects tensor while operation provides f64 - // CHECK-NOT: func.func private @ParametrizedRX_decomp + // CHECK: func.func private @ParametrizedRXRY_decomp func.func private @ParametrizedRXRY_decomp(%arg0: tensor, %arg1: !quantum.bit) -> !quantum.bit attributes {target_gate = "ParametrizedRXRY", llvm.linkage = #llvm.linkage} { %extracted = tensor.extract %arg0[] : tensor @@ -219,92 +227,60 @@ module @param_rxry { return %out_qubits_1 : !quantum.bit } } -// ----- - -// Test parametric gates and wires -module @param_rxry_2 { - func.func public @test_param_rxry_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<2xf64> attributes {quantum.node} { - %c0_i64 = arith.constant 0 : i64 - - // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg - %0 = quantum.alloc( 1) : !quantum.reg - - // CHECK: [[WIRE:%.+]] = tensor.extract %arg2[] : tensor - %extracted = tensor.extract %arg2[] : tensor - - // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][[[WIRE]]] : !quantum.reg -> !quantum.bit - %1 = quantum.extract %0[%extracted] : !quantum.reg -> !quantum.bit - - // CHECK: [[PARAM_0:%.+]] = tensor.extract %arg0[] : tensor - %param_0 = tensor.extract %arg0[] : tensor - - // CHECK: [[PARAM_1:%.+]] = tensor.extract %arg1[] : tensor - %param_1 = tensor.extract %arg1[] : tensor - - // CHECK: [[QUBIT1:%.+]] = quantum.custom "RX"([[PARAM_0]]) [[QUBIT]] : !quantum.bit - // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[PARAM_1]]) [[QUBIT1]] : !quantum.bit - // CHECK-NOT: quantum.custom "ParametrizedRXRY" - %out_qubits = quantum.custom "ParametrizedRXRY"(%param_0, %param_1) %1 : !quantum.bit - - // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT2]] : !quantum.reg, !quantum.bit - %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit - %3 = quantum.compbasis qreg %2 : !quantum.obs - %4 = quantum.probs %3 : tensor<2xf64> - quantum.dealloc %2 : !quantum.reg - return %4 : tensor<2xf64> - } - // Decomposition function expects tensor while operation provides f64 - // CHECK-NOT: func.func private @ParametrizedRX_decomp - func.func private @ParametrizedRXRY_decomp(%arg0: tensor, %arg1: tensor, %arg2: !quantum.bit) -> !quantum.bit - attributes {target_gate = "ParametrizedRXRY", llvm.linkage = #llvm.linkage} { - %extracted_param_0 = tensor.extract %arg0[] : tensor - %out_qubits = quantum.custom "RX"(%extracted_param_0) %arg2 : !quantum.bit - %extracted_param_1 = tensor.extract %arg1[] : tensor - %out_qubits_1 = quantum.custom "RY"(%extracted_param_1) %out_qubits : !quantum.bit - return %out_qubits_1 : !quantum.bit - } -} // ----- // Test recursive and qreg-based gate decomposition + +// CHECK-LABEL: module @qreg_base_circuit module @qreg_base_circuit { func.func public @test_qreg_base_circuit() -> tensor<2xf64> attributes {quantum.node} { - // CHECK: [[CST:%.+]] = arith.constant 1.000000e+00 : f64 + // CHECK-DAG: [[cmp_0:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: [[test_angle:%.+]] = arith.constant 1.000000e+00 : f64 + // CHECK-DAG: [[index_tensor:%.+]] = arith.constant dense<0> : tensor<1xi64> + // CHECK-DAG: [[cmp_1:%.+]] = arith.constant dense<1.000000e+00> : tensor + // CHECK: [[reg0:%.+]] = quantum.alloc( 1) : !quantum.reg %cst = arith.constant 1.000000e+00 : f64 - - // CHECK: [[CST_0:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor - // CHECK: [[CST_1:%.+]] = arith.constant dense<0> : tensor<1xi64> - // CHECK: [[CST_2:%.+]] = arith.constant dense<1.000000e+00> : tensor - // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg %0 = quantum.alloc( 1) : !quantum.reg - // CHECK: [[EXTRACT_QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[MRES:%.+]], [[OUT_QUBIT:%.+]] = quantum.measure [[EXTRACT_QUBIT]] : i1, !quantum.bit - // CHECK: [[REG1:%.+]] = quantum.insert [[REG]][ 0], [[OUT_QUBIT]] : !quantum.reg, !quantum.bit - // CHECK: [[COMPARE:%.+]] = stablehlo.compare NE, [[CST_2]], [[CST_0]], FLOAT : (tensor, tensor) -> tensor - // CHECK: [[EXTRACTED:%.+]] = tensor.extract [[COMPARE]][] : tensor - // CHECK: [[CONDITIONAL:%.+]] = scf.if [[EXTRACTED]] -> (!quantum.reg) { - // CHECK: [[SLICE1:%.+]] = stablehlo.slice [[CST_1]] [0:1] : (tensor<1xi64>) -> tensor<1xi64> - // CHECK: [[RESHAPE1:%.+]] = stablehlo.reshape [[SLICE1]] : (tensor<1xi64>) -> tensor - // CHECK: [[EXTRACTED_3:%.+]] = tensor.extract [[RESHAPE1]][] : tensor - // CHECK: [[FROM_ELEMENTS:%.+]] = tensor.from_elements [[EXTRACTED_3]] : tensor<1xi64> - // CHECK: [[SLICE2:%.+]] = stablehlo.slice [[FROM_ELEMENTS]] [0:1] : (tensor<1xi64>) -> tensor<1xi64> - // CHECK: [[RESHAPE2:%.+]] = stablehlo.reshape [[SLICE2]] : (tensor<1xi64>) -> tensor - // CHECK: [[EXTRACTED_4:%.+]] = tensor.extract [[RESHAPE2]][] : tensor - // CHECK: [[EXTRACT1:%.+]] = quantum.extract [[REG1]][[[EXTRACTED_4]]] : !quantum.reg -> !quantum.bit - // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST]]) [[EXTRACT1]] : !quantum.bit - // CHECK: [[INSERT1:%.+]] = quantum.insert [[REG1]][[[EXTRACTED_4]]], [[RZ1]] : !quantum.reg, !quantum.bit - // CHECK: [[EXTRACT2:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED_3]]] : !quantum.reg -> !quantum.bit - // CHECK: [[INSERT2:%.+]] = quantum.insert [[REG1]][[[EXTRACTED_3]]], [[EXTRACT2]] : !quantum.reg, !quantum.bit - // CHECK: [[EXTRACT3:%.+]] = quantum.extract [[INSERT2]][[[EXTRACTED_4]]] : !quantum.reg -> !quantum.bit - // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST]]) [[EXTRACT3]] : !quantum.bit - // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED_4]]], [[RZ2]] : !quantum.reg, !quantum.bit - // CHECK: [[EXTRACT4:%.+]] = quantum.extract [[INSERT3]][[[EXTRACTED_3]]] : !quantum.reg -> !quantum.bit - // CHECK: [[INSERT4:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED_3]]], [[EXTRACT4]] : !quantum.reg, !quantum.bit - // CHECK: scf.yield [[INSERT4]] : !quantum.reg + + // CHECK: [[q0:%.+]] = quantum.extract [[reg0]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[meas:%.+]], [[q1:%.+]] = quantum.measure [[q0]] : i1, !quantum.bit + // CHECK: [[reg1:%.+]] = quantum.insert [[reg0]][ 0], [[q1]] : !quantum.reg, !quantum.bit + // CHECK: [[cmp:%.+]] = stablehlo.compare NE, [[cmp_1]], [[cmp_0]], FLOAT : (tensor, tensor) -> tensor + // CHECK: [[cond:%.+]] = tensor.extract [[cmp]][] : tensor + // CHECK: [[condresult:%.+]] = scf.if [[cond]] -> (!quantum.reg) { + // CHECK: [[slice0:%.+]] = stablehlo.slice [[index_tensor]] [0:1] : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: [[reshape0:%.+]] = stablehlo.reshape [[slice0]] : (tensor<1xi64>) -> tensor + // CHECK: [[index0:%.+]] = tensor.extract [[reshape0]][] : tensor + // CHECK: [[fromelements0:%.+]] = tensor.from_elements [[index0]] : tensor<1xi64> + // CHECK: [[slice1:%.+]] = stablehlo.slice [[fromelements0]] [0:1] : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: [[reshape1:%.+]] = stablehlo.reshape [[slice1]] : (tensor<1xi64>) -> tensor + // CHECK: [[index1:%.+]] = tensor.extract [[reshape1]][] : tensor + // CHECK: [[q2:%.+]] = quantum.extract [[reg1]][[[index1]]] : !quantum.reg -> !quantum.bit + // CHECK: [[q3:%.+]] = quantum.custom "RZ"([[test_angle]]) [[q2]] : !quantum.bit + // CHECK: [[index2:%.+]] = tensor.extract [[reshape1]][] + // CHECK: [[reg2:%.+]] = quantum.insert [[reg1]][[[index2]]], [[q3]] : !quantum.reg, !quantum.bit + // CHECK: [[q4:%.+]] = quantum.extract [[reg2]][[[index0]]] : !quantum.reg -> !quantum.bit + // CHECK: [[slice3:%.+]] = stablehlo.slice [[index_tensor]] [0:1] : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: [[reshape3:%.+]] = stablehlo.reshape [[slice3]] : (tensor<1xi64>) -> tensor + // CHECK: [[extract6:%.+]] = tensor.extract [[reshape0]][] + // CHECK: [[reg3:%.+]] = quantum.insert [[reg1]][[[extract6]]], [[q4]] : !quantum.reg, !quantum.bit + // CHECK: [[index3:%.+]] = tensor.extract [[reshape3]][] + // CHECK: [[fromelements1:%.+]] = tensor.from_elements [[index3]] + // CHECK: [[slice4:%.+]] = stablehlo.slice [[fromelements1]] [0:1] + // CHECK: [[reshape4:%.+]] = stablehlo.reshape [[slice4]] + // CHECK: [[index4:%.+]] = tensor.extract [[reshape4]][] + // CHECK: [[q5:%.+]] = quantum.extract [[reg3]][[[index4]]] : !quantum.reg -> !quantum.bit + // CHECK: [[q6:%.+]] = quantum.custom "RZ"([[test_angle]]) [[q5]] : !quantum.bit + // CHECK: [[index5:%.+]] = tensor.extract [[reshape4]][] + // CHECK: [[reg4:%.+]] = quantum.insert [[reg3]][[[index5]]], [[q6]] : !quantum.reg, !quantum.bit + // CHECK: [[q7:%.+]] = quantum.extract [[reg4]][[[index3]]] : !quantum.reg -> !quantum.bit + // CHECK: [[index6:%.+]] = tensor.extract [[reshape3]][] + // CHECK: [[out:%.+]] = quantum.insert [[reg3]][[[index6]]], [[q7]] : !quantum.reg, !quantum.bit + // CHECK: scf.yield [[out]] : !quantum.reg // CHECK: } else { - // CHECK: scf.yield [[REG1]] : !quantum.reg + // CHECK: scf.yield [[reg1]] : !quantum.reg // CHECK: } // CHECK-NOT: quantum.custom "Test" %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit @@ -318,8 +294,8 @@ module @qreg_base_circuit { return %4 : tensor<2xf64> } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func private @Test_rule_1 + // Decomposition function should be retained for future passes + // CHECK: func.func private @Test_rule_1 func.func private @Test_rule_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {target_gate = "Test", llvm.linkage = #llvm.linkage} { %cst = stablehlo.constant dense<0.000000e+00> : tensor @@ -352,8 +328,8 @@ module @qreg_base_circuit { return %1 : !quantum.reg } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func private @RzDecomp_rule_1 + // Decomposition function should be retained for future passes + // CHECK: func.func private @RzDecomp_rule_1 func.func private @RzDecomp_rule_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {target_gate = "RzDecomp", llvm.linkage = #llvm.linkage} { %0 = stablehlo.slice %arg2 [0:1] : (tensor<1xi64>) -> tensor<1xi64> @@ -370,12 +346,12 @@ module @qreg_base_circuit { // ----- +// CHECK-LABEL: module @multi_wire_cnot_decomposition module @multi_wire_cnot_decomposition { func.func public @test_cnot_decomposition() -> tensor<4xf64> attributes {quantum.node} { %0 = quantum.alloc( 2) : !quantum.reg %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit - // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 // CHECK: [[WIRE_TENSOR:%.+]] = arith.constant dense<[0, 1]> : tensor<2xi64> @@ -388,13 +364,22 @@ module @multi_wire_cnot_decomposition { // CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT1]] : !quantum.bit // CHECK: [[RY1:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ1]] : !quantum.bit + // CHECK: [[index:%.+]] = tensor.extract [[RESHAPE2]][] + // CHECK: [[INSERT_TARGET:%.+]] = quantum.insert [[REG]][[[index]]], [[RY1]] : !quantum.reg, !quantum.bit // CHECK: [[EXTRACTED2:%.+]] = tensor.extract [[RESHAPE1]][] : tensor - // CHECK: [[QUBIT0:%.+]] = quantum.extract [[REG]][[[EXTRACTED2]]] : !quantum.reg -> !quantum.bit - // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[RY1]] : !quantum.bit, !quantum.bit - // CHECK: [[INSERT2:%.+]] = quantum.insert [[REG]][[[EXTRACTED2]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit - // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[CZ_RESULT]]#1 : !quantum.bit + // CHECK: [[QUBIT0:%.+]] = quantum.extract [[INSERT_TARGET]][[[EXTRACTED2]]] : !quantum.reg -> !quantum.bit + // CHECK: [[index2:%.+]] = tensor.extract [[RESHAPE2]][] + // CHECK: [[QUBIT1_UPDATED:%.+]] = quantum.extract [[INSERT_TARGET]][[[index2]]] : !quantum.reg -> !quantum.bit + // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[QUBIT1_UPDATED]] : !quantum.bit, !quantum.bit + // CHECK: [[index3:%.+]] = tensor.extract [[RESHAPE1]][] + // CHECK: [[INSERT2:%.+]] = quantum.insert [[INSERT_TARGET]][[[index3]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[index4:%.+]] = tensor.extract [[RESHAPE2]][] + // CHECK: [[INSERT_CZ1:%.+]] = quantum.insert [[INSERT2]][[[index4]]], [[CZ_RESULT]]#1 : !quantum.reg, !quantum.bit + // CHECK: [[TARGET_AFTER_CZ:%.+]] = quantum.extract [[INSERT_CZ1]][{{%.+}}] : !quantum.reg -> !quantum.bit + // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[TARGET_AFTER_CZ]] : !quantum.bit // CHECK: [[RY2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ2]] : !quantum.bit - // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED]]], [[RY2]] : !quantum.reg, !quantum.bit + // CHECK: [[index5:%.+]] = tensor.extract [[RESHAPE2]][] + // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT_CZ1]][[[index5]]], [[RY2]] : !quantum.reg, !quantum.bit // CHECK: [[FINAL_QUBIT0:%.+]] = quantum.extract [[INSERT3]][ 0] : !quantum.reg -> !quantum.bit // CHECK: [[FINAL_QUBIT1:%.+]] = quantum.extract [[INSERT3]][ 1] : !quantum.reg -> !quantum.bit // CHECK-NOT: quantum.custom "CNOT" @@ -410,8 +395,8 @@ module @multi_wire_cnot_decomposition { return %8 : tensor<4xf64> } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func private @CNOT_rule_cz_rz_ry + // Decomposition function should be retained for future passes + // CHECK: func.func private @CNOT_rule_cz_rz_ry func.func private @CNOT_rule_cz_rz_ry(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage} { // CNOT decomposition: CNOT = (I ⊗ H) * CZ * (I ⊗ H) %cst = arith.constant 1.5707963267948966 : f64 @@ -456,6 +441,7 @@ module @multi_wire_cnot_decomposition { // ----- +// CHECK-LABEL: module @cnot_alternative_decomposition module @cnot_alternative_decomposition { func.func public @test_cnot_alternative_decomposition() -> tensor<4xf64> attributes {quantum.node} { %0 = quantum.alloc( 2) : !quantum.reg @@ -485,8 +471,8 @@ module @cnot_alternative_decomposition { return %8 : tensor<4xf64> } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func private @CNOT_rule_h_cnot_h + // Decomposition function should be retained for future passes + // CHECK: func.func private @CNOT_rule_h_cnot_h func.func private @CNOT_rule_h_cnot_h(%arg0: !quantum.bit, %arg1: !quantum.bit) -> (!quantum.bit, !quantum.bit) attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage} { // CNOT decomposition: CNOT = (I ⊗ H) * CZ * (I ⊗ H) %cst = arith.constant 1.5707963267948966 : f64 @@ -509,6 +495,7 @@ module @cnot_alternative_decomposition { // ----- +// CHECK-LABEL: module @mcm_example module @mcm_example { func.func public @test_mcm_hadamard() -> tensor<2xf64> attributes {quantum.node} { %0 = quantum.alloc( 1) : !quantum.reg @@ -516,9 +503,9 @@ module @mcm_example { %mres, %out_qubit = quantum.measure %1 : i1, !quantum.bit %2 = quantum.insert %0[ 0], %out_qubit : !quantum.reg, !quantum.bit - // CHECK: [[RZ_QUBIT:%.+]] = quantum.custom "RZ"([[CST_0:%.+]]) - // CHECK: [[RY_QUBIT:%.+]] = quantum.custom "RY"([[CST_1:%.+]]) [[RZ_QUBIT]] : !quantum.bit - // CHECK: [[REG_1:%.+]] = quantum.insert [[REG:%.+]][[[EXTRACTED:%.+]]], [[RY_QUBIT]] : !quantum.reg, !quantum.bit + // CHECK: quantum.custom "RZ" + // CHECK: quantum.custom "RY" + // CHECK-NOT: quantum.custom "Hadamard" %3 = quantum.extract %2[ 0] : !quantum.reg -> !quantum.bit %out_qubits = quantum.custom "Hadamard"() %3 : !quantum.bit @@ -530,8 +517,8 @@ module @mcm_example { return %6 : tensor<2xf64> } - // Decomposition function should be applied and removed from the module - // CHECK-NOT: func.func public @rz_ry + // Decomposition function should be retained for future passes + // CHECK: func.func public @rz_ry func.func public @rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} { %cst = arith.constant 3.1415926535897931 : f64 %cst_0 = arith.constant 1.5707963267948966 : f64 @@ -555,15 +542,16 @@ module @mcm_example { // ----- +// CHECK-LABEL: module @circuit_with_multirz module @circuit_with_multirz { func.func public @test_with_multirz() -> tensor<4xf64> attributes {quantum.node} { %0 = quantum.alloc( 2) : !quantum.reg %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit // CHECK: func.func public @test_with_multirz() -> tensor<4xf64> - // CHECK: [[CST_RZ:%.+]] = arith.constant 5.000000e-01 : f64 - // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 - // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 - // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK-DAG: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK-DAG: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK-DAG: [[CST_RZ:%.+]] = arith.constant 5.000000e-01 : f64 + // CHECK-DAG: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_RZ]]) {{%.+}} : !quantum.bit // CHECK-NOT: quantum.multirz @@ -584,7 +572,7 @@ module @circuit_with_multirz { return %4 : tensor<4xf64> } - // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + // CHECK: func.func private @Hadamard_to_RY_decomp func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { %cst = arith.constant 3.1415926535897931 : f64 %cst_0 = arith.constant 1.5707963267948966 : f64 @@ -593,8 +581,8 @@ module @circuit_with_multirz { return %out_qubits_1 : !quantum.bit } - // CHECK-NOT: func.func private @_multi_rz_decomposition_wires_1 - func.func public @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"} { + // CHECK: func.func private @_multi_rz_decomposition_wires_1 + func.func private @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"} { %0 = stablehlo.slice %arg2 [0:1] : (tensor<1xi64>) -> tensor<1xi64> %1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor %extracted = tensor.extract %1[] : tensor @@ -610,6 +598,7 @@ module @circuit_with_multirz { // ----- +// CHECK-LABEL: module @qreg_at_not_first_arg module @qreg_at_not_first_arg { func.func public @test_qreg_at_not_first_arg() attributes {quantum.node} { // CHECK: [[wire_tensor:%.+]] = arith.constant dense<[0, 1]> : tensor<2xi64> @@ -635,7 +624,7 @@ module @qreg_at_not_first_arg { return } - // CHECK-NOT: func.func private @my_cnot + // CHECK: func.func private @my_cnot func.func private @my_cnot(%arg0: tensor<2xi64>, %arg1: !quantum.reg) -> !quantum.reg attributes {target_gate = "CNOT"} { %0 = stablehlo.slice %arg0 [0:1] : (tensor<2xi64>) -> tensor<1xi64> %1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor @@ -654,6 +643,7 @@ module @qreg_at_not_first_arg { // ----- +// CHECK-LABEL: module @test_paulirot module @test_paulirot { func.func @test() attributes {quantum.node} { %pi = arith.constant 3.1 : f64 @@ -678,7 +668,7 @@ module @test_paulirot { return } - // CHECK-NOT: my_paulirot_decomp + // CHECK: my_paulirot_decomp func.func private @my_paulirot_decomp(%inreg : !quantum.reg, %angle_tensor : tensor, %q_tensor : tensor<3xi64>) -> !quantum.reg attributes {target_gate = "paulirotZXY"} { %pi_by_2 = arith.constant 1.57 : f64 %m_pi_by_2 = arith.constant -1.57 : f64 @@ -717,6 +707,7 @@ module @test_paulirot { // ----- +// CHECK-LABEL: module @null_decomp_rule module @null_decomp_rule{ func.func public @test_null_decomp_rule() attributes {quantum.node} { // CHECK: [[reg:%.+]] = quantum.alloc( 1) @@ -733,7 +724,7 @@ module @null_decomp_rule{ return } - // CHECK-NOT: func.func private @null_decomp + // CHECK: func.func private @null_decomp func.func private @null_decomp() attributes {target_gate = "PauliX"} { return } @@ -741,6 +732,7 @@ module @null_decomp_rule{ // ----- +// CHECK-LABEL: module @different_qreg_values module @different_qreg_values{ func.func public @circuit() attributes {quantum.node} { // CHECK: [[wire_tensor:%.+]] = arith.constant dense<[2, 1]> : tensor<2xi64> @@ -776,7 +768,7 @@ module @different_qreg_values{ return } - // CHECK-NOT: func.func private @my_cnot + // CHECK: func.func private @my_cnot func.func private @my_cnot(%arg0: tensor<2xi64>, %arg1: !quantum.reg) -> !quantum.reg attributes {target_gate = "CNOT"} { %0 = stablehlo.slice %arg0 [0:1] : (tensor<2xi64>) -> tensor<1xi64> %1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor