Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@

<h3>Internal changes ⚙️</h3>

* The `graph-decomposition` pass now performs far less IR manipulation.
[(#2977)](https://github.com/PennyLaneAI/catalyst/pull/2977)

* Update tests to not use global capture toggle where possible.
[(#2964)](https://github.com/PennyLaneAI/catalyst/pull/2964)

Expand Down
264 changes: 132 additions & 132 deletions mlir/lib/Quantum/Transforms/graph_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase<GraphDec
// Step 1: Gather inputs for graph
std::vector<OperatorNode> setOfOps;
std::vector<RuleNode> setOfRules;
llvm::StringMap<mlir::OwningOpRef<func::FuncOp>> ruleNameToFuncOp;
llvm::StringSet<> userRuleNames;
llvm::SmallVector<mlir::OwningOpRef<func::FuncOp>>
allUserRules; // includes rules unused in this decomp
llvm::StringMap<std::string> opToFixedDecompName;
llvm::StringMap<llvm::SmallVector<std::string>> opToAltDecompNames;
WeightedGateset targetGateSet;
Expand All @@ -128,7 +125,9 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase<GraphDec

// NOTE: getOperators must be after getRuleNodes, which removes user rules from the module.
// This prevents operators in user rules from being added to the graph.
getRuleNodes(bytecodeRulesFile, setOfRules, userRuleNames, allUserRules, ruleNameToFuncOp);
if (failed(getRuleNodes(bytecodeRulesFile, setOfRules, userRuleNames))) {
return signalPassFailure();
}
getOperators(setOfOps);

///////////////////////////
Expand All @@ -139,31 +138,24 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase<GraphDec
std::move(altDecomps));
DecompositionSolver solver(graph);
auto solution = solver.solve();
///////////////////////////
// Step 3: Insert decomposition rules picked by the graph solver (solution) into the
// module
insertChosenRules(solution, ruleNameToFuncOp);

///////////////////////////
// Step 4: Convert python-decompositions from reference to value semantics and run
// decompose-lowering to apply the chosen decomposition rules
ModuleOp module = getOperation();
OpPassManager pm("builtin.module");

DecomposeLoweringPassOptions dlOptions;
for (auto &[op, chosenRule] : solution) {
dlOptions.targetRulesOption.push_back(chosenRule.ruleName);
}

pm.addPass(qref::createValueSemanticsConversionPass());
pm.addPass(createDecomposeLoweringPass());
pm.addPass(createDecomposeLoweringPass(dlOptions));

if (failed(runPipeline(pm, module))) {
return signalPassFailure();
}

///////////////////////////
// Step 5: Re-introduce any missing user rules for future decompositions
SymbolTable symbolTable(module);
for (auto &rule : allUserRules) {
if (!symbolTable.lookup<func::FuncOp>(rule->getName())) {
module.getBody()->push_back(rule.release());
}
}
}

private:
Expand Down Expand Up @@ -236,66 +228,111 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase<GraphDec
return success();
}

void loadBuiltInDecompositionRules(
llvm::StringRef filename,
llvm::SmallVector<mlir::OwningOpRef<mlir::func::FuncOp>> &ruleRegistry)
LogicalResult addRuleNode(mlir::func::FuncOp rule, std::vector<RuleNode> &ruleNodes)
{
llvm::StringRef ruleName = rule.getName();

// 1. Mandatory Attribute Check (Target Gate and Resources)
auto targetGateAttr = rule->getAttrOfType<StringAttr>("target_gate");
auto resourcesAttr = rule->getAttrOfType<DictionaryAttr>("resources");
if (!targetGateAttr) {
llvm::errs() << "Cannot parse decomposition rule " << ruleName
<< " without the `target_gate` attribute.\n";
LDBG() << rule;
return failure();
}

// Ensure resources
if (!resourcesAttr) {
ResourceAnalysis analysis(rule);
if (const ResourceResult *flat = analysis.getFlattenedResource(rule.getName())) {
rule->setAttr("resources", buildResourceDict(&getContext(), *flat));
}
resourcesAttr = rule->getAttrOfType<DictionaryAttr>("resources");
}

// 2. Extract 'operations' dictionary from resources
auto operations = mlir::dyn_cast_or_null<DictionaryAttr>(resourcesAttr.get("operations"));
if (!operations) {
llvm::errs() << "Cannot parse resource for decomposition rule " << ruleName
<< " without `operations` attribute.\n";
LDBG() << rule;
return failure();
}

// 3. Populate RuleNode
RuleNode ruleNode;
ruleNode.name = ruleName.str();
ruleNode.output = parseOperator(targetGateAttr.getValue());

for (const auto &namedAttr : operations) {
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(namedAttr.getValue())) {
ruleNode.inputs.push_back({parseOperator(namedAttr.getName().strref()),
static_cast<uint32_t>(intAttr.getInt())});
}
}

// 4. Add RuleNode
ruleNodes.push_back(std::move(ruleNode));
return success();
}

