Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@

<h3>Improvements 🛠</h3>

* The `decompose-lowering` pass now supports applying a selection of the available decomposition rules via the `target_rules` parameter.
The pass also no longer applies the `inline`, `cse` and `canonicalize` passes to avoid unnecessary IR mutations.
Instead, decomposition rules are deterministically inlined by a custom function (`inline` is non-deterministic, using an estimated benefit and threshold as criteria for inlining).
Decomposition rules are no longer removed after the `decompose-lowering` pass, which allows them to be used by subsequent passes, namely `graph-decomposition`.
Instead, rules are removed by the `symbol-dce` pass at the end of the `QuantumCompilationStage`.
[(#2973)](https://github.com/PennyLaneAI/catalyst/pull/2973)

* The new `pennylane.core.Operator2` can now be lowered to MLIR with program capture for operators
without non-lowerable arguments.
[(#2969)](https://github.com/PennyLaneAI/catalyst/pull/2969/)
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/Quantum/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ def GraphDecompositionPass : Pass<"graph-decomposition", "mlir::ModuleOp"> {

def DecomposeLoweringPass : Pass<"decompose-lowering"> {
let summary = "Replace quantum operations with compiled decomposition rules.";

let options = [
ListOption<
/*C++ name*/"targetRulesOption",
/*CLI name*/"target-rules",
/*Type*/"std::string",
/*Description*/"The set of decomposition rules to apply. If empty, applies all rules.">
];
}

def DisentangleCNOTPass : Pass<"disentangle-cnot"> {
Expand Down
43 changes: 25 additions & 18 deletions mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,26 @@
// limitations under the License.

#include <algorithm> // std::move_backward
#include <cassert>
#include <cstddef>
#include <iterator>
#include <utility>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"

#include "Quantum/IR/QuantumOps.h"
#include "Quantum/IR/QuantumTypes.h"
Expand All @@ -34,8 +45,8 @@ namespace catalyst {
namespace quantum {

// The goal of this class is to analyze the signature of a custom operation to get the enough
// information to prepare the call operands and results for replacing the op to calling the
// decomposition function.
// information to prepare the operands and results for replacing the op with the decomposition
// function.
class BaseSignatureAnalyzer {
protected:
bool isValid = true;
Expand Down Expand Up @@ -177,7 +188,7 @@ class BaseSignatureAnalyzer {
return latestQreg;
}

// Prepare the operands for calling the decomposition function
// Prepare the operands for the decomposition function
// There are two cases:
// 1. The first input is a qreg, which means the decomposition function is a qreg mode function
// 2. Otherwise, the decomposition function is a qubit mode function
Expand All @@ -187,10 +198,10 @@ class BaseSignatureAnalyzer {
// - func(qreg, param*, inWires*, inCtrlWires*?, inCtrlValues*?) -> qreg
// 2. qubit mode:
// - func(param*, inQubits*, inCtrlQubits*?, inCtrlValues*?) -> outQubits*
llvm::SmallVector<Value> prepareCallOperands(func::FuncOp decompFunc, PatternRewriter &rewriter,
Location loc)
llvm::SmallVector<Value> prepareOperands(func::FuncOp rule, PatternRewriter &rewriter,
Location loc)
{
auto funcType = decompFunc.getFunctionType();
auto funcType = rule.getFunctionType();
auto funcInputs = funcType.getInputs();

SmallVector<Type> funcInputsNoQreg;
Expand All @@ -202,10 +213,10 @@ class BaseSignatureAnalyzer {

SmallVector<Value> operands(funcInputs.size());

auto qregIt = llvm::find_if(decompFunc.getFunctionType().getInputs(),
auto qregIt = llvm::find_if(rule.getFunctionType().getInputs(),
[](mlir::Type t) { return isa<quantum::QuregType>(t); });
int qregIdx = std::distance(decompFunc.getFunctionType().getInputs().begin(), qregIt);
bool hasQreg = (qregIt != decompFunc.getFunctionType().getInputs().end());
int qregIdx = std::distance(rule.getFunctionType().getInputs().begin(), qregIt);
bool hasQreg = (qregIt != rule.getFunctionType().getInputs().end());

int operandIdx = 0;
if (!signature.params.empty()) {
Expand Down Expand Up @@ -264,22 +275,18 @@ class BaseSignatureAnalyzer {
return operands;
}

// Prepare the results for the call operation
SmallVector<Value> prepareCallResultForQreg(func::CallOp callOp, PatternRewriter &rewriter)
// Prepare the results produced by a qreg-mode decomposition rule
SmallVector<Value> prepareResultsForQreg(Value qreg, Location loc, PatternRewriter &rewriter)
{
assert(callOp.getNumResults() == 1 && "only one qreg result for qreg mode is allowed");

auto qreg = callOp.getResult(0);
assert(isa<quantum::QuregType>(qreg.getType()) && "only allow to have qreg result");

SmallVector<Value> newResults;
rewriter.setInsertionPointAfter(callOp);

for (const auto &indices : {signature.outQubitIndices, signature.outCtrlQubitIndices}) {
for (const auto &index : indices) {
auto extractOp = quantum::ExtractOp::create(
rewriter, callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg,
index.getValue(), index.getAttr());
rewriter, loc, rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
index.getAttr());
newResults.emplace_back(extractOp.getResult());
}
}
Expand Down Expand Up @@ -325,7 +332,7 @@ class BaseSignatureAnalyzer {
return {startIdx, paramTypeEnd};
}

// generate params for calling the decomposition function based on function type requirements
// generate params for the decomposition function based on function type requirements
SmallVector<Value> generateParams(ValueRange signatureParams, ArrayRef<Type> funcParamTypes,
PatternRewriter &rewriter, Location loc)
{
Expand Down
135 changes: 102 additions & 33 deletions mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,75 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#define DEBUG_TYPE "decompose-lowering"
#include <cassert>
#include <cstddef>
#include <string>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/AllocatorBase.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"

#include "Quantum/IR/QuantumInterfaces.h"
#include "Quantum/IR/QuantumOps.h"
#include "Quantum/IR/QuantumTypes.h"

#include "DecomposeLoweringImpl.hpp"

#define DEBUG_TYPE "decompose-lowering"

using namespace mlir;
using namespace catalyst::quantum;

namespace catalyst {
namespace quantum {

SmallVector<Value> getDecompRuleResults(func::FuncOp rule, ValueRange operands,
PatternRewriter &rewriter)
{
Block &body = rule.front();
auto returnOp = cast<func::ReturnOp>(body.getTerminator());

IRMapping mapping;
mapping.map(body.getArguments(), operands);

for (Operation &op : body.without_terminator()) {
rewriter.clone(op, mapping);
}

SmallVector<Value> results;
for (Value operand : returnOp.getOperands()) {
results.push_back(mapping.lookupOrDefault(operand));
}
return results;
}

bool isInDecompRule(Operation *op)
{
while (auto parentOp = op->getParentOp()) {
if (auto funcOp = dyn_cast<func::FuncOp>(parentOp)) {
if (funcOp->hasAttr("target_gate")) {
return true;
}
}
op = parentOp;
}
return false;
}

struct DLCustomOpPattern : public OpRewritePattern<CustomOp> {
private:
const llvm::StringMap<func::FuncOp> &decompositionRegistry;
Expand All @@ -54,18 +103,24 @@ struct DLCustomOpPattern : public OpRewritePattern<CustomOp> {
return failure();
}

// Find the corresponding decomposition function for the op
// do not nest decomposition rules, they're applied greedily and this can lead to
// cycles/identity rules
if (isInDecompRule(op)) {
return failure();
}

// Find the corresponding decomposition rule for the op
auto it = decompositionRegistry.find(gateName);
if (it == decompositionRegistry.end()) {
return failure();
}
func::FuncOp decompFunc = it->second;
func::FuncOp rule = it->second;

// For null decomp rules, the signature will not have any quantum values
// This is a deviation from the standard decomp func signature, so we deal with it
// separately
if (!llvm::any_of(llvm::concat<const Type>(decompFunc.getFunctionType().getInputs(),
decompFunc.getFunctionType().getResults()),
if (!llvm::any_of(llvm::concat<const Type>(rule.getFunctionType().getInputs(),
rule.getFunctionType().getResults()),
[](const mlir::Type t) {
return isa<quantum::QuregType, quantum::QubitType>(t);
})) {
Expand All @@ -76,32 +131,32 @@ struct DLCustomOpPattern : public OpRewritePattern<CustomOp> {
return success();
}

// Here is the assumption that the decomposition function must have at least one input and
// Here is the assumption that the decomposition rule must have at least one input and
// one result
assert(decompFunc.getFunctionType().getNumInputs() > 0 &&
assert(rule.getFunctionType().getNumInputs() > 0 &&
"Decomposition function must have at least one input");
assert(decompFunc.getFunctionType().getNumResults() >= 1 &&
assert(rule.getFunctionType().getNumResults() >= 1 &&
"Decomposition function must have at least one result");

rewriter.setInsertionPointAfter(op);

auto enableQreg = llvm::any_of(decompFunc.getFunctionType().getInputs(),
auto enableQreg = llvm::any_of(rule.getFunctionType().getInputs(),
[](mlir::Type t) { return isa<quantum::QuregType>(t); });
auto analyzer = CustomOpSignatureAnalyzer(op, enableQreg);
assert(analyzer && "Analyzer should be valid");

auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
auto callOp =
func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(),
decompFunc.getSymName(), callOperands);
auto operands = analyzer.prepareOperands(rule, rewriter, op.getLoc());
SmallVector<Value> inlinedFunctionResults = getDecompRuleResults(rule, operands, rewriter);

// Replace the op with the call op and adjust the insert ops for the qreg mode
if (callOp.getNumResults() == 1 && isa<quantum::QuregType>(callOp.getResult(0).getType())) {
auto results = analyzer.prepareCallResultForQreg(callOp, rewriter);
// Replace the op with the inlined function and adjust the insert ops for the qreg mode
if (inlinedFunctionResults.size() == 1 &&
isa<quantum::QuregType>(inlinedFunctionResults.front().getType())) {
auto results = analyzer.prepareResultsForQreg(inlinedFunctionResults.front(),
op.getLoc(), rewriter);
rewriter.replaceOp(op, results);
}
else {
rewriter.replaceOp(op, callOp->getResults());
rewriter.replaceOp(op, inlinedFunctionResults);
}

return success();
Expand Down Expand Up @@ -130,6 +185,12 @@ struct DLMultiRZOpPattern : public OpRewritePattern<MultiRZOp> {
return failure();
}

// do not nest decomposition rules, they're applied greedily and this can lead to
// cycles/identity rules
if (isInDecompRule(op)) {
return failure();
}

// Find the corresponding decomposition function for the op
auto numQubits = op.getInQubits().size();
auto MRZNameWithQubits = gateName + "_" + std::to_string(numQubits);
Expand Down Expand Up @@ -165,18 +226,19 @@ struct DLMultiRZOpPattern : public OpRewritePattern<MultiRZOp> {
auto analyzer = MultiRZOpSignatureAnalyzer(op, enableQreg);
assert(analyzer && "Analyzer should be valid");

auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
auto callOp =
func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(),
decompFunc.getSymName(), callOperands);
auto operands = analyzer.prepareOperands(decompFunc, rewriter, op.getLoc());
SmallVector<Value> inlinedFunctionResults =
getDecompRuleResults(decompFunc, operands, rewriter);

// Replace the op with the call op and adjust the insert ops for the qreg mode
if (callOp.getNumResults() == 1 && isa<quantum::QuregType>(callOp.getResult(0).getType())) {
auto results = analyzer.prepareCallResultForQreg(callOp, rewriter);
// Replace the op with the inlined function and adjust the insert ops for the qreg mode
if (inlinedFunctionResults.size() == 1 &&
isa<quantum::QuregType>(inlinedFunctionResults.front().getType())) {
auto results = analyzer.prepareResultsForQreg(inlinedFunctionResults.front(),
op.getLoc(), rewriter);
rewriter.replaceOp(op, results);
}
else {
rewriter.replaceOp(op, callOp->getResults());
rewriter.replaceOp(op, inlinedFunctionResults);
}

return success();
Expand Down Expand Up @@ -205,6 +267,12 @@ struct DLPauliRotOpPattern : public OpRewritePattern<PauliRotOp> {
return failure();
}

// do not nest decomposition rules, they're applied greedily and this can lead to
// cycles/identity rules
if (isInDecompRule(op)) {
return failure();
}

// Find the corresponding decomposition function for the op
auto it = decompositionRegistry.find(gateName);
if (it == decompositionRegistry.end()) {
Expand All @@ -226,18 +294,19 @@ struct DLPauliRotOpPattern : public OpRewritePattern<PauliRotOp> {
auto analyzer = PauliRotOpSignatureAnalyzer(op, enableQreg);
assert(analyzer && "Analyzer should be valid");

auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
auto callOp =
func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(),
decompFunc.getSymName(), callOperands);
auto operands = analyzer.prepareOperands(decompFunc, rewriter, op.getLoc());
SmallVector<Value> inlinedFunctionResults =
getDecompRuleResults(decompFunc, operands, rewriter);

// Replace the op with the call op and adjust the insert ops for the qreg mode
if (callOp.getNumResults() == 1 && isa<quantum::QuregType>(callOp.getResult(0).getType())) {
auto results = analyzer.prepareCallResultForQreg(callOp, rewriter);
// Replace the op with the inlined results and adjust the insert ops for the qreg mode
if (inlinedFunctionResults.size() == 1 &&
isa<quantum::QuregType>(inlinedFunctionResults.front().getType())) {
auto results = analyzer.prepareResultsForQreg(inlinedFunctionResults.front(),
op.getLoc(), rewriter);
rewriter.replaceOp(op, results);
}
else {
rewriter.replaceOp(op, callOp->getResults());
rewriter.replaceOp(op, inlinedFunctionResults);
}

return success();
Expand Down
Loading
Loading