From 3648480198859f05155d6e49a97db1c3939ea90c Mon Sep 17 00:00:00 2001 From: River McCubbin Date: Thu, 25 Jun 2026 13:12:49 -0400 Subject: [PATCH 1/4] update pass --- .../Transforms/graph_decomposition.cpp | 237 +++++++++--------- 1 file changed, 122 insertions(+), 115 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/graph_decomposition.cpp b/mlir/lib/Quantum/Transforms/graph_decomposition.cpp index 5a9425e9cb..351c492537 100644 --- a/mlir/lib/Quantum/Transforms/graph_decomposition.cpp +++ b/mlir/lib/Quantum/Transforms/graph_decomposition.cpp @@ -128,7 +128,9 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase(rule->getName())) { - module.getBody()->push_back(rule.release()); - } - } } private: @@ -236,36 +231,79 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase> &ruleRegistry) + LogicalResult addRuleNode(mlir::func::FuncOp rule, std::vector &ruleNodes) + { + llvm::StringRef ruleName = rule.getName(); + + // 1. Mandatory Attribute Check (Target Gate and Resources) + auto targetGateAttr = rule->getAttrOfType("target_gate"); + auto resourcesAttr = rule->getAttrOfType("resources"); + if (!targetGateAttr || !resourcesAttr) { + llvm::errs() << "Cannot parse decomposition rule without `target_gate` and `resources` " + "attributes.\n"; + return failure(); + } + + // 2. Extract 'operations' dictionary from resources + auto operations = mlir::dyn_cast_or_null(resourcesAttr.get("operations")); + if (!operations) { + llvm::errs() + << "Cannot parse decomposition rule resources without `operations` attribute.\n"; + return failure(); + } + + // 3. Populate RuleNode + RuleNode ruleNode; + ruleNode.name = ruleName.str(); + ruleNode.output = parseOperator(targetGateAttr.getValue()); + + for (const auto &namedAttr : operations) { + if (auto intAttr = mlir::dyn_cast(namedAttr.getValue())) { + ruleNode.inputs.push_back({parseOperator(namedAttr.getName().strref()), + static_cast(intAttr.getInt())}); + } + } + + // 4. Add RuleNode + ruleNodes.push_back(std::move(ruleNode)); + return success(); + } + + LogicalResult loadBuiltInDecompositionRules(llvm::StringRef filename, + std::vector &ruleNodes) { mlir::MLIRContext *context = &getContext(); + mlir::ModuleOp module = getOperation(); mlir::ParserConfig config(context); mlir::OwningOpRef moduleOp = mlir::parseSourceFile(filename, config); + SymbolTable symbolTable(module); + if (!moduleOp) { - mlir::emitError(mlir::UnknownLoc::get(context)) - << "failed to load built-in decomposition rules from '" << filename - << "': the rules file could not be parsed"; - return; + llvm::errs() << "failed to load built-in decomposition rules from '" << filename + << "': the rules file could not be parsed\n"; + return failure(); } for (auto rule : llvm::make_early_inc_range(moduleOp.get().getOps())) { - rule->remove(); - ruleRegistry.push_back(std::move(rule)); + if (failed(addRuleNode(rule, ruleNodes))) { + return failure(); + } + // avoid double-insertion + if (!symbolTable.lookup(rule.getName())) { + rule->remove(); + module.push_back(std::move(rule)); + } } - return; + return success(); } /** - * @brief Remove user rules from the module, loading into + * @brief Load the listed user rules into the registry. */ - LogicalResult - loadUserDecompositionRules(llvm::StringSet<> &userRuleNames, - llvm::SmallVector> &graphRules, - llvm::SmallVector> &rules) + LogicalResult loadUserDecompositionRules(llvm::StringSet<> &userRuleNames, + std::vector &ruleNodes) { mlir::ModuleOp module = getOperation(); if (userRuleNames.empty()) { @@ -280,22 +318,21 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase userRules; - - module.walk([&](mlir::func::FuncOp func) { + WalkResult walkResult = module.walk([&](mlir::func::FuncOp func) { if (func->hasAttr("target_gate")) { - userRules.push_back(func); if (userRuleNames.contains(func.getName())) { - graphRules.push_back(mlir::OwningOpRef(func.clone())); + if (failed(addRuleNode(func, ruleNodes))) { + return WalkResult::interrupt(); + } } } return WalkResult::skip(); }); - for (auto rule : llvm::make_early_inc_range(userRules)) { - rule->remove(); - rules.push_back(std::move(rule)); + if (walkResult.wasInterrupted()) { + return failure(); } + return success(); } @@ -305,9 +342,9 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase> &ruleRegistry) + mlir::LogicalResult loadPauliRotRules(std::vector &ruleNodes) { + LDBG() << "loading paulirot rules"; mlir::ModuleOp module = getOperation(); MLIRContext *context = &getContext(); @@ -316,10 +353,20 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase pauliRotOps; module.walk([&](quantum::PauliRotOp op) { pauliRotOps.push_back(op); }); + LDBG() << "found the following paulirots:"; + for (auto op : pauliRotOps) { + LDBG() << op; + } + if (!pauliRotOps.empty()) { - loadQPD(libQPDPath, libpythonPath); + if (!loadQPD(libQPDPath, libpythonPath)) { + llvm::errs() << "failed to load libQuantumPythonCallbacks\n"; + return failure(); + } } + LDBG() << "loaded QPD"; + for (quantum::PauliRotOp pauliRot : pauliRotOps) { std::string pauliWord = pauliRot.getPauliWord(); @@ -365,15 +412,34 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBasesetAttr("resources", buildResourceDict(context, *flat)); } - ruleRegistry.push_back(std::move(outOp)); + if (failed(addRuleNode(funcOp, ruleNodes))) { + return failure(); + } + LDBG() << "adding rule " << funcOp.getName(); + module.push_back(std::move(outOp.release())); } - return success(); } + bool isInDecompRule(Operation *op) + { + while (auto parentOp = op->getParentOp()) { + if (auto funcOp = dyn_cast(parentOp)) { + if (funcOp->hasAttr("target_gate")) { + return true; + } + } + op = parentOp; + } + return false; + } + void getOperators(std::vector &operators) { getOperation().walk([&](quantum::QuantumGate op) { + if (isInDecompRule(op)) { + return; + } OperatorNode node; node.numWires = op.getNonCtrlQubitOperands().size(); node.adjoint = op.getAdjointFlag(); @@ -455,84 +521,25 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase &rules, - llvm::StringSet<> &userRuleNames, - llvm::SmallVector> &userRules, - llvm::StringMap> &ruleNameToFuncOp) + LogicalResult getRuleNodes(llvm::StringRef filename, std::vector &rules, + llvm::StringSet<> &userRuleNames) { - llvm::SmallVector> graphRules; - - // Load rules from bytecode and user-defined rules - loadBuiltInDecompositionRules(filename, graphRules); - if (failed(loadPauliRotRules(graphRules))) { - return signalPassFailure(); - } - if (failed(loadUserDecompositionRules(userRuleNames, graphRules, userRules))) { - return signalPassFailure(); - } - - for (auto &ruleOpRef : graphRules) { - mlir::func::FuncOp func = ruleOpRef.get(); - llvm::StringRef ruleName = func.getName(); - - // 1. Mandatory Attribute Check (Target Gate and Resources) - auto targetGateAttr = func->getAttrOfType("target_gate"); - auto resourcesAttr = func->getAttrOfType("resources"); - if (!targetGateAttr || !resourcesAttr) - continue; - - // 2. Extract 'operations' dictionary from resources - auto operations = - mlir::dyn_cast_or_null(resourcesAttr.get("operations")); - if (!operations) - continue; + LDBG() << "getting rule nodes"; - // 3. Populate RuleNode - RuleNode ruleNode; - ruleNode.name = ruleName.str(); - ruleNode.output = parseOperator(targetGateAttr.getValue()); - - for (const auto &namedAttr : operations) { - if (auto intAttr = mlir::dyn_cast(namedAttr.getValue())) { - ruleNode.inputs.push_back({parseOperator(namedAttr.getName().strref()), - static_cast(intAttr.getInt())}); - } - } + // Load pre-compiled rules (ignore failure, we can try to solve without) + std::ignore = loadBuiltInDecompositionRules(filename, rules); - // 4. Finalize: move the OpRef to the map to keep IR alive and store the node - ruleNameToFuncOp[ruleNode.name] = std::move(ruleOpRef); - rules.push_back(std::move(ruleNode)); + // Lower and load compile-time rules + LDBG() << "loading paulirot rules"; + if (failed(loadPauliRotRules(rules))) { + return failure(); } - } - /** - * @brief Insert the decomposition rules picked by the graph solver into the module for - * later use in the decompose-lowering patterns to apply the decomposition rules and rewrite - * the quantum operations. - * - * @param solution The chosen decomposition rules from the graph solver. - * @param ruleNameToFuncOp A mapping from rule names to their corresponding function - * operations. - */ - void insertChosenRules(GraphResult &solution, - llvm::StringMap> &ruleNameToFuncOp) - { - mlir::ModuleOp module = getOperation(); - for (const auto &[_, chosenRule] : solution) { - if (chosenRule.isBasis) { - continue; // skip basis rules as they don't correspond to actual decomposition - // functions to insert - } - auto it = ruleNameToFuncOp.find(chosenRule.ruleName); - - if (it == ruleNameToFuncOp.end() || !it->second) { - // skip if the rule is not found or - // the function op is null or - // it is already moved - continue; - } - module.push_back(it->second.release()); + // Load user-rules + if (failed(loadUserDecompositionRules(userRuleNames, rules))) { + return failure(); } + return success(); } /** From 9060f2249cd133c66059332f8855c3f5f5dc7b80 Mon Sep 17 00:00:00 2001 From: River McCubbin Date: Thu, 25 Jun 2026 13:12:56 -0400 Subject: [PATCH 2/4] update test --- mlir/test/Quantum/GraphDecomposition/TestAltDecomps.mlir | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/test/Quantum/GraphDecomposition/TestAltDecomps.mlir b/mlir/test/Quantum/GraphDecomposition/TestAltDecomps.mlir index 4ade35c784..8c2d5d2e62 100644 --- a/mlir/test/Quantum/GraphDecomposition/TestAltDecomps.mlir +++ b/mlir/test/Quantum/GraphDecomposition/TestAltDecomps.mlir @@ -27,9 +27,13 @@ func.func @circuit() -> !quantum.bit { // XZ: PauliX // XZ: PauliZ %qout = quantum.custom "PauliY"() %q : !quantum.bit + + // needed to ensure we don't match in the following decomposition rules + // CHECK: return return %qout : !quantum.bit } +// CHECK-LABEL: y_to_ry func.func @y_to_ry(%q0 : !quantum.bit) -> !quantum.bit attributes {target_gate="PauliY"} { %pi = arith.constant 3.14 : f64 %negpiby2 = arith.constant -1.57 : f64 @@ -38,6 +42,7 @@ func.func @y_to_ry(%q0 : !quantum.bit) -> !quantum.bit attributes {target_gate=" return %q1 : !quantum.bit } +// CHECK-LABEL: y_to_x_z func.func @y_to_x_z(%q0 : !quantum.bit) -> !quantum.bit attributes {target_gate="PauliY"} { %q1 = quantum.custom "PauliX"() %q0 : !quantum.bit %q2 = quantum.custom "PauliZ"() %q1 : !quantum.bit From 8e55f6afe344615b01ebd559b3fbaa6b07492df5 Mon Sep 17 00:00:00 2001 From: River McCubbin Date: Thu, 25 Jun 2026 14:42:27 -0400 Subject: [PATCH 3/4] fix paulirot rule collisions, centralize resources in addRuleNode --- .../Transforms/graph_decomposition.cpp | 61 ++++++++----------- 1 file changed, 27 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/graph_decomposition.cpp b/mlir/lib/Quantum/Transforms/graph_decomposition.cpp index 351c492537..c1ec170efb 100644 --- a/mlir/lib/Quantum/Transforms/graph_decomposition.cpp +++ b/mlir/lib/Quantum/Transforms/graph_decomposition.cpp @@ -104,10 +104,7 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase setOfOps; std::vector setOfRules; - llvm::StringMap> ruleNameToFuncOp; llvm::StringSet<> userRuleNames; - llvm::SmallVector> - allUserRules; // includes rules unused in this decomp llvm::StringMap opToFixedDecompName; llvm::StringMap> opToAltDecompNames; WeightedGateset targetGateSet; @@ -238,17 +235,28 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBasegetAttrOfType("target_gate"); auto resourcesAttr = rule->getAttrOfType("resources"); - if (!targetGateAttr || !resourcesAttr) { - llvm::errs() << "Cannot parse decomposition rule without `target_gate` and `resources` " - "attributes.\n"; + if (!targetGateAttr) { + llvm::errs() << "Cannot parse decomposition rule " << ruleName + << " without the `target_gate` attribute.\n"; + LDBG() << rule; return failure(); } + // Ensure resources + if (!resourcesAttr) { + ResourceAnalysis analysis(rule); + if (const ResourceResult *flat = analysis.getFlattenedResource(rule.getName())) { + rule->setAttr("resources", buildResourceDict(&getContext(), *flat)); + } + resourcesAttr = rule->getAttrOfType("resources"); + } + // 2. Extract 'operations' dictionary from resources auto operations = mlir::dyn_cast_or_null(resourcesAttr.get("operations")); if (!operations) { - llvm::errs() - << "Cannot parse decomposition rule resources without `operations` attribute.\n"; + llvm::errs() << "Cannot parse resource for decomposition rule " << ruleName + << " without `operations` attribute.\n"; + LDBG() << rule; return failure(); } @@ -300,7 +308,7 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase &userRuleNames, std::vector &ruleNodes) @@ -310,14 +318,6 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBasehasAttr("target_gate")) { if (userRuleNames.contains(func.getName())) { @@ -344,20 +344,22 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase &ruleNodes) { - LDBG() << "loading paulirot rules"; mlir::ModuleOp module = getOperation(); MLIRContext *context = &getContext(); llvm::StringSet<> addedWords; + // Add words from existing paulirot rules + module.walk([&](mlir::func::FuncOp func) { + if (func->hasAttr("target_gate")) { + if (func.getName().starts_with("paulirot_decomp_rule_")) { + addedWords.insert(func.getName().drop_front(21)); + } + } + }); llvm::SmallVector pauliRotOps; module.walk([&](quantum::PauliRotOp op) { pauliRotOps.push_back(op); }); - LDBG() << "found the following paulirots:"; - for (auto op : pauliRotOps) { - LDBG() << op; - } - if (!pauliRotOps.empty()) { if (!loadQPD(libQPDPath, libpythonPath)) { llvm::errs() << "failed to load libQuantumPythonCallbacks\n"; @@ -365,8 +367,6 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBasesetName((outOp->getName() + "_" + pauliWord).str()); // unique name per pauliword funcOp->setAttr("target_gate", mlir::StringAttr::get(context, "paulirot" + pauliWord)); - auto analysis = ResourceAnalysis(funcOp); - if (const ResourceResult *flat = analysis.getFlattenedResource(funcOp.getName())) { - funcOp->setAttr("resources", buildResourceDict(context, *flat)); - } - if (failed(addRuleNode(funcOp, ruleNodes))) { return failure(); } @@ -524,13 +519,10 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase &rules, llvm::StringSet<> &userRuleNames) { - LDBG() << "getting rule nodes"; - // Load pre-compiled rules (ignore failure, we can try to solve without) std::ignore = loadBuiltInDecompositionRules(filename, rules); // Lower and load compile-time rules - LDBG() << "loading paulirot rules"; if (failed(loadPauliRotRules(rules))) { return failure(); } @@ -580,7 +572,8 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase Date: Fri, 26 Jun 2026 15:25:28 -0400 Subject: [PATCH 4/4] changelog --- doc/releases/changelog-dev.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4f127d6e32..996b3eeea6 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -201,6 +201,9 @@

Internal changes ⚙️

+* The `graph-decomposition` pass now performs far less IR manipulation. + [(#2977)](https://github.com/PennyLaneAI/catalyst/pull/2977) + * Update tests to not use global capture toggle where possible. [(#2964)](https://github.com/PennyLaneAI/catalyst/pull/2964)