LogicalResult loadBuiltInDecompositionRules(llvm::StringRef filename,
std::vector<RuleNode> &ruleNodes)
{
mlir::MLIRContext *context = &getContext();
mlir::ModuleOp module = getOperation();
mlir::ParserConfig config(context);
mlir::OwningOpRef<mlir::ModuleOp> moduleOp =
mlir::parseSourceFile<mlir::ModuleOp>(filename, config);

SymbolTable symbolTable(module);

if (!moduleOp) {
mlir::emitError(mlir::UnknownLoc::get(context))
<< "failed to load built-in decomposition rules from '" << filename
<< "': the rules file could not be parsed";
return;
llvm::errs() << "failed to load built-in decomposition rules from '" << filename
<< "': the rules file could not be parsed\n";
return failure();
}

for (auto rule : llvm::make_early_inc_range(moduleOp.get().getOps<mlir::func::FuncOp>())) {
rule->remove();
ruleRegistry.push_back(std::move(rule));
if (failed(addRuleNode(rule, ruleNodes))) {
return failure();
}
// avoid double-insertion
if (!symbolTable.lookup<mlir::func::FuncOp>(rule.getName())) {
rule->remove();
module.push_back(std::move(rule));
}
}
return;
return success();
}

/**
* @brief Remove user rules from the module, loading into
* @brief Load the listed user rules into the set of RuleNodes for the graph.
*/
LogicalResult
loadUserDecompositionRules(llvm::StringSet<> &userRuleNames,
llvm::SmallVector<mlir::OwningOpRef<mlir::func::FuncOp>> &graphRules,
llvm::SmallVector<mlir::OwningOpRef<mlir::func::FuncOp>> &rules)
LogicalResult loadUserDecompositionRules(llvm::StringSet<> &userRuleNames,
std::vector<RuleNode> &ruleNodes)
{
mlir::ModuleOp module = getOperation();
if (userRuleNames.empty()) {
return success();
}

PassManager pm(&getContext());
pm.addPass(createRegisterDecompRuleResourcePass());
if (failed(pm.run(module))) {
module.emitError() << "failed to load user decomposition rules: unable to run resource "
"annotation pass";
return failure();
}

llvm::SmallVector<mlir::func::FuncOp> userRules;

module.walk([&](mlir::func::FuncOp func) {
WalkResult walkResult = module.walk([&](mlir::func::FuncOp func) {
if (func->hasAttr("target_gate")) {
userRules.push_back(func);
if (userRuleNames.contains(func.getName())) {
graphRules.push_back(mlir::OwningOpRef<mlir::func::FuncOp>(func.clone()));
if (failed(addRuleNode(func, ruleNodes))) {
return WalkResult::interrupt();
}
}
}
return WalkResult::skip();
});

for (auto rule : llvm::make_early_inc_range(userRules)) {
rule->remove();
rules.push_back(std::move(rule));
if (walkResult.wasInterrupted()) {
return failure();
}

return success();
}

Expand All @@ -305,19 +342,29 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase<GraphDec
* annotating the lowered decomposition rules with resources and target gates. The target gate
* for the decomposition rule associated with Pauli word `ABC` will be `paulirotABC`.
*/
mlir::LogicalResult
loadPauliRotRules(llvm::SmallVector<mlir::OwningOpRef<mlir::func::FuncOp>> &ruleRegistry)
mlir::LogicalResult loadPauliRotRules(std::vector<RuleNode> &ruleNodes)
{
mlir::ModuleOp module = getOperation();
MLIRContext *context = &getContext();

llvm::StringSet<> addedWords;
// Add words from existing paulirot rules
module.walk([&](mlir::func::FuncOp func) {
if (func->hasAttr("target_gate")) {
if (func.getName().starts_with("paulirot_decomp_rule_")) {
addedWords.insert(func.getName().drop_front(21));
}
}
});

llvm::SmallVector<quantum::PauliRotOp> pauliRotOps;
module.walk([&](quantum::PauliRotOp op) { pauliRotOps.push_back(op); });

if (!pauliRotOps.empty()) {
loadQPD(libQPDPath, libpythonPath);
if (!loadQPD(libQPDPath, libpythonPath)) {
llvm::errs() << "failed to load libQuantumPythonCallbacks\n";
return failure();
}
}

for (quantum::PauliRotOp pauliRot : pauliRotOps) {
Expand Down Expand Up @@ -360,20 +407,34 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase<GraphDec
outOp->setName((outOp->getName() + "_" + pauliWord).str()); // unique name per pauliword
funcOp->setAttr("target_gate", mlir::StringAttr::get(context, "paulirot" + pauliWord));

auto analysis = ResourceAnalysis(funcOp);
if (const ResourceResult *flat = analysis.getFlattenedResource(funcOp.getName())) {
funcOp->setAttr("resources", buildResourceDict(context, *flat));
if (failed(addRuleNode(funcOp, ruleNodes))) {
return failure();
}

ruleRegistry.push_back(std::move(outOp));
LDBG() << "adding rule " << funcOp.getName();
module.push_back(std::move(outOp.release()));
}

return success();
}

bool isInDecompRule(Operation *op)
{
while (auto parentOp = op->getParentOp()) {
if (auto funcOp = dyn_cast<func::FuncOp>(parentOp)) {
if (funcOp->hasAttr("target_gate")) {
return true;
}
}
op = parentOp;
}
return false;
}

void getOperators(std::vector<OperatorNode> &operators)
{
getOperation().walk([&](quantum::QuantumGate op) {
if (isInDecompRule(op)) {
return;
}
OperatorNode node;
node.numWires = op.getNonCtrlQubitOperands().size();
node.adjoint = op.getAdjointFlag();
Expand Down Expand Up @@ -455,84 +516,22 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase<GraphDec
/**
* @brief Create RuleNodes for each rule available to be used in graph decomposition.
*/
void getRuleNodes(llvm::StringRef filename, std::vector<RuleNode> &rules,
llvm::StringSet<> &userRuleNames,
llvm::SmallVector<mlir::OwningOpRef<func::FuncOp>> &userRules,
llvm::StringMap<mlir::OwningOpRef<func::FuncOp>> &ruleNameToFuncOp)
LogicalResult getRuleNodes(llvm::StringRef filename, std::vector<RuleNode> &rules,
llvm::StringSet<> &userRuleNames)
{
llvm::SmallVector<mlir::OwningOpRef<mlir::func::FuncOp>> graphRules;

// Load rules from bytecode and user-defined rules
loadBuiltInDecompositionRules(filename, graphRules);
if (failed(loadPauliRotRules(graphRules))) {
return signalPassFailure();
}
if (failed(loadUserDecompositionRules(userRuleNames, graphRules, userRules))) {
return signalPassFailure();
}

for (auto &ruleOpRef : graphRules) {
mlir::func::FuncOp func = ruleOpRef.get();
llvm::StringRef ruleName = func.getName();
// Load pre-compiled rules (ignore failure, we can try to solve without)
std::ignore = loadBuiltInDecompositionRules(filename, rules);

// 1. Mandatory Attribute Check (Target Gate and Resources)
auto targetGateAttr = func->getAttrOfType<StringAttr>("target_gate");
auto resourcesAttr = func->getAttrOfType<DictionaryAttr>("resources");
if (!targetGateAttr || !resourcesAttr)
continue;

// 2. Extract 'operations' dictionary from resources
auto operations =
mlir::dyn_cast_or_null<DictionaryAttr>(resourcesAttr.get("operations"));
if (!operations)
continue;

// 3. Populate RuleNode
RuleNode ruleNode;
ruleNode.name = ruleName.str();
ruleNode.output = parseOperator(targetGateAttr.getValue());

for (const auto &namedAttr : operations) {
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(namedAttr.getValue())) {
ruleNode.inputs.push_back({parseOperator(namedAttr.getName().strref()),
static_cast<uint32_t>(intAttr.getInt())});
}
}

// 4. Finalize: move the OpRef to the map to keep IR alive and store the node
ruleNameToFuncOp[ruleNode.name] = std::move(ruleOpRef);
rules.push_back(std::move(ruleNode));
// Lower and load compile-time rules
if (failed(loadPauliRotRules(rules))) {
return failure();
}
}

/**
* @brief Insert the decomposition rules picked by the graph solver into the module for
* later use in the decompose-lowering patterns to apply the decomposition rules and rewrite
* the quantum operations.
*
* @param solution The chosen decomposition rules from the graph solver.
* @param ruleNameToFuncOp A mapping from rule names to their corresponding function
* operations.
*/
void insertChosenRules(GraphResult &solution,
llvm::StringMap<mlir::OwningOpRef<func::FuncOp>> &ruleNameToFuncOp)
{
mlir::ModuleOp module = getOperation();
for (const auto &[_, chosenRule] : solution) {
if (chosenRule.isBasis) {
continue; // skip basis rules as they don't correspond to actual decomposition
// functions to insert
}
auto it = ruleNameToFuncOp.find(chosenRule.ruleName);

if (it == ruleNameToFuncOp.end() || !it->second) {
// skip if the rule is not found or
// the function op is null or
// it is already moved
continue;
}
module.push_back(it->second.release());
// Load user-rules
if (failed(loadUserDecompositionRules(userRuleNames, rules))) {
return failure();
}
return success();
}

/**
Expand Down Expand Up @@ -573,7 +572,8 @@ struct GraphDecompositionPass : public impl::GraphDecompositionPassBase<GraphDec
* For each entry, looks up the corresponding RuleNodes in setOfRules by name.
* Individual rules not found are skipped with a diagnostic.
*
* @param opToAltDecompNames Parsed mapping from operator name to alternative-rule names.
* @param opToAltDecompNames Parsed mapping from operator name to alternative-rule
* names.
* @param setOfRules The full list of available decomposition rules.
* @return Core::AltDecomps Mapping from OperatorNode to its alternative RuleNodes.
*/
Expand Down
Loading
Loading