diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index ce4b2ed980..0bc19155f7 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -205,6 +205,9 @@
Internal changes ⚙️
+* The `graph-decomposition` pass now performs far less IR manipulation.
+ [(#2977)](https://github.com/PennyLaneAI/catalyst/pull/2977)
+
* Update tests to not use global capture toggle where possible.
[(#2964)](https://github.com/PennyLaneAI/catalyst/pull/2964)
diff --git a/mlir/lib/Quantum/Transforms/graph_decomposition.cpp b/mlir/lib/Quantum/Transforms/graph_decomposition.cpp
index 5a9425e9cb..c1ec170efb 100644
--- a/mlir/lib/Quantum/Transforms/graph_decomposition.cpp
+++ b/mlir/lib/Quantum/Transforms/graph_decomposition.cpp
@@ -104,10 +104,7 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase setOfOps;
std::vector setOfRules;
- llvm::StringMap> ruleNameToFuncOp;
llvm::StringSet<> userRuleNames;
- llvm::SmallVector>
- allUserRules; // includes rules unused in this decomp
llvm::StringMap opToFixedDecompName;
llvm::StringMap> opToAltDecompNames;
WeightedGateset targetGateSet;
@@ -128,7 +125,9 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase(rule->getName())) {
- module.getBody()->push_back(rule.release());
- }
- }
}
private:
@@ -236,66 +228,111 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase> &ruleRegistry)
+ LogicalResult addRuleNode(mlir::func::FuncOp rule, std::vector &ruleNodes)
+ {
+ llvm::StringRef ruleName = rule.getName();
+
+ // 1. Mandatory Attribute Check (Target Gate and Resources)
+ auto targetGateAttr = rule->getAttrOfType("target_gate");
+ auto resourcesAttr = rule->getAttrOfType("resources");
+ if (!targetGateAttr) {
+ llvm::errs() << "Cannot parse decomposition rule " << ruleName
+ << " without the `target_gate` attribute.\n";
+ LDBG() << rule;
+ return failure();
+ }
+
+ // Ensure resources
+ if (!resourcesAttr) {
+ ResourceAnalysis analysis(rule);
+ if (const ResourceResult *flat = analysis.getFlattenedResource(rule.getName())) {
+ rule->setAttr("resources", buildResourceDict(&getContext(), *flat));
+ }
+ resourcesAttr = rule->getAttrOfType("resources");
+ }
+
+ // 2. Extract 'operations' dictionary from resources
+ auto operations = mlir::dyn_cast_or_null(resourcesAttr.get("operations"));
+ if (!operations) {
+ llvm::errs() << "Cannot parse resource for decomposition rule " << ruleName
+ << " without `operations` attribute.\n";
+ LDBG() << rule;
+ return failure();
+ }
+
+ // 3. Populate RuleNode
+ RuleNode ruleNode;
+ ruleNode.name = ruleName.str();
+ ruleNode.output = parseOperator(targetGateAttr.getValue());
+
+ for (const auto &namedAttr : operations) {
+ if (auto intAttr = mlir::dyn_cast(namedAttr.getValue())) {
+ ruleNode.inputs.push_back({parseOperator(namedAttr.getName().strref()),
+ static_cast(intAttr.getInt())});
+ }
+ }
+
+ // 4. Add RuleNode
+ ruleNodes.push_back(std::move(ruleNode));
+ return success();
+ }
+
+ LogicalResult loadBuiltInDecompositionRules(llvm::StringRef filename,
+ std::vector &ruleNodes)
{
mlir::MLIRContext *context = &getContext();
+ mlir::ModuleOp module = getOperation();
mlir::ParserConfig config(context);
mlir::OwningOpRef moduleOp =
mlir::parseSourceFile(filename, config);
+ SymbolTable symbolTable(module);
+
if (!moduleOp) {
- mlir::emitError(mlir::UnknownLoc::get(context))
- << "failed to load built-in decomposition rules from '" << filename
- << "': the rules file could not be parsed";
- return;
+ llvm::errs() << "failed to load built-in decomposition rules from '" << filename
+ << "': the rules file could not be parsed\n";
+ return failure();
}
for (auto rule : llvm::make_early_inc_range(moduleOp.get().getOps())) {
- rule->remove();
- ruleRegistry.push_back(std::move(rule));
+ if (failed(addRuleNode(rule, ruleNodes))) {
+ return failure();
+ }
+ // avoid double-insertion
+ if (!symbolTable.lookup(rule.getName())) {
+ rule->remove();
+ module.push_back(std::move(rule));
+ }
}
- return;
+ return success();
}
/**
- * @brief Remove user rules from the module, loading into
+ * @brief Load the listed user rules into the set of RuleNodes for the graph.
*/
- LogicalResult
- loadUserDecompositionRules(llvm::StringSet<> &userRuleNames,
- llvm::SmallVector> &graphRules,
- llvm::SmallVector> &rules)
+ LogicalResult loadUserDecompositionRules(llvm::StringSet<> &userRuleNames,
+ std::vector &ruleNodes)
{
mlir::ModuleOp module = getOperation();
if (userRuleNames.empty()) {
return success();
}
- PassManager pm(&getContext());
- pm.addPass(createRegisterDecompRuleResourcePass());
- if (failed(pm.run(module))) {
- module.emitError() << "failed to load user decomposition rules: unable to run resource "
- "annotation pass";
- return failure();
- }
-
- llvm::SmallVector userRules;
-
- module.walk([&](mlir::func::FuncOp func) {
+ WalkResult walkResult = module.walk([&](mlir::func::FuncOp func) {
if (func->hasAttr("target_gate")) {
- userRules.push_back(func);
if (userRuleNames.contains(func.getName())) {
- graphRules.push_back(mlir::OwningOpRef(func.clone()));
+ if (failed(addRuleNode(func, ruleNodes))) {
+ return WalkResult::interrupt();
+ }
}
}
return WalkResult::skip();
});
- for (auto rule : llvm::make_early_inc_range(userRules)) {
- rule->remove();
- rules.push_back(std::move(rule));
+ if (walkResult.wasInterrupted()) {
+ return failure();
}
+
return success();
}
@@ -305,19 +342,29 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase> &ruleRegistry)
+ mlir::LogicalResult loadPauliRotRules(std::vector &ruleNodes)
{
mlir::ModuleOp module = getOperation();
MLIRContext *context = &getContext();
llvm::StringSet<> addedWords;
+ // Add words from existing paulirot rules
+ module.walk([&](mlir::func::FuncOp func) {
+ if (func->hasAttr("target_gate")) {
+ if (func.getName().starts_with("paulirot_decomp_rule_")) {
+ addedWords.insert(func.getName().drop_front(21));
+ }
+ }
+ });
llvm::SmallVector pauliRotOps;
module.walk([&](quantum::PauliRotOp op) { pauliRotOps.push_back(op); });
if (!pauliRotOps.empty()) {
- loadQPD(libQPDPath, libpythonPath);
+ if (!loadQPD(libQPDPath, libpythonPath)) {
+ llvm::errs() << "failed to load libQuantumPythonCallbacks\n";
+ return failure();
+ }
}
for (quantum::PauliRotOp pauliRot : pauliRotOps) {
@@ -360,20 +407,34 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBasesetName((outOp->getName() + "_" + pauliWord).str()); // unique name per pauliword
funcOp->setAttr("target_gate", mlir::StringAttr::get(context, "paulirot" + pauliWord));
- auto analysis = ResourceAnalysis(funcOp);
- if (const ResourceResult *flat = analysis.getFlattenedResource(funcOp.getName())) {
- funcOp->setAttr("resources", buildResourceDict(context, *flat));
+ if (failed(addRuleNode(funcOp, ruleNodes))) {
+ return failure();
}
-
- ruleRegistry.push_back(std::move(outOp));
+ LDBG() << "adding rule " << funcOp.getName();
+ module.push_back(std::move(outOp.release()));
}
-
return success();
}
+ bool isInDecompRule(Operation *op)
+ {
+ while (auto parentOp = op->getParentOp()) {
+ if (auto funcOp = dyn_cast(parentOp)) {
+ if (funcOp->hasAttr("target_gate")) {
+ return true;
+ }
+ }
+ op = parentOp;
+ }
+ return false;
+ }
+
void getOperators(std::vector &operators)
{
getOperation().walk([&](quantum::QuantumGate op) {
+ if (isInDecompRule(op)) {
+ return;
+ }
OperatorNode node;
node.numWires = op.getNonCtrlQubitOperands().size();
node.adjoint = op.getAdjointFlag();
@@ -455,84 +516,22 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase &rules,
- llvm::StringSet<> &userRuleNames,
- llvm::SmallVector> &userRules,
- llvm::StringMap> &ruleNameToFuncOp)
+ LogicalResult getRuleNodes(llvm::StringRef filename, std::vector &rules,
+ llvm::StringSet<> &userRuleNames)
{
- llvm::SmallVector> graphRules;
-
- // Load rules from bytecode and user-defined rules
- loadBuiltInDecompositionRules(filename, graphRules);
- if (failed(loadPauliRotRules(graphRules))) {
- return signalPassFailure();
- }
- if (failed(loadUserDecompositionRules(userRuleNames, graphRules, userRules))) {
- return signalPassFailure();
- }
-
- for (auto &ruleOpRef : graphRules) {
- mlir::func::FuncOp func = ruleOpRef.get();
- llvm::StringRef ruleName = func.getName();
+ // Load pre-compiled rules (ignore failure, we can try to solve without)
+ std::ignore = loadBuiltInDecompositionRules(filename, rules);
- // 1. Mandatory Attribute Check (Target Gate and Resources)
- auto targetGateAttr = func->getAttrOfType("target_gate");
- auto resourcesAttr = func->getAttrOfType("resources");
- if (!targetGateAttr || !resourcesAttr)
- continue;
-
- // 2. Extract 'operations' dictionary from resources
- auto operations =
- mlir::dyn_cast_or_null(resourcesAttr.get("operations"));
- if (!operations)
- continue;
-
- // 3. Populate RuleNode
- RuleNode ruleNode;
- ruleNode.name = ruleName.str();
- ruleNode.output = parseOperator(targetGateAttr.getValue());
-
- for (const auto &namedAttr : operations) {
- if (auto intAttr = mlir::dyn_cast(namedAttr.getValue())) {
- ruleNode.inputs.push_back({parseOperator(namedAttr.getName().strref()),
- static_cast(intAttr.getInt())});
- }
- }
-
- // 4. Finalize: move the OpRef to the map to keep IR alive and store the node
- ruleNameToFuncOp[ruleNode.name] = std::move(ruleOpRef);
- rules.push_back(std::move(ruleNode));
+ // Lower and load compile-time rules
+ if (failed(loadPauliRotRules(rules))) {
+ return failure();
}
- }
- /**
- * @brief Insert the decomposition rules picked by the graph solver into the module for
- * later use in the decompose-lowering patterns to apply the decomposition rules and rewrite
- * the quantum operations.
- *
- * @param solution The chosen decomposition rules from the graph solver.
- * @param ruleNameToFuncOp A mapping from rule names to their corresponding function
- * operations.
- */
- void insertChosenRules(GraphResult &solution,
- llvm::StringMap> &ruleNameToFuncOp)
- {
- mlir::ModuleOp module = getOperation();
- for (const auto &[_, chosenRule] : solution) {
- if (chosenRule.isBasis) {
- continue; // skip basis rules as they don't correspond to actual decomposition
- // functions to insert
- }
- auto it = ruleNameToFuncOp.find(chosenRule.ruleName);
-
- if (it == ruleNameToFuncOp.end() || !it->second) {
- // skip if the rule is not found or
- // the function op is null or
- // it is already moved
- continue;
- }
- module.push_back(it->second.release());
+ // Load user-rules
+ if (failed(loadUserDecompositionRules(userRuleNames, rules))) {
+ return failure();
}
+ return success();
}
/**
@@ -573,7 +572,8 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase !quantum.bit {
// XZ: PauliX
// XZ: PauliZ
%qout = quantum.custom "PauliY"() %q : !quantum.bit
+
+ // needed to ensure we don't match in the following decomposition rules
+ // CHECK: return
return %qout : !quantum.bit
}
+// CHECK-LABEL: y_to_ry
func.func @y_to_ry(%q0 : !quantum.bit) -> !quantum.bit attributes {target_gate="PauliY"} {
%pi = arith.constant 3.14 : f64
%negpiby2 = arith.constant -1.57 : f64
@@ -38,6 +42,7 @@ func.func @y_to_ry(%q0 : !quantum.bit) -> !quantum.bit attributes {target_gate="
return %q1 : !quantum.bit
}
+// CHECK-LABEL: y_to_x_z
func.func @y_to_x_z(%q0 : !quantum.bit) -> !quantum.bit attributes {target_gate="PauliY"} {
%q1 = quantum.custom "PauliX"() %q0 : !quantum.bit
%q2 = quantum.custom "PauliZ"() %q1 : !quantum.bit