diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index f5b0045e82..ce4b2ed980 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -5,6 +5,13 @@
Improvements ðŸ›
+* The `decompose-lowering` pass now supports applying a selection of the available decomposition rules via the `target_rules` parameter.
+ The pass also no longer applies the `inline`, `cse` and `canonicalize` passes to avoid unnecessary IR mutations.
+ Instead, decomposition rules are deterministically inlined by a custom function (`inline` is non-deterministic, using an estimated benefit and threshold as criteria for inlining).
+ Decomposition rules are no longer removed after the `decompose-lowering` pass, which allows them to be used by subsequent passes, namely `graph-decomposition`.
+ Instead, rules are removed by the `symbol-dce` pass at the end of the `QuantumCompilationStage`.
+ [(#2973)](https://github.com/PennyLaneAI/catalyst/pull/2973)
+
* The new `pennylane.core.Operator2` can now be lowered to MLIR with program capture for operators
without non-lowerable arguments.
[(#2969)](https://github.com/PennyLaneAI/catalyst/pull/2969/)
diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td
index e8da695ae2..16740f815c 100644
--- a/mlir/include/Quantum/Transforms/Passes.td
+++ b/mlir/include/Quantum/Transforms/Passes.td
@@ -189,6 +189,14 @@ def GraphDecompositionPass : Pass<"graph-decomposition", "mlir::ModuleOp"> {
def DecomposeLoweringPass : Pass<"decompose-lowering"> {
let summary = "Replace quantum operations with compiled decomposition rules.";
+
+ let options = [
+ ListOption<
+ /*C++ name*/"targetRulesOption",
+ /*CLI name*/"target-rules",
+ /*Type*/"std::string",
+ /*Description*/"The set of decomposition rules to apply. If empty, applies all rules.">
+ ];
}
def DisentangleCNOTPass : Pass<"disentangle-cnot"> {
diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp
index 4fdc0d421c..ce8d37d971 100644
--- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp
+++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp
@@ -13,15 +13,26 @@
// limitations under the License.
#include // std::move_backward
+#include
+#include
+#include
+#include
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
+#include "mlir/Support/LLVM.h"
#include "Quantum/IR/QuantumOps.h"
#include "Quantum/IR/QuantumTypes.h"
@@ -34,8 +45,8 @@ namespace catalyst {
namespace quantum {
// The goal of this class is to analyze the signature of a custom operation to get the enough
-// information to prepare the call operands and results for replacing the op to calling the
-// decomposition function.
+// information to prepare the operands and results for replacing the op with the decomposition
+// function.
class BaseSignatureAnalyzer {
protected:
bool isValid = true;
@@ -177,7 +188,7 @@ class BaseSignatureAnalyzer {
return latestQreg;
}
- // Prepare the operands for calling the decomposition function
+ // Prepare the operands for the decomposition function
// There are two cases:
// 1. The first input is a qreg, which means the decomposition function is a qreg mode function
// 2. Otherwise, the decomposition function is a qubit mode function
@@ -187,10 +198,10 @@ class BaseSignatureAnalyzer {
// - func(qreg, param*, inWires*, inCtrlWires*?, inCtrlValues*?) -> qreg
// 2. qubit mode:
// - func(param*, inQubits*, inCtrlQubits*?, inCtrlValues*?) -> outQubits*
- llvm::SmallVector prepareCallOperands(func::FuncOp decompFunc, PatternRewriter &rewriter,
- Location loc)
+ llvm::SmallVector prepareOperands(func::FuncOp rule, PatternRewriter &rewriter,
+ Location loc)
{
- auto funcType = decompFunc.getFunctionType();
+ auto funcType = rule.getFunctionType();
auto funcInputs = funcType.getInputs();
SmallVector funcInputsNoQreg;
@@ -202,10 +213,10 @@ class BaseSignatureAnalyzer {
SmallVector operands(funcInputs.size());
- auto qregIt = llvm::find_if(decompFunc.getFunctionType().getInputs(),
+ auto qregIt = llvm::find_if(rule.getFunctionType().getInputs(),
[](mlir::Type t) { return isa(t); });
- int qregIdx = std::distance(decompFunc.getFunctionType().getInputs().begin(), qregIt);
- bool hasQreg = (qregIt != decompFunc.getFunctionType().getInputs().end());
+ int qregIdx = std::distance(rule.getFunctionType().getInputs().begin(), qregIt);
+ bool hasQreg = (qregIt != rule.getFunctionType().getInputs().end());
int operandIdx = 0;
if (!signature.params.empty()) {
@@ -264,22 +275,18 @@ class BaseSignatureAnalyzer {
return operands;
}
- // Prepare the results for the call operation
- SmallVector prepareCallResultForQreg(func::CallOp callOp, PatternRewriter &rewriter)
+ // Prepare the results produced by a qreg-mode decomposition rule
+ SmallVector prepareResultsForQreg(Value qreg, Location loc, PatternRewriter &rewriter)
{
- assert(callOp.getNumResults() == 1 && "only one qreg result for qreg mode is allowed");
-
- auto qreg = callOp.getResult(0);
assert(isa(qreg.getType()) && "only allow to have qreg result");
SmallVector newResults;
- rewriter.setInsertionPointAfter(callOp);
for (const auto &indices : {signature.outQubitIndices, signature.outCtrlQubitIndices}) {
for (const auto &index : indices) {
auto extractOp = quantum::ExtractOp::create(
- rewriter, callOp.getLoc(), rewriter.getType(), qreg,
- index.getValue(), index.getAttr());
+ rewriter, loc, rewriter.getType(), qreg, index.getValue(),
+ index.getAttr());
newResults.emplace_back(extractOp.getResult());
}
}
@@ -325,7 +332,7 @@ class BaseSignatureAnalyzer {
return {startIdx, paramTypeEnd};
}
- // generate params for calling the decomposition function based on function type requirements
+ // generate params for the decomposition function based on function type requirements
SmallVector generateParams(ValueRange signatureParams, ArrayRef funcParamTypes,
PatternRewriter &rewriter, Location loc)
{
diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp
index a4f0fdc30d..5f7c88878c 100644
--- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp
+++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp
@@ -12,13 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#define DEBUG_TYPE "decompose-lowering"
+#include
+#include
+#include
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/AllocatorBase.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
#include "Quantum/IR/QuantumInterfaces.h"
#include "Quantum/IR/QuantumOps.h"
@@ -26,12 +40,47 @@
#include "DecomposeLoweringImpl.hpp"
+#define DEBUG_TYPE "decompose-lowering"
+
using namespace mlir;
using namespace catalyst::quantum;
namespace catalyst {
namespace quantum {
+SmallVector getDecompRuleResults(func::FuncOp rule, ValueRange operands,
+ PatternRewriter &rewriter)
+{
+ Block &body = rule.front();
+ auto returnOp = cast(body.getTerminator());
+
+ IRMapping mapping;
+ mapping.map(body.getArguments(), operands);
+
+ for (Operation &op : body.without_terminator()) {
+ rewriter.clone(op, mapping);
+ }
+
+ SmallVector results;
+ for (Value operand : returnOp.getOperands()) {
+ results.push_back(mapping.lookupOrDefault(operand));
+ }
+ return results;
+}
+
+bool isInDecompRule(Operation *op)
+{
+ while (auto parentOp = op->getParentOp()) {
+ if (auto funcOp = dyn_cast(parentOp)) {
+ if (funcOp->hasAttr("target_gate")) {
+ return true;
+ }
+ }
+ op = parentOp;
+ }
+ return false;
+}
+
struct DLCustomOpPattern : public OpRewritePattern {
private:
const llvm::StringMap &decompositionRegistry;
@@ -54,18 +103,24 @@ struct DLCustomOpPattern : public OpRewritePattern {
return failure();
}
- // Find the corresponding decomposition function for the op
+ // do not nest decomposition rules, they're applied greedily and this can lead to
+ // cycles/identity rules
+ if (isInDecompRule(op)) {
+ return failure();
+ }
+
+ // Find the corresponding decomposition rule for the op
auto it = decompositionRegistry.find(gateName);
if (it == decompositionRegistry.end()) {
return failure();
}
- func::FuncOp decompFunc = it->second;
+ func::FuncOp rule = it->second;
// For null decomp rules, the signature will not have any quantum values
// This is a deviation from the standard decomp func signature, so we deal with it
// separately
- if (!llvm::any_of(llvm::concat(decompFunc.getFunctionType().getInputs(),
- decompFunc.getFunctionType().getResults()),
+ if (!llvm::any_of(llvm::concat(rule.getFunctionType().getInputs(),
+ rule.getFunctionType().getResults()),
[](const mlir::Type t) {
return isa(t);
})) {
@@ -76,32 +131,32 @@ struct DLCustomOpPattern : public OpRewritePattern {
return success();
}
- // Here is the assumption that the decomposition function must have at least one input and
+ // Here is the assumption that the decomposition rule must have at least one input and
// one result
- assert(decompFunc.getFunctionType().getNumInputs() > 0 &&
+ assert(rule.getFunctionType().getNumInputs() > 0 &&
"Decomposition function must have at least one input");
- assert(decompFunc.getFunctionType().getNumResults() >= 1 &&
+ assert(rule.getFunctionType().getNumResults() >= 1 &&
"Decomposition function must have at least one result");
rewriter.setInsertionPointAfter(op);
- auto enableQreg = llvm::any_of(decompFunc.getFunctionType().getInputs(),
+ auto enableQreg = llvm::any_of(rule.getFunctionType().getInputs(),
[](mlir::Type t) { return isa(t); });
auto analyzer = CustomOpSignatureAnalyzer(op, enableQreg);
assert(analyzer && "Analyzer should be valid");
- auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
- auto callOp =
- func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(),
- decompFunc.getSymName(), callOperands);
+ auto operands = analyzer.prepareOperands(rule, rewriter, op.getLoc());
+ SmallVector inlinedFunctionResults = getDecompRuleResults(rule, operands, rewriter);
- // Replace the op with the call op and adjust the insert ops for the qreg mode
- if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) {
- auto results = analyzer.prepareCallResultForQreg(callOp, rewriter);
+ // Replace the op with the inlined function and adjust the insert ops for the qreg mode
+ if (inlinedFunctionResults.size() == 1 &&
+ isa(inlinedFunctionResults.front().getType())) {
+ auto results = analyzer.prepareResultsForQreg(inlinedFunctionResults.front(),
+ op.getLoc(), rewriter);
rewriter.replaceOp(op, results);
}
else {
- rewriter.replaceOp(op, callOp->getResults());
+ rewriter.replaceOp(op, inlinedFunctionResults);
}
return success();
@@ -130,6 +185,12 @@ struct DLMultiRZOpPattern : public OpRewritePattern {
return failure();
}
+ // do not nest decomposition rules, they're applied greedily and this can lead to
+ // cycles/identity rules
+ if (isInDecompRule(op)) {
+ return failure();
+ }
+
// Find the corresponding decomposition function for the op
auto numQubits = op.getInQubits().size();
auto MRZNameWithQubits = gateName + "_" + std::to_string(numQubits);
@@ -165,18 +226,19 @@ struct DLMultiRZOpPattern : public OpRewritePattern {
auto analyzer = MultiRZOpSignatureAnalyzer(op, enableQreg);
assert(analyzer && "Analyzer should be valid");
- auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
- auto callOp =
- func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(),
- decompFunc.getSymName(), callOperands);
+ auto operands = analyzer.prepareOperands(decompFunc, rewriter, op.getLoc());
+ SmallVector inlinedFunctionResults =
+ getDecompRuleResults(decompFunc, operands, rewriter);
- // Replace the op with the call op and adjust the insert ops for the qreg mode
- if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) {
- auto results = analyzer.prepareCallResultForQreg(callOp, rewriter);
+ // Replace the op with the inlined function and adjust the insert ops for the qreg mode
+ if (inlinedFunctionResults.size() == 1 &&
+ isa(inlinedFunctionResults.front().getType())) {
+ auto results = analyzer.prepareResultsForQreg(inlinedFunctionResults.front(),
+ op.getLoc(), rewriter);
rewriter.replaceOp(op, results);
}
else {
- rewriter.replaceOp(op, callOp->getResults());
+ rewriter.replaceOp(op, inlinedFunctionResults);
}
return success();
@@ -205,6 +267,12 @@ struct DLPauliRotOpPattern : public OpRewritePattern {
return failure();
}
+ // do not nest decomposition rules, they're applied greedily and this can lead to
+ // cycles/identity rules
+ if (isInDecompRule(op)) {
+ return failure();
+ }
+
// Find the corresponding decomposition function for the op
auto it = decompositionRegistry.find(gateName);
if (it == decompositionRegistry.end()) {
@@ -226,18 +294,19 @@ struct DLPauliRotOpPattern : public OpRewritePattern {
auto analyzer = PauliRotOpSignatureAnalyzer(op, enableQreg);
assert(analyzer && "Analyzer should be valid");
- auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
- auto callOp =
- func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(),
- decompFunc.getSymName(), callOperands);
+ auto operands = analyzer.prepareOperands(decompFunc, rewriter, op.getLoc());
+ SmallVector inlinedFunctionResults =
+ getDecompRuleResults(decompFunc, operands, rewriter);
- // Replace the op with the call op and adjust the insert ops for the qreg mode
- if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) {
- auto results = analyzer.prepareCallResultForQreg(callOp, rewriter);
+ // Replace the op with the inlined results and adjust the insert ops for the qreg mode
+ if (inlinedFunctionResults.size() == 1 &&
+ isa(inlinedFunctionResults.front().getType())) {
+ auto results = analyzer.prepareResultsForQreg(inlinedFunctionResults.front(),
+ op.getLoc(), rewriter);
rewriter.replaceOp(op, results);
}
else {
- rewriter.replaceOp(op, callOp->getResults());
+ rewriter.replaceOp(op, inlinedFunctionResults);
}
return success();
diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp
index 522ed966b8..b1be26406c 100644
--- a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp
+++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp
@@ -12,34 +12,47 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#define DEBUG_TYPE "decompose-lowering"
+#include
+#include
+#include
-// When we read the decomposition rules module from file,
-// StablehloDialect may not be registered from start.
#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/AllocatorBase.h"
+#include "llvm/Support/DebugLog.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/WalkResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Inliner.h"
#include "mlir/Transforms/Passes.h"
-#include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/StablehloOps.h" // When we read the decomposition rules module from file, StablehloDialect may not be registered from start.
+#include "Quantum/IR/QuantumDialect.h"
#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Patterns.h"
+#define DEBUG_TYPE "decompose-lowering"
+
using namespace mlir;
using namespace catalyst::quantum;
namespace catalyst {
namespace quantum {
+
#define GEN_PASS_DEF_DECOMPOSELOWERINGPASS
#define GEN_PASS_DECL_DECOMPOSELOWERINGPASS
#include "Quantum/Transforms/Passes.h.inc"
@@ -96,9 +109,15 @@ struct DecomposeLoweringPass : impl::DecomposeLoweringPassBase &decompositionRegistry)
+ llvm::StringMap &decompositionRegistry,
+ llvm::StringSet<> targetRules)
{
module.walk([&](func::FuncOp func) {
+ // if targetRules is provided, only add requested rules
+ if (!targetRules.empty() && !targetRules.contains(func.getName())) {
+ return WalkResult::skip();
+ }
+
if (StringRef targetOp = DecompUtils::getTargetGateName(func); !targetOp.empty()) {
removeUnusedFuncArgs(func);
if (targetOp == "MultiRZ") {
@@ -150,82 +169,33 @@ struct DecomposeLoweringPass : impl::DecomposeLoweringPassBasegetOperandTypes()));
}
- // Remove unused decomposition functions:
- // Since the decomposition functions are marked as public from the frontend,
- // there is no way to remove them with any DCE pass automatically.
- // So we need to manually remove them from the module
- void removeDecompositionFunctions(ModuleOp module,
- llvm::StringMap &decompositionRegistry)
- {
- llvm::DenseSet usedDecompositionFunctions;
-
- module.walk([&](func::CallOp callOp) {
- if (auto targetFunc = module.lookupSymbol(callOp.getCallee())) {
- if (DecompUtils::isDecompositionFunction(targetFunc)) {
- usedDecompositionFunctions.insert(targetFunc);
- }
- }
- });
-
- // remove unused decomposition functions
- module.walk([&](func::FuncOp func) {
- if (DecompUtils::isDecompositionFunction(func) &&
- !usedDecompositionFunctions.contains(func)) {
- func.erase();
- }
- return WalkResult::skip();
- });
- }
-
public:
void runOnOperation() final
{
ModuleOp module = cast(getOperation());
// Step 1: Discover and register all decomposition functions in the module
- discoverAndRegisterDecompositions(module, decompositionRegistry);
+ llvm::StringSet<> targetRules;
+ for (auto rule : targetRulesOption) {
+ targetRules.insert(rule);
+ }
+ discoverAndRegisterDecompositions(module, decompositionRegistry, targetRules);
if (decompositionRegistry.empty()) {
return;
}
- // Step 1.1: Find the target gate set
+ // Step 2: Find the target gate set
findTargetGateSet(module, targetGateSet);
- // Step 2: Canonicalize the module
- RewritePatternSet patternsCanonicalization(&getContext());
- catalyst::quantum::CustomOp::getCanonicalizationPatterns(patternsCanonicalization,
- &getContext());
- if (failed(applyPatternsGreedily(module, std::move(patternsCanonicalization)))) {
- return signalPassFailure();
- }
-
- // Step 3: Apply the decomposition patterns
+ // Step 3: Apply the decomposition patterns, canonicalizing the insert/extract pairs
RewritePatternSet decompositionPatterns(&getContext());
populateDecomposeLoweringPatterns(decompositionPatterns, decompositionRegistry,
targetGateSet);
- if (failed(applyPatternsGreedily(module, std::move(decompositionPatterns)))) {
- return signalPassFailure();
- }
-
- // Step 4: Inline and canonicalize/CSE the module again
- PassManager pm(&getContext());
- pm.addPass(createInlinerPass());
- pm.addPass(createCanonicalizerPass());
- pm.addPass(createCSEPass());
- if (failed(pm.run(module))) {
- return signalPassFailure();
- }
-
- // Step 5. Remove redundant decomposition functions
- removeDecompositionFunctions(module, decompositionRegistry);
-
- // Step 6. Canonicalize the extract/insert pair
- RewritePatternSet patternsInsertExtract(&getContext());
- catalyst::quantum::InsertOp::getCanonicalizationPatterns(patternsInsertExtract,
+ catalyst::quantum::InsertOp::getCanonicalizationPatterns(decompositionPatterns,
&getContext());
- catalyst::quantum::ExtractOp::getCanonicalizationPatterns(patternsInsertExtract,
+ catalyst::quantum::ExtractOp::getCanonicalizationPatterns(decompositionPatterns,
&getContext());
- if (failed(applyPatternsGreedily(module, std::move(patternsInsertExtract)))) {
+ if (failed(applyPatternsGreedily(module, std::move(decompositionPatterns)))) {
return signalPassFailure();
}
}
diff --git a/mlir/lib/Quantum/Transforms/graph_decomposition.cpp b/mlir/lib/Quantum/Transforms/graph_decomposition.cpp
index a8bf13b12f..5a9425e9cb 100644
--- a/mlir/lib/Quantum/Transforms/graph_decomposition.cpp
+++ b/mlir/lib/Quantum/Transforms/graph_decomposition.cpp
@@ -12,27 +12,47 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#define DEBUG_TYPE "graph-decomposition"
-
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/WalkResult.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "Catalyst/Analysis/ResourceAnalysis.h"
+#include "Catalyst/Analysis/ResourceResult.h"
#include "Catalyst/Transforms/Passes.h"
#include "QRef/Transforms/Passes.h"
#include "Quantum/IR/QuantumDialect.h"
+#include "Quantum/IR/QuantumInterfaces.h"
#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Passes.h"
#include "Quantum/Transforms/QPDLoader.h"
@@ -41,6 +61,8 @@
#include "DGSolver.hpp"
#include "DGTypes.hpp"
+#define DEBUG_TYPE "graph-decomposition"
+
using namespace mlir;
using namespace catalyst::quantum;
using namespace DecompGraph::Core;
@@ -135,9 +157,12 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBasepush_back(rule.release());
+ if (!symbolTable.lookup(rule->getName())) {
+ module.getBody()->push_back(rule.release());
+ }
}
}
diff --git a/mlir/test/Quantum/DecomposeLoweringTargetRulesTest.mlir b/mlir/test/Quantum/DecomposeLoweringTargetRulesTest.mlir
new file mode 100644
index 0000000000..6c261d07dc
--- /dev/null
+++ b/mlir/test/Quantum/DecomposeLoweringTargetRulesTest.mlir
@@ -0,0 +1,51 @@
+// Copyright 2026 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// RUN: quantum-opt --pass-pipeline='builtin.module(decompose-lowering{target-rules=my_X_decomp,my_Z_decomp})' --split-input-file -verify-diagnostics %s | FileCheck %s
+
+// Test that decompose-lowering only applies the rules requested by the `target-rules` option when present
+
+module @test_module {
+ // CHECK: func.func private @my_X_decomp
+ func.func private @my_X_decomp(%q: !quantum.bit) -> !quantum.bit attributes {target_gate="X"} {
+ %angle = arith.constant 1.57 : f64
+ %out = quantum.custom "RX"(%angle) %q : !quantum.bit
+ return %out : !quantum.bit
+ }
+
+ // CHECK: func.func private @my_Y_decomp
+ func.func private @my_Y_decomp(%q: !quantum.bit) -> !quantum.bit attributes {target_gate="Y"} {
+ %angle = arith.constant 1.57 : f64
+ %out = quantum.custom "RY"(%angle) %q : !quantum.bit
+ return %out : !quantum.bit
+ }
+
+ // CHECK: func.func private @my_Z_decomp
+ func.func private @my_Z_decomp(%q: !quantum.bit) -> !quantum.bit attributes {target_gate="Z"} {
+ %angle = arith.constant 1.57 : f64
+ %out = quantum.custom "RZ"(%angle) %q : !quantum.bit
+ return %out : !quantum.bit
+ }
+
+ // CHECK: [[q:%.+]] = quantum.alloc_qb
+ // CHECK: [[x_out:%.+]] = quantum.custom "RX"(%{{.+}}) [[q]]
+ // CHECK: [[y_out:%.+]] = quantum.custom "Y"() [[x_out]]
+ // CHECK: [[z_out:%.+]] = quantum.custom "RZ"(%{{.+}}) [[y_out]]
+ // CHECK: quantum.dealloc_qb [[z_out]]
+ %0 = quantum.alloc_qb : !quantum.bit
+ %1 = quantum.custom "X"() %0 : !quantum.bit
+ %2 = quantum.custom "Y"() %1 : !quantum.bit
+ %3 = quantum.custom "Z"() %2 : !quantum.bit
+ quantum.dealloc_qb %3 : !quantum.bit
+}
diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir
index b7abe6f203..7a6d9d3566 100644
--- a/mlir/test/Quantum/DecomposeLoweringTest.mlir
+++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir
@@ -41,8 +41,8 @@ module @two_hadamards {
return %4 : tensor<4xf64>
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func private @Hadamard_to_RY_decomp
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func private @Hadamard_to_RY_decomp
func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} {
%cst = arith.constant 3.1415926535897931 : f64
%cst_0 = arith.constant 1.5707963267948966 : f64
@@ -55,6 +55,8 @@ module @two_hadamards {
// -----
// Test single Hadamard decomposition
+
+// CHECK-LABEL: module @single_hadamard
module @single_hadamard {
func.func @test_single_hadamard() -> !quantum.bit {
// CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64
@@ -73,8 +75,8 @@ module @single_hadamard {
return %2 : !quantum.bit
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func private @Hadamard_to_RY_decomp
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func private @Hadamard_to_RY_decomp
func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} {
%cst = arith.constant 3.1415926535897931 : f64
%cst_0 = arith.constant 1.5707963267948966 : f64
@@ -85,6 +87,8 @@ module @single_hadamard {
}
// -----
+
+// CHECK-LABEL: module @recursive
module @recursive {
func.func public @test_recursive() -> tensor<4xf64> attributes {quantum.node} {
%0 = quantum.alloc( 2) : !quantum.reg
@@ -112,15 +116,15 @@ module @recursive {
return %4 : tensor<4xf64>
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func private @Hadamard_to_RY_decomp
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func private @Hadamard_to_RY_decomp
func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} {
%out_qubits_0 = quantum.custom "RZRY"() %arg0 : !quantum.bit
return %out_qubits_0 : !quantum.bit
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func private @RZRY_decomp
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func private @RZRY_decomp
func.func private @RZRY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "RZRY", llvm.linkage = #llvm.linkage} {
%cst = arith.constant 3.1415926535897931 : f64
%cst_0 = arith.constant 1.5707963267948966 : f64
@@ -131,6 +135,8 @@ module @recursive {
}
// -----
+
+// CHECK-LABEL: module @recursive
module @recursive {
func.func public @test_recursive() -> tensor<4xf64> attributes {quantum.node} {
%0 = quantum.alloc( 2) : !quantum.reg
@@ -158,15 +164,15 @@ module @recursive {
return %4 : tensor<4xf64>
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func private @Hadamard_to_RY_decomp
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func private @Hadamard_to_RY_decomp
func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} {
%out_qubits_0 = quantum.custom "RZRY"() %arg0 : !quantum.bit
return %out_qubits_0 : !quantum.bit
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func private @RZRY_decomp
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func private @RZRY_decomp
func.func private @RZRY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "RZRY", llvm.linkage = #llvm.linkage} {
%cst = arith.constant 3.1415926535897931 : f64
%cst_0 = arith.constant 1.5707963267948966 : f64
@@ -179,6 +185,8 @@ module @recursive {
// -----
// Test parametric gates and wires
+
+// CHECK-LABEL: module @param_rxry
module @param_rxry {
func.func public @test_param_rxry(%arg0: tensor, %arg1: tensor) -> tensor<2xf64> attributes {quantum.node} {
%c0_i64 = arith.constant 0 : i64
@@ -209,7 +217,7 @@ module @param_rxry {
}
// Decomposition function expects tensor while operation provides f64
- // CHECK-NOT: func.func private @ParametrizedRX_decomp
+ // CHECK: func.func private @ParametrizedRXRY_decomp
func.func private @ParametrizedRXRY_decomp(%arg0: tensor, %arg1: !quantum.bit) -> !quantum.bit
attributes {target_gate = "ParametrizedRXRY", llvm.linkage = #llvm.linkage} {
%extracted = tensor.extract %arg0[] : tensor
@@ -219,92 +227,60 @@ module @param_rxry {
return %out_qubits_1 : !quantum.bit
}
}
-// -----
-
-// Test parametric gates and wires
-module @param_rxry_2 {
- func.func public @test_param_rxry_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<2xf64> attributes {quantum.node} {
- %c0_i64 = arith.constant 0 : i64
-
- // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg
- %0 = quantum.alloc( 1) : !quantum.reg
-
- // CHECK: [[WIRE:%.+]] = tensor.extract %arg2[] : tensor
- %extracted = tensor.extract %arg2[] : tensor
-
- // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][[[WIRE]]] : !quantum.reg -> !quantum.bit
- %1 = quantum.extract %0[%extracted] : !quantum.reg -> !quantum.bit
-
- // CHECK: [[PARAM_0:%.+]] = tensor.extract %arg0[] : tensor
- %param_0 = tensor.extract %arg0[] : tensor
-
- // CHECK: [[PARAM_1:%.+]] = tensor.extract %arg1[] : tensor
- %param_1 = tensor.extract %arg1[] : tensor
-
- // CHECK: [[QUBIT1:%.+]] = quantum.custom "RX"([[PARAM_0]]) [[QUBIT]] : !quantum.bit
- // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[PARAM_1]]) [[QUBIT1]] : !quantum.bit
- // CHECK-NOT: quantum.custom "ParametrizedRXRY"
- %out_qubits = quantum.custom "ParametrizedRXRY"(%param_0, %param_1) %1 : !quantum.bit
-
- // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT2]] : !quantum.reg, !quantum.bit
- %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
- %3 = quantum.compbasis qreg %2 : !quantum.obs
- %4 = quantum.probs %3 : tensor<2xf64>
- quantum.dealloc %2 : !quantum.reg
- return %4 : tensor<2xf64>
- }
- // Decomposition function expects tensor while operation provides f64
- // CHECK-NOT: func.func private @ParametrizedRX_decomp
- func.func private @ParametrizedRXRY_decomp(%arg0: tensor, %arg1: tensor, %arg2: !quantum.bit) -> !quantum.bit
- attributes {target_gate = "ParametrizedRXRY", llvm.linkage = #llvm.linkage} {
- %extracted_param_0 = tensor.extract %arg0[] : tensor
- %out_qubits = quantum.custom "RX"(%extracted_param_0) %arg2 : !quantum.bit
- %extracted_param_1 = tensor.extract %arg1[] : tensor
- %out_qubits_1 = quantum.custom "RY"(%extracted_param_1) %out_qubits : !quantum.bit
- return %out_qubits_1 : !quantum.bit
- }
-}
// -----
// Test recursive and qreg-based gate decomposition
+
+// CHECK-LABEL: module @qreg_base_circuit
module @qreg_base_circuit {
func.func public @test_qreg_base_circuit() -> tensor<2xf64> attributes {quantum.node} {
- // CHECK: [[CST:%.+]] = arith.constant 1.000000e+00 : f64
+ // CHECK-DAG: [[cmp_0:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor
+ // CHECK-DAG: [[test_angle:%.+]] = arith.constant 1.000000e+00 : f64
+ // CHECK-DAG: [[index_tensor:%.+]] = arith.constant dense<0> : tensor<1xi64>
+ // CHECK-DAG: [[cmp_1:%.+]] = arith.constant dense<1.000000e+00> : tensor
+ // CHECK: [[reg0:%.+]] = quantum.alloc( 1) : !quantum.reg
%cst = arith.constant 1.000000e+00 : f64
-
- // CHECK: [[CST_0:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor
- // CHECK: [[CST_1:%.+]] = arith.constant dense<0> : tensor<1xi64>
- // CHECK: [[CST_2:%.+]] = arith.constant dense<1.000000e+00> : tensor
- // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg
%0 = quantum.alloc( 1) : !quantum.reg
- // CHECK: [[EXTRACT_QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit
- // CHECK: [[MRES:%.+]], [[OUT_QUBIT:%.+]] = quantum.measure [[EXTRACT_QUBIT]] : i1, !quantum.bit
- // CHECK: [[REG1:%.+]] = quantum.insert [[REG]][ 0], [[OUT_QUBIT]] : !quantum.reg, !quantum.bit
- // CHECK: [[COMPARE:%.+]] = stablehlo.compare NE, [[CST_2]], [[CST_0]], FLOAT : (tensor, tensor) -> tensor
- // CHECK: [[EXTRACTED:%.+]] = tensor.extract [[COMPARE]][] : tensor
- // CHECK: [[CONDITIONAL:%.+]] = scf.if [[EXTRACTED]] -> (!quantum.reg) {
- // CHECK: [[SLICE1:%.+]] = stablehlo.slice [[CST_1]] [0:1] : (tensor<1xi64>) -> tensor<1xi64>
- // CHECK: [[RESHAPE1:%.+]] = stablehlo.reshape [[SLICE1]] : (tensor<1xi64>) -> tensor
- // CHECK: [[EXTRACTED_3:%.+]] = tensor.extract [[RESHAPE1]][] : tensor
- // CHECK: [[FROM_ELEMENTS:%.+]] = tensor.from_elements [[EXTRACTED_3]] : tensor<1xi64>
- // CHECK: [[SLICE2:%.+]] = stablehlo.slice [[FROM_ELEMENTS]] [0:1] : (tensor<1xi64>) -> tensor<1xi64>
- // CHECK: [[RESHAPE2:%.+]] = stablehlo.reshape [[SLICE2]] : (tensor<1xi64>) -> tensor
- // CHECK: [[EXTRACTED_4:%.+]] = tensor.extract [[RESHAPE2]][] : tensor
- // CHECK: [[EXTRACT1:%.+]] = quantum.extract [[REG1]][[[EXTRACTED_4]]] : !quantum.reg -> !quantum.bit
- // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST]]) [[EXTRACT1]] : !quantum.bit
- // CHECK: [[INSERT1:%.+]] = quantum.insert [[REG1]][[[EXTRACTED_4]]], [[RZ1]] : !quantum.reg, !quantum.bit
- // CHECK: [[EXTRACT2:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED_3]]] : !quantum.reg -> !quantum.bit
- // CHECK: [[INSERT2:%.+]] = quantum.insert [[REG1]][[[EXTRACTED_3]]], [[EXTRACT2]] : !quantum.reg, !quantum.bit
- // CHECK: [[EXTRACT3:%.+]] = quantum.extract [[INSERT2]][[[EXTRACTED_4]]] : !quantum.reg -> !quantum.bit
- // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST]]) [[EXTRACT3]] : !quantum.bit
- // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED_4]]], [[RZ2]] : !quantum.reg, !quantum.bit
- // CHECK: [[EXTRACT4:%.+]] = quantum.extract [[INSERT3]][[[EXTRACTED_3]]] : !quantum.reg -> !quantum.bit
- // CHECK: [[INSERT4:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED_3]]], [[EXTRACT4]] : !quantum.reg, !quantum.bit
- // CHECK: scf.yield [[INSERT4]] : !quantum.reg
+
+ // CHECK: [[q0:%.+]] = quantum.extract [[reg0]][ 0] : !quantum.reg -> !quantum.bit
+ // CHECK: [[meas:%.+]], [[q1:%.+]] = quantum.measure [[q0]] : i1, !quantum.bit
+ // CHECK: [[reg1:%.+]] = quantum.insert [[reg0]][ 0], [[q1]] : !quantum.reg, !quantum.bit
+ // CHECK: [[cmp:%.+]] = stablehlo.compare NE, [[cmp_1]], [[cmp_0]], FLOAT : (tensor, tensor) -> tensor
+ // CHECK: [[cond:%.+]] = tensor.extract [[cmp]][] : tensor
+ // CHECK: [[condresult:%.+]] = scf.if [[cond]] -> (!quantum.reg) {
+ // CHECK: [[slice0:%.+]] = stablehlo.slice [[index_tensor]] [0:1] : (tensor<1xi64>) -> tensor<1xi64>
+ // CHECK: [[reshape0:%.+]] = stablehlo.reshape [[slice0]] : (tensor<1xi64>) -> tensor
+ // CHECK: [[index0:%.+]] = tensor.extract [[reshape0]][] : tensor
+ // CHECK: [[fromelements0:%.+]] = tensor.from_elements [[index0]] : tensor<1xi64>
+ // CHECK: [[slice1:%.+]] = stablehlo.slice [[fromelements0]] [0:1] : (tensor<1xi64>) -> tensor<1xi64>
+ // CHECK: [[reshape1:%.+]] = stablehlo.reshape [[slice1]] : (tensor<1xi64>) -> tensor
+ // CHECK: [[index1:%.+]] = tensor.extract [[reshape1]][] : tensor
+ // CHECK: [[q2:%.+]] = quantum.extract [[reg1]][[[index1]]] : !quantum.reg -> !quantum.bit
+ // CHECK: [[q3:%.+]] = quantum.custom "RZ"([[test_angle]]) [[q2]] : !quantum.bit
+ // CHECK: [[index2:%.+]] = tensor.extract [[reshape1]][]
+ // CHECK: [[reg2:%.+]] = quantum.insert [[reg1]][[[index2]]], [[q3]] : !quantum.reg, !quantum.bit
+ // CHECK: [[q4:%.+]] = quantum.extract [[reg2]][[[index0]]] : !quantum.reg -> !quantum.bit
+ // CHECK: [[slice3:%.+]] = stablehlo.slice [[index_tensor]] [0:1] : (tensor<1xi64>) -> tensor<1xi64>
+ // CHECK: [[reshape3:%.+]] = stablehlo.reshape [[slice3]] : (tensor<1xi64>) -> tensor
+ // CHECK: [[extract6:%.+]] = tensor.extract [[reshape0]][]
+ // CHECK: [[reg3:%.+]] = quantum.insert [[reg1]][[[extract6]]], [[q4]] : !quantum.reg, !quantum.bit
+ // CHECK: [[index3:%.+]] = tensor.extract [[reshape3]][]
+ // CHECK: [[fromelements1:%.+]] = tensor.from_elements [[index3]]
+ // CHECK: [[slice4:%.+]] = stablehlo.slice [[fromelements1]] [0:1]
+ // CHECK: [[reshape4:%.+]] = stablehlo.reshape [[slice4]]
+ // CHECK: [[index4:%.+]] = tensor.extract [[reshape4]][]
+ // CHECK: [[q5:%.+]] = quantum.extract [[reg3]][[[index4]]] : !quantum.reg -> !quantum.bit
+ // CHECK: [[q6:%.+]] = quantum.custom "RZ"([[test_angle]]) [[q5]] : !quantum.bit
+ // CHECK: [[index5:%.+]] = tensor.extract [[reshape4]][]
+ // CHECK: [[reg4:%.+]] = quantum.insert [[reg3]][[[index5]]], [[q6]] : !quantum.reg, !quantum.bit
+ // CHECK: [[q7:%.+]] = quantum.extract [[reg4]][[[index3]]] : !quantum.reg -> !quantum.bit
+ // CHECK: [[index6:%.+]] = tensor.extract [[reshape3]][]
+ // CHECK: [[out:%.+]] = quantum.insert [[reg3]][[[index6]]], [[q7]] : !quantum.reg, !quantum.bit
+ // CHECK: scf.yield [[out]] : !quantum.reg
// CHECK: } else {
- // CHECK: scf.yield [[REG1]] : !quantum.reg
+ // CHECK: scf.yield [[reg1]] : !quantum.reg
// CHECK: }
// CHECK-NOT: quantum.custom "Test"
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
@@ -318,8 +294,8 @@ module @qreg_base_circuit {
return %4 : tensor<2xf64>
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func private @Test_rule_1
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func private @Test_rule_1
func.func private @Test_rule_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg
attributes {target_gate = "Test", llvm.linkage = #llvm.linkage} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor
@@ -352,8 +328,8 @@ module @qreg_base_circuit {
return %1 : !quantum.reg
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func private @RzDecomp_rule_1
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func private @RzDecomp_rule_1
func.func private @RzDecomp_rule_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg
attributes {target_gate = "RzDecomp", llvm.linkage = #llvm.linkage} {
%0 = stablehlo.slice %arg2 [0:1] : (tensor<1xi64>) -> tensor<1xi64>
@@ -370,12 +346,12 @@ module @qreg_base_circuit {
// -----
+// CHECK-LABEL: module @multi_wire_cnot_decomposition
module @multi_wire_cnot_decomposition {
func.func public @test_cnot_decomposition() -> tensor<4xf64> attributes {quantum.node} {
%0 = quantum.alloc( 2) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
%2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
-
// CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64
// CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64
// CHECK: [[WIRE_TENSOR:%.+]] = arith.constant dense<[0, 1]> : tensor<2xi64>
@@ -388,13 +364,22 @@ module @multi_wire_cnot_decomposition {
// CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit
// CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT1]] : !quantum.bit
// CHECK: [[RY1:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ1]] : !quantum.bit
+ // CHECK: [[index:%.+]] = tensor.extract [[RESHAPE2]][]
+ // CHECK: [[INSERT_TARGET:%.+]] = quantum.insert [[REG]][[[index]]], [[RY1]] : !quantum.reg, !quantum.bit
// CHECK: [[EXTRACTED2:%.+]] = tensor.extract [[RESHAPE1]][] : tensor
- // CHECK: [[QUBIT0:%.+]] = quantum.extract [[REG]][[[EXTRACTED2]]] : !quantum.reg -> !quantum.bit
- // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[RY1]] : !quantum.bit, !quantum.bit
- // CHECK: [[INSERT2:%.+]] = quantum.insert [[REG]][[[EXTRACTED2]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit
- // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[CZ_RESULT]]#1 : !quantum.bit
+ // CHECK: [[QUBIT0:%.+]] = quantum.extract [[INSERT_TARGET]][[[EXTRACTED2]]] : !quantum.reg -> !quantum.bit
+ // CHECK: [[index2:%.+]] = tensor.extract [[RESHAPE2]][]
+ // CHECK: [[QUBIT1_UPDATED:%.+]] = quantum.extract [[INSERT_TARGET]][[[index2]]] : !quantum.reg -> !quantum.bit
+ // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[QUBIT1_UPDATED]] : !quantum.bit, !quantum.bit
+ // CHECK: [[index3:%.+]] = tensor.extract [[RESHAPE1]][]
+ // CHECK: [[INSERT2:%.+]] = quantum.insert [[INSERT_TARGET]][[[index3]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit
+ // CHECK: [[index4:%.+]] = tensor.extract [[RESHAPE2]][]
+ // CHECK: [[INSERT_CZ1:%.+]] = quantum.insert [[INSERT2]][[[index4]]], [[CZ_RESULT]]#1 : !quantum.reg, !quantum.bit
+ // CHECK: [[TARGET_AFTER_CZ:%.+]] = quantum.extract [[INSERT_CZ1]][{{%.+}}] : !quantum.reg -> !quantum.bit
+ // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[TARGET_AFTER_CZ]] : !quantum.bit
// CHECK: [[RY2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ2]] : !quantum.bit
- // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED]]], [[RY2]] : !quantum.reg, !quantum.bit
+ // CHECK: [[index5:%.+]] = tensor.extract [[RESHAPE2]][]
+ // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT_CZ1]][[[index5]]], [[RY2]] : !quantum.reg, !quantum.bit
// CHECK: [[FINAL_QUBIT0:%.+]] = quantum.extract [[INSERT3]][ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[FINAL_QUBIT1:%.+]] = quantum.extract [[INSERT3]][ 1] : !quantum.reg -> !quantum.bit
// CHECK-NOT: quantum.custom "CNOT"
@@ -410,8 +395,8 @@ module @multi_wire_cnot_decomposition {
return %8 : tensor<4xf64>
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func private @CNOT_rule_cz_rz_ry
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func private @CNOT_rule_cz_rz_ry
func.func private @CNOT_rule_cz_rz_ry(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage} {
// CNOT decomposition: CNOT = (I ⊗ H) * CZ * (I ⊗ H)
%cst = arith.constant 1.5707963267948966 : f64
@@ -456,6 +441,7 @@ module @multi_wire_cnot_decomposition {
// -----
+// CHECK-LABEL: module @cnot_alternative_decomposition
module @cnot_alternative_decomposition {
func.func public @test_cnot_alternative_decomposition() -> tensor<4xf64> attributes {quantum.node} {
%0 = quantum.alloc( 2) : !quantum.reg
@@ -485,8 +471,8 @@ module @cnot_alternative_decomposition {
return %8 : tensor<4xf64>
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func private @CNOT_rule_h_cnot_h
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func private @CNOT_rule_h_cnot_h
func.func private @CNOT_rule_h_cnot_h(%arg0: !quantum.bit, %arg1: !quantum.bit) -> (!quantum.bit, !quantum.bit) attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage} {
// CNOT decomposition: CNOT = (I ⊗ H) * CZ * (I ⊗ H)
%cst = arith.constant 1.5707963267948966 : f64
@@ -509,6 +495,7 @@ module @cnot_alternative_decomposition {
// -----
+// CHECK-LABEL: module @mcm_example
module @mcm_example {
func.func public @test_mcm_hadamard() -> tensor<2xf64> attributes {quantum.node} {
%0 = quantum.alloc( 1) : !quantum.reg
@@ -516,9 +503,9 @@ module @mcm_example {
%mres, %out_qubit = quantum.measure %1 : i1, !quantum.bit
%2 = quantum.insert %0[ 0], %out_qubit : !quantum.reg, !quantum.bit
- // CHECK: [[RZ_QUBIT:%.+]] = quantum.custom "RZ"([[CST_0:%.+]])
- // CHECK: [[RY_QUBIT:%.+]] = quantum.custom "RY"([[CST_1:%.+]]) [[RZ_QUBIT]] : !quantum.bit
- // CHECK: [[REG_1:%.+]] = quantum.insert [[REG:%.+]][[[EXTRACTED:%.+]]], [[RY_QUBIT]] : !quantum.reg, !quantum.bit
+ // CHECK: quantum.custom "RZ"
+ // CHECK: quantum.custom "RY"
+
// CHECK-NOT: quantum.custom "Hadamard"
%3 = quantum.extract %2[ 0] : !quantum.reg -> !quantum.bit
%out_qubits = quantum.custom "Hadamard"() %3 : !quantum.bit
@@ -530,8 +517,8 @@ module @mcm_example {
return %6 : tensor<2xf64>
}
- // Decomposition function should be applied and removed from the module
- // CHECK-NOT: func.func public @rz_ry
+ // Decomposition function should be retained for future passes
+ // CHECK: func.func public @rz_ry
func.func public @rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} {
%cst = arith.constant 3.1415926535897931 : f64
%cst_0 = arith.constant 1.5707963267948966 : f64
@@ -555,15 +542,16 @@ module @mcm_example {
// -----
+// CHECK-LABEL: module @circuit_with_multirz
module @circuit_with_multirz {
func.func public @test_with_multirz() -> tensor<4xf64> attributes {quantum.node} {
%0 = quantum.alloc( 2) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
// CHECK: func.func public @test_with_multirz() -> tensor<4xf64>
- // CHECK: [[CST_RZ:%.+]] = arith.constant 5.000000e-01 : f64
- // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64
- // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64
- // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg
+ // CHECK-DAG: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64
+ // CHECK-DAG: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64
+ // CHECK-DAG: [[CST_RZ:%.+]] = arith.constant 5.000000e-01 : f64
+ // CHECK-DAG: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg
// CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_RZ]]) {{%.+}} : !quantum.bit
// CHECK-NOT: quantum.multirz
@@ -584,7 +572,7 @@ module @circuit_with_multirz {
return %4 : tensor<4xf64>
}
- // CHECK-NOT: func.func private @Hadamard_to_RY_decomp
+ // CHECK: func.func private @Hadamard_to_RY_decomp
func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} {
%cst = arith.constant 3.1415926535897931 : f64
%cst_0 = arith.constant 1.5707963267948966 : f64
@@ -593,8 +581,8 @@ module @circuit_with_multirz {
return %out_qubits_1 : !quantum.bit
}
- // CHECK-NOT: func.func private @_multi_rz_decomposition_wires_1
- func.func public @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"} {
+ // CHECK: func.func private @_multi_rz_decomposition_wires_1
+ func.func private @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"} {
%0 = stablehlo.slice %arg2 [0:1] : (tensor<1xi64>) -> tensor<1xi64>
%1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor
%extracted = tensor.extract %1[] : tensor
@@ -610,6 +598,7 @@ module @circuit_with_multirz {
// -----
+// CHECK-LABEL: module @qreg_at_not_first_arg
module @qreg_at_not_first_arg {
func.func public @test_qreg_at_not_first_arg() attributes {quantum.node} {
// CHECK: [[wire_tensor:%.+]] = arith.constant dense<[0, 1]> : tensor<2xi64>
@@ -635,7 +624,7 @@ module @qreg_at_not_first_arg {
return
}
- // CHECK-NOT: func.func private @my_cnot
+ // CHECK: func.func private @my_cnot
func.func private @my_cnot(%arg0: tensor<2xi64>, %arg1: !quantum.reg) -> !quantum.reg attributes {target_gate = "CNOT"} {
%0 = stablehlo.slice %arg0 [0:1] : (tensor<2xi64>) -> tensor<1xi64>
%1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor
@@ -654,6 +643,7 @@ module @qreg_at_not_first_arg {
// -----
+// CHECK-LABEL: module @test_paulirot
module @test_paulirot {
func.func @test() attributes {quantum.node} {
%pi = arith.constant 3.1 : f64
@@ -678,7 +668,7 @@ module @test_paulirot {
return
}
- // CHECK-NOT: my_paulirot_decomp
+ // CHECK: my_paulirot_decomp
func.func private @my_paulirot_decomp(%inreg : !quantum.reg, %angle_tensor : tensor, %q_tensor : tensor<3xi64>) -> !quantum.reg attributes {target_gate = "paulirotZXY"} {
%pi_by_2 = arith.constant 1.57 : f64
%m_pi_by_2 = arith.constant -1.57 : f64
@@ -717,6 +707,7 @@ module @test_paulirot {
// -----
+// CHECK-LABEL: module @null_decomp_rule
module @null_decomp_rule{
func.func public @test_null_decomp_rule() attributes {quantum.node} {
// CHECK: [[reg:%.+]] = quantum.alloc( 1)
@@ -733,7 +724,7 @@ module @null_decomp_rule{
return
}
- // CHECK-NOT: func.func private @null_decomp
+ // CHECK: func.func private @null_decomp
func.func private @null_decomp() attributes {target_gate = "PauliX"} {
return
}
@@ -741,6 +732,7 @@ module @null_decomp_rule{
// -----
+// CHECK-LABEL: module @different_qreg_values
module @different_qreg_values{
func.func public @circuit() attributes {quantum.node} {
// CHECK: [[wire_tensor:%.+]] = arith.constant dense<[2, 1]> : tensor<2xi64>
@@ -776,7 +768,7 @@ module @different_qreg_values{
return
}
- // CHECK-NOT: func.func private @my_cnot
+ // CHECK: func.func private @my_cnot
func.func private @my_cnot(%arg0: tensor<2xi64>, %arg1: !quantum.reg) -> !quantum.reg attributes {target_gate = "CNOT"} {
%0 = stablehlo.slice %arg0 [0:1] : (tensor<2xi64>) -> tensor<1xi64>
%1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor