diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index a968dbad21..b82bea5bc7 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -15,6 +15,11 @@ * The `ResourceAnalysis` pass now supports IR in reference semantics natively, rather than requiring a conversion step. [(#2923)](https://github.com/PennyLaneAI/catalyst/pull/2923) + +* The `--adjoint-lowering` pass no longer turns statically bounded for loops into + dynamically bounded ones. In this way they remain analyzable by functionality like `qp.specs`. + [(#2959)](https://github.com/PennyLaneAI/catalyst/issues/2959) + * The `--decompose-lowering` pass can now handle decomposition rule functions whose quantum register argument is at an arbitrary position in the argument list. [(#2836)](https://github.com/PennyLaneAI/catalyst/pull/2836) diff --git a/mlir/include/Quantum/Transforms/Patterns.h b/mlir/include/Quantum/Transforms/Patterns.h index 779ac29d12..e19da00451 100644 --- a/mlir/include/Quantum/Transforms/Patterns.h +++ b/mlir/include/Quantum/Transforms/Patterns.h @@ -33,7 +33,6 @@ mlir::Value getGlobalString(mlir::Location loc, mlir::OpBuilder &rewriter, mlir: void populateGridsynthPatterns(mlir::RewritePatternSet &patterns, double epsilon, bool pprBasis); void populateQIRConversionPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &, bool); -void populateAdjointPatterns(mlir::RewritePatternSet &); void populateCancelInversesPatterns(mlir::RewritePatternSet &); void populateMergeRotationsPatterns(mlir::RewritePatternSet &); void populateIonsDecompositionPatterns(mlir::RewritePatternSet &); diff --git a/mlir/include/Quantum/Utils/QuantumSplitting.h b/mlir/include/Quantum/Utils/QuantumSplitting.h deleted file mode 100644 index 79b6003894..0000000000 --- a/mlir/include/Quantum/Utils/QuantumSplitting.h +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2023 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/IR/Value.h" - -#include "Catalyst/IR/CatalystDialect.h" -#include "Catalyst/IR/CatalystOps.h" -#include "Quantum/IR/QuantumInterfaces.h" - -namespace catalyst { -namespace quantum { - -/// A collection of the data required to reconstruct a deterministic hybrid quantum program with -/// classical preprocessing and arbitrary classical control flow. -struct QuantumCache { - mlir::TypedValue paramVector; - mlir::TypedValue wireVector; - /// For every structured control flow op, store the values required for it to execute. - /// Specifically: store the conditions for scf.if ops, the start/stop/step of scf.for ops, and - /// the number of iterations for scf.while ops. - mlir::DenseMap> controlFlowTapes; - - /// Initialize the quantum cache to traverse and store the necessary parameters for the given - /// `topLevelRegion`. - static QuantumCache initialize(mlir::Region &topLevelRegion, mlir::OpBuilder &builder, - mlir::Location loc); - - void emitDealloc(mlir::OpBuilder &builder, mlir::Location loc); -}; - -class AugmentedCircuitGenerator { - public: - AugmentedCircuitGenerator(mlir::IRMapping &oldToCloned, QuantumCache &cache) - : oldToCloned(oldToCloned), cache(cache) - { - } - - /// Given a `region` containing classical preprocessing and quantum operations, generate an - /// augmented version that caches all the parameters required to deterministically re-execute - /// the circuit (gate params, classical control flow, and dynamic wires). - void generate(mlir::Region ®ion, mlir::OpBuilder &builder); - - private: - mlir::IRMapping &oldToCloned; - QuantumCache &cache; - - void visitOperation(mlir::scf::ForOp forOp, mlir::OpBuilder &builder); - void visitOperation(mlir::scf::WhileOp forOp, mlir::OpBuilder &builder); - void visitOperation(mlir::scf::IfOp forOp, mlir::OpBuilder &builder); - void visitOperation(mlir::scf::IndexSwitchOp indexSwitchOp, mlir::OpBuilder &builder); - - void cloneTerminatorClassicalOperands(mlir::Operation *terminator, mlir::OpBuilder &builder); - - /// Update the internal mapping of the results of `oldOp` to the results of `clonedOp` using the - /// given result remapping. - void mapResults(mlir::Operation *oldOp, mlir::Operation *clonedOp, - const mlir::DenseMap &argIdxMapping); - - // Emit an operation to cache a dynamic wire for quantum.insert/extract ops. - template void cacheDynamicWire(IndexingOp op, mlir::OpBuilder &builder) - { - if (!op.getIdxAttr().has_value()) { - ListPushOp::create(builder, op.getLoc(), oldToCloned.lookupOrDefault(op.getIdx()), - cache.wireVector); - } - } - - void cacheGate(quantum::ParametrizedGate, mlir::OpBuilder &builder); -}; - -void verifyTypeIsCacheable(mlir::Type ty, mlir::Operation *gate); - -} // namespace quantum -} // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/AdjointLowering/AdjointLowering.cpp b/mlir/lib/Quantum/Transforms/AdjointLowering/AdjointLowering.cpp new file mode 100644 index 0000000000..02ec3c6f0a --- /dev/null +++ b/mlir/lib/Quantum/Transforms/AdjointLowering/AdjointLowering.cpp @@ -0,0 +1,103 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "AdjointLowering.hpp" + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "Quantum/IR/QuantumOps.h" + +#include "QuantumCache.hpp" + +using namespace mlir; +using namespace catalyst::quantum; + +namespace { + +/// Orchestrates the adjoint lowering of a single `quantum.adjoint` op by running the forward pass +/// (recording the augmented circuit) followed by the reverse pass (emitting the adjoint circuit). +struct AdjointSingleOpRewritePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + /// We build a map from values mentioned in the source data flow to the values of + /// the program where quantum control flow is reversed. Most of the time, there is a 1-to-1 + /// correspondence with a notable exception caused by `insert`/`extract` API asymmetry. + LogicalResult matchAndRewrite(AdjointOp adjoint, PatternRewriter &rewriter) const override + { + QuantumCache cache = + QuantumCache::initialize(adjoint.getRegion(), rewriter, adjoint.getLoc()); + + // Forward pass: copy the classical computations to the target insertion point and record + // the values needed to replay the circuit in reverse. + IRMapping oldToCloned; + generateAdjointForwardPass(adjoint.getRegion(), rewriter, oldToCloned, cache); + + // Seed the reverse pass with the operands of the quantum.yield. + auto yieldOp = cast(adjoint.getRegion().front().getTerminator()); + for (auto [yieldVal, adjointOperand] : + llvm::zip_equal(yieldOp.getOperands(), adjoint.getArgs())) { + oldToCloned.map(yieldVal, adjointOperand); + } + + // Reverse pass: emit the adjoint quantum operations and reversed control flow, using the + // cached values. + if (failed(generateAdjointReversePass(adjoint.getRegion(), rewriter, oldToCloned, cache))) { + return failure(); + } + + // Explicitly free the memory of the caches. + cache.emitDealloc(rewriter, adjoint.getLoc()); + // The final quantum values are the re-mapped region arguments of the original adjoint op. + SmallVector reversedOutputs; + for (BlockArgument arg : adjoint.getRegion().getArguments()) { + reversedOutputs.push_back(oldToCloned.lookup(arg)); + } + rewriter.replaceOp(adjoint, reversedOutputs); + return success(); + } +}; + +} // namespace + +namespace catalyst { +namespace quantum { + +#define GEN_PASS_DEF_ADJOINTLOWERINGPASS +#include "Quantum/Transforms/Passes.h.inc" + +struct AdjointLoweringPass : impl::AdjointLoweringPassBase { + using AdjointLoweringPassBase::AdjointLoweringPassBase; + + void runOnOperation() final + { + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext(), 1); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace quantum +} // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/AdjointLowering/AdjointLowering.hpp b/mlir/lib/Quantum/Transforms/AdjointLowering/AdjointLowering.hpp new file mode 100644 index 0000000000..cda9af1f8a --- /dev/null +++ b/mlir/lib/Quantum/Transforms/AdjointLowering/AdjointLowering.hpp @@ -0,0 +1,45 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/IR/IRMapping.h" + +#include "QuantumCache.hpp" + +namespace catalyst { +namespace quantum { + +/// Generate the forward pass of the adjoint region, i.e. the classical portion (forwards). +/// +/// This generates the forward computation: Classical preprocessing is cloned to the current +/// insertion point, and the information needed to deterministically replay the circuit in +/// reverse (gate parameters, dynamic wires, and control-flow structure) is recorded into the +/// `QuantumCache`. The reverse pass later consumes this cache. +void generateAdjointForwardPass(mlir::Region ®ion, mlir::OpBuilder &builder, + mlir::IRMapping &oldToCloned, QuantumCache &cache); + +/// Generate the reverse pass of the adjoint region, i.e. the quantum portion (in reverse). +/// +/// This generates the reverse computation: the quantum operations are cloned in reverse order +/// with the adjoint attribute applied, with any control flow reversed, and using recorded +/// values from the cache. +/// The `IRMapping` helps associating values in the original program to their counterpart in the +/// reversed program. It must be seeded with the adjoint operands before calling. +mlir::LogicalResult generateAdjointReversePass(mlir::Region ®ion, mlir::OpBuilder &builder, + mlir::IRMapping &remappedValues, + QuantumCache &cache); + +} // namespace quantum +} // namespace catalyst diff --git a/mlir/lib/Quantum/Utils/QuantumSplitting.cpp b/mlir/lib/Quantum/Transforms/AdjointLowering/ForwardPass.cpp similarity index 85% rename from mlir/lib/Quantum/Utils/QuantumSplitting.cpp rename to mlir/lib/Quantum/Transforms/AdjointLowering/ForwardPass.cpp index 234bf79ecd..8d650da338 100644 --- a/mlir/lib/Quantum/Utils/QuantumSplitting.cpp +++ b/mlir/lib/Quantum/Transforms/AdjointLowering/ForwardPass.cpp @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "Quantum/Utils/QuantumSplitting.h" - #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/IRMapping.h" #include "mlir/Transforms/DialectConversion.h" @@ -27,8 +26,12 @@ #include "PBC/IR/PBCOps.h" #include "Quantum/IR/QuantumOps.h" +#include "AdjointLowering.hpp" +#include "QuantumCache.hpp" + using namespace mlir; using namespace catalyst; +using namespace catalyst::quantum; namespace { bool isQuantumType(Type type) { return isa(type.getDialect()); } @@ -43,75 +46,48 @@ void populateArgIdxMapping(TypeRange types, DenseMap &argIdx } } -} // namespace +/// Generates the forward "augmented circuit" of the adjoint operation: classical preprocessing is +/// cloned as-is, while gate parameters, dynamic wires, and control-flow structure are recorded into +/// the cache for the reverse pass to replay. +class AugmentedCircuitGenerator { + public: + AugmentedCircuitGenerator(IRMapping &oldToCloned, QuantumCache &cache) + : oldToCloned(oldToCloned), cache(cache) + { + } -namespace catalyst { -namespace quantum { + /// Given a `region` containing classical preprocessing and quantum operations, generate an + /// augmented version that caches all the parameters required to deterministically re-execute + /// the circuit (gate params, classical control flow, and dynamic wires). + void generate(Region ®ion, OpBuilder &builder); -void verifyTypeIsCacheable(Type ty, Operation *op) -{ - // Sanitizing inputs. - // Technically we know for a fact that none of this will ever issue an - // error. This is because QubitUnitary is guaranteed to have a - // tensor> But this code in the future may be extended to - // support other types. Hence the sanitization. - if (ty.isF64()) { - return; - } + private: + IRMapping &oldToCloned; + QuantumCache &cache; - // TODO: Generalize to unranked tensors - if (!isa(ty)) { - op->emitOpError() << "Caching only supports tensors complex F64"; - } + void visitOperation(scf::ForOp forOp, OpBuilder &builder); + void visitOperation(scf::WhileOp whileOp, OpBuilder &builder); + void visitOperation(scf::IfOp ifOp, OpBuilder &builder); + void visitOperation(scf::IndexSwitchOp indexSwitchOp, OpBuilder &builder); - auto aTensorType = cast(ty); - ArrayRef shape = aTensorType.getShape(); + void cloneTerminatorClassicalOperands(Operation *terminator, OpBuilder &builder); - // TODO: Generalize to arbitrary dimensions - if (2 != shape.size()) { - op->emitOpError() << "Caching only supports tensors complex F64"; - } - // TODO: Generalize to other types - Type elementType = aTensorType.getElementType(); - if (!isa(elementType)) { - op->emitOpError() << "Caching only supports tensors complex F64"; - } - // TODO: Generalize to other types - Type f64 = cast(elementType).getElementType(); - if (!f64.isF64()) { - op->emitOpError() << "Caching only supports tensors complex F64"; - } -} + /// Update the internal mapping of the results of `oldOp` to the results of `clonedOp` using the + /// given result remapping. + void mapResults(Operation *oldOp, Operation *clonedOp, + const DenseMap &argIdxMapping); -QuantumCache QuantumCache::initialize(Region ®ion, OpBuilder &builder, Location loc) -{ - MLIRContext *ctx = builder.getContext(); - auto paramVectorType = ArrayListType::get(ctx, builder.getF64Type()); - auto wireVectorType = ArrayListType::get(ctx, builder.getI64Type()); - auto controlFlowTapeType = ArrayListType::get(ctx, builder.getIndexType()); - auto paramVector = ListInitOp::create(builder, loc, paramVectorType); - auto wireVector = ListInitOp::create(builder, loc, wireVectorType); - - // Initialize the tapes that store the structure of control flow. - DenseMap> controlFlowTapes; - region.walk([&](Operation *op) { - if (isa(op)) { - auto tape = catalyst::ListInitOp::create(builder, loc, controlFlowTapeType); - controlFlowTapes.insert({op, tape}); + // Emit an operation to cache a dynamic wire for quantum.insert/extract ops. + template void cacheDynamicWire(IndexingOp op, OpBuilder &builder) + { + if (!op.getIdxAttr().has_value()) { + ListPushOp::create(builder, op.getLoc(), oldToCloned.lookupOrDefault(op.getIdx()), + cache.wireVector); } - }); - return quantum::QuantumCache{ - .paramVector = paramVector, .wireVector = wireVector, .controlFlowTapes = controlFlowTapes}; -} - -void QuantumCache::emitDealloc(OpBuilder &builder, Location loc) -{ - ListDeallocOp::create(builder, loc, paramVector); - ListDeallocOp::create(builder, loc, wireVector); - for (const auto &[_key, controlFlowTape] : controlFlowTapes) { - ListDeallocOp::create(builder, loc, controlFlowTape); } -} + + void cacheGate(quantum::ParametrizedGate gate, OpBuilder &builder); +}; void AugmentedCircuitGenerator::cacheGate(quantum::ParametrizedGate gate, OpBuilder &builder) { @@ -257,9 +233,14 @@ void AugmentedCircuitGenerator::visitOperation(scf::ForOp forOp, OpBuilder &buil } } - // Store the start, stop, and step to this op's control flow tape. + // Store the start, stop, and step to this op's control flow tape, but only when dynamic. + // Constant values can be rematerialized directly in the backward pass, which preverses the + // static info. Value tape = cache.controlFlowTapes.at(forOp); for (Value param : {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()}) { + if (getConstantIntValue(param).has_value()) { + continue; + } ListPushOp::create(builder, forOp.getLoc(), oldToCloned.lookupOrDefault(param), tape); } @@ -435,5 +416,17 @@ void AugmentedCircuitGenerator::mapResults(Operation *oldOp, Operation *clonedOp } } +} // namespace + +namespace catalyst { +namespace quantum { + +void generateAdjointForwardPass(Region ®ion, OpBuilder &builder, IRMapping &oldToCloned, + QuantumCache &cache) +{ + AugmentedCircuitGenerator generator{oldToCloned, cache}; + generator.generate(region, builder); +} + } // namespace quantum } // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/AdjointLowering/QuantumCache.cpp b/mlir/lib/Quantum/Transforms/AdjointLowering/QuantumCache.cpp new file mode 100644 index 0000000000..8ec4c37c52 --- /dev/null +++ b/mlir/lib/Quantum/Transforms/AdjointLowering/QuantumCache.cpp @@ -0,0 +1,93 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "QuantumCache.hpp" + +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "Catalyst/IR/CatalystOps.h" + +using namespace mlir; +using namespace catalyst; + +namespace catalyst { +namespace quantum { + +void verifyTypeIsCacheable(Type ty, Operation *op) +{ + // Sanitizing inputs. + // Technically we know for a fact that none of this will ever issue an + // error. This is because QubitUnitary is guaranteed to have a + // tensor> But this code in the future may be extended to + // support other types. Hence the sanitization. + if (ty.isF64()) { + return; + } + + // TODO: Generalize to unranked tensors + if (!isa(ty)) { + op->emitOpError() << "Caching only supports tensors complex F64"; + } + + auto aTensorType = cast(ty); + ArrayRef shape = aTensorType.getShape(); + + // TODO: Generalize to arbitrary dimensions + if (2 != shape.size()) { + op->emitOpError() << "Caching only supports tensors complex F64"; + } + // TODO: Generalize to other types + Type elementType = aTensorType.getElementType(); + if (!isa(elementType)) { + op->emitOpError() << "Caching only supports tensors complex F64"; + } + // TODO: Generalize to other types + Type f64 = cast(elementType).getElementType(); + if (!f64.isF64()) { + op->emitOpError() << "Caching only supports tensors complex F64"; + } +} + +QuantumCache QuantumCache::initialize(Region ®ion, OpBuilder &builder, Location loc) +{ + MLIRContext *ctx = builder.getContext(); + auto paramVectorType = ArrayListType::get(ctx, builder.getF64Type()); + auto wireVectorType = ArrayListType::get(ctx, builder.getI64Type()); + auto controlFlowTapeType = ArrayListType::get(ctx, builder.getIndexType()); + auto paramVector = ListInitOp::create(builder, loc, paramVectorType); + auto wireVector = ListInitOp::create(builder, loc, wireVectorType); + + // Initialize the tapes that store the structure of control flow. + DenseMap> controlFlowTapes; + region.walk([&](Operation *op) { + if (isa(op)) { + auto tape = catalyst::ListInitOp::create(builder, loc, controlFlowTapeType); + controlFlowTapes.insert({op, tape}); + } + }); + return quantum::QuantumCache{ + .paramVector = paramVector, .wireVector = wireVector, .controlFlowTapes = controlFlowTapes}; +} + +void QuantumCache::emitDealloc(OpBuilder &builder, Location loc) +{ + ListDeallocOp::create(builder, loc, paramVector); + ListDeallocOp::create(builder, loc, wireVector); + for (const auto &[_key, controlFlowTape] : controlFlowTapes) { + ListDeallocOp::create(builder, loc, controlFlowTape); + } +} + +} // namespace quantum +} // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/AdjointLowering/QuantumCache.hpp b/mlir/lib/Quantum/Transforms/AdjointLowering/QuantumCache.hpp new file mode 100644 index 0000000000..5549cc133a --- /dev/null +++ b/mlir/lib/Quantum/Transforms/AdjointLowering/QuantumCache.hpp @@ -0,0 +1,53 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Value.h" + +#include "Catalyst/IR/CatalystDialect.h" + +namespace catalyst { +namespace quantum { + +/// A collection of the data required to reconstruct a deterministic hybrid quantum program with +/// classical preprocessing and arbitrary classical control flow. +/// +/// The forward pass populates these tapes (see ForwardPass.hpp) and the reverse pass consumes them +/// (see ReversePass.hpp). The push/pop order is a shared contract between the two passes: values +/// are pushed in program order during the forward pass and popped in reverse during the backward +/// pass. +struct QuantumCache { + mlir::TypedValue paramVector; + mlir::TypedValue wireVector; + /// For every structured control flow op, store the values required for it to execute. + /// Specifically: store the conditions for scf.if ops, the start/stop/step of scf.for ops, and + /// the number of iterations for scf.while ops. + mlir::DenseMap> controlFlowTapes; + + /// Initialize the quantum cache to traverse and store the necessary parameters for the given + /// `topLevelRegion`. + static QuantumCache initialize(mlir::Region &topLevelRegion, mlir::OpBuilder &builder, + mlir::Location loc); + + void emitDealloc(mlir::OpBuilder &builder, mlir::Location loc); +}; + +/// Verify that `ty` is a type the cache knows how to record (an f64 scalar or a 2D tensor of +/// complex), emitting an error on `op` otherwise. +void verifyTypeIsCacheable(mlir::Type ty, mlir::Operation *op); + +} // namespace quantum +} // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/AdjointPatterns.cpp b/mlir/lib/Quantum/Transforms/AdjointLowering/ReversePass.cpp similarity index 91% rename from mlir/lib/Quantum/Transforms/AdjointPatterns.cpp rename to mlir/lib/Quantum/Transforms/AdjointLowering/ReversePass.cpp index bb53104891..bee112f332 100644 --- a/mlir/lib/Quantum/Transforms/AdjointPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/AdjointLowering/ReversePass.cpp @@ -19,16 +19,14 @@ #include #include -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/Errc.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -39,10 +37,10 @@ #include "Quantum/IR/QuantumInterfaces.h" #include "Quantum/IR/QuantumOps.h" #include "Quantum/IR/QuantumTypes.h" -#include "Quantum/Transforms/Patterns.h" -#include "Quantum/Utils/QuantumSplitting.h" -using llvm::dbgs; +#include "AdjointLowering.hpp" +#include "QuantumCache.hpp" + using namespace mlir; using namespace catalyst; using namespace catalyst::quantum; @@ -106,7 +104,6 @@ class AdjointGenerator { "Expected only structured control flow (each region should have a single block)"); for (Operation &op : llvm::reverse(region.front().without_terminator())) { - LLVM_DEBUG(dbgs() << "generating adjoint for: " << op << "\n"); if (auto callOp = dyn_cast(op)) { visitOperation(callOp, builder); } @@ -495,11 +492,16 @@ class AdjointGenerator { } Value tape = cache.controlFlowTapes.at(forOp); - // Popping the start, stop, and step implies that these are backwards relative to - // the order they were pushed. - Value step = ListPopOp::create(builder, forOp.getLoc(), tape); - Value stop = ListPopOp::create(builder, forOp.getLoc(), tape); - Value start = ListPopOp::create(builder, forOp.getLoc(), tape); + // Re-materialize constant loop bounds directly to preserve static information. + auto recoverBound = [&](Value original) -> Value { + if (std::optional constant = getConstantIntValue(original)) { + return index::ConstantOp::create(builder, forOp.getLoc(), *constant); + } + return ListPopOp::create(builder, forOp.getLoc(), tape); + }; + Value step = recoverBound(forOp.getStep()); + Value stop = recoverBound(forOp.getUpperBound()); + Value start = recoverBound(forOp.getLowerBound()); SmallVector reversedResults; for (auto v : getQuantumValues(forOp.getResults())) { @@ -726,54 +728,16 @@ class AdjointGenerator { bool generationFailed = false; }; -struct AdjointSingleOpRewritePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - /// We build a map from values mentioned in the source data flow to the values of - /// the program where quantum control flow is reversed. Most of the time, there is a 1-to-1 - /// correspondence with a notable exception caused by `insert`/`extract` API asymmetry. - LogicalResult matchAndRewrite(AdjointOp adjoint, PatternRewriter &rewriter) const override - { - LLVM_DEBUG(dbgs() << "Adjointing the following:\n" << adjoint << "\n"); - auto cache = QuantumCache::initialize(adjoint.getRegion(), rewriter, adjoint.getLoc()); - // First, copy the classical computations directly to the target insertion point. - IRMapping oldToCloned; - AugmentedCircuitGenerator augmentedGenerator{oldToCloned, cache}; - augmentedGenerator.generate(adjoint.getRegion(), rewriter); - - // Initialize the backward pass with the operand of the quantum.yield - auto yieldOp = cast(adjoint.getRegion().front().getTerminator()); - for (auto [yieldVal, adjointOperand] : - llvm::zip_equal(yieldOp.getOperands(), adjoint.getArgs())) { - oldToCloned.map(yieldVal, adjointOperand); - } - - // Emit the adjoint quantum operations and reversed control flow, using cached values. - AdjointGenerator adjointGenerator{oldToCloned, cache}; - if (failed(adjointGenerator.generate(adjoint.getRegion(), rewriter))) { - return failure(); - } - - // Explicitly free the memory of the caches. - cache.emitDealloc(rewriter, adjoint.getLoc()); - // The final quantum values are the re-mapped region arguments of the original adjoint op. - SmallVector reversedOutputs; - for (BlockArgument arg : adjoint.getRegion().getArguments()) { - reversedOutputs.push_back(oldToCloned.lookup(arg)); - } - rewriter.replaceOp(adjoint, reversedOutputs); - return success(); - } -}; - } // namespace namespace catalyst { namespace quantum { -void populateAdjointPatterns(RewritePatternSet &patterns) +LogicalResult generateAdjointReversePass(Region ®ion, OpBuilder &builder, + IRMapping &remappedValues, QuantumCache &cache) { - patterns.add(patterns.getContext(), 1); + AdjointGenerator generator{remappedValues, cache}; + return generator.generate(region, builder); } } // namespace quantum diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index ea9215b673..8db4192e30 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -3,20 +3,22 @@ set(LIBRARY_NAME quantum-transforms) add_subdirectory(DecompGraphSolver) file(GLOB SRC - AdjointPatterns.cpp + AdjointLowering/AdjointLowering.cpp + AdjointLowering/ForwardPass.cpp + AdjointLowering/QuantumCache.cpp + AdjointLowering/ReversePass.cpp BufferizableOpInterfaceImpl.cpp CancelInversesPatterns.cpp ConversionPatterns.cpp DecomposeLoweringPatterns.cpp - QPDLoader.cpp DisentangleCNOT.cpp DisentangleSWAP.cpp GridsynthPatterns.cpp IonsDecompositionPatterns.cpp LoopBoundaryOptimizationPatterns.cpp MergeRotationsPatterns.cpp + QPDLoader.cpp SplitMultipleTapes.cpp - adjoint_lowering.cpp cancel_inverses.cpp cp_global_buffers.cpp decompose_lowering.cpp @@ -58,8 +60,9 @@ add_mlir_library(${LIBRARY_NAME} STATIC ) target_compile_features(${LIBRARY_NAME} PUBLIC cxx_std_20) target_include_directories(${LIBRARY_NAME} PUBLIC - . - ${PROJECT_SOURCE_DIR}/include - ${CMAKE_BINARY_DIR}/include - DecompGraphSolver/ - ) + . + ${PROJECT_SOURCE_DIR}/include + ${CMAKE_BINARY_DIR}/include + AdjointLowering/ + DecompGraphSolver/ +) diff --git a/mlir/lib/Quantum/Transforms/adjoint_lowering.cpp b/mlir/lib/Quantum/Transforms/adjoint_lowering.cpp deleted file mode 100644 index c0a32aa7cf..0000000000 --- a/mlir/lib/Quantum/Transforms/adjoint_lowering.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2023 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#define DEBUG_TYPE "adjoint" - -#include -#include - -#include "llvm/Support/Debug.h" -#include "llvm/Support/Errc.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "Catalyst/IR/CatalystDialect.h" -#include "Quantum/IR/QuantumOps.h" -#include "Quantum/Transforms/Patterns.h" - -using namespace llvm; -using namespace mlir; -using namespace catalyst::quantum; - -namespace catalyst { -namespace quantum { - -#define GEN_PASS_DEF_ADJOINTLOWERINGPASS -#include "Quantum/Transforms/Passes.h.inc" - -struct AdjointLoweringPass : impl::AdjointLoweringPassBase { - using AdjointLoweringPassBase::AdjointLoweringPassBase; - - void runOnOperation() final - { - LLVM_DEBUG(dbgs() << "adjoint lowering pass" - << "\n"); - - RewritePatternSet patterns(&getContext()); - populateAdjointPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace quantum -} // namespace catalyst diff --git a/mlir/lib/Quantum/Utils/CMakeLists.txt b/mlir/lib/Quantum/Utils/CMakeLists.txt index 4a2073137b..2f8d375a25 100644 --- a/mlir/lib/Quantum/Utils/CMakeLists.txt +++ b/mlir/lib/Quantum/Utils/CMakeLists.txt @@ -1,4 +1,3 @@ add_mlir_library(QuantumUtils - QuantumSplitting.cpp RemoveQuantum.cpp ) diff --git a/mlir/test/Quantum/AdjointTest.mlir b/mlir/test/Quantum/AdjointTest.mlir index 794c502cd7..216dd46da6 100644 --- a/mlir/test/Quantum/AdjointTest.mlir +++ b/mlir/test/Quantum/AdjointTest.mlir @@ -354,46 +354,90 @@ func.func private @param_ordering(%0: !quantum.reg) -> !quantum.reg { // Test adjoint of scf.for - func.func public @adjoint_for_loop() { - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c0 = arith.constant 0 : index +func.func public @adjoint_for_loop_static() { + // CHECK-DAG: [[start:%.+]] = index.constant 0 + // CHECK-DAG: [[stop:%.+]] = index.constant 4 + // CHECK-DAG: [[step:%.+]] = index.constant 1 + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + + // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[q0:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit - // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg - // CHECK: [[q0:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[q1:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit - %0 = quantum.alloc( 2) : !quantum.reg - %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit - %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + // CHECK-NOT: quantum.adjoint + %3:2 = quantum.adjoint(%1, %2) : !quantum.bit, !quantum.bit { + ^bb0(%arg1: !quantum.bit, %arg2: !quantum.bit): + + // CHECK-NOT: catalyst.list_push + // CHECK-NOT: catalyst.list_pop + // CHECK: [[for_out:%.+]]:2 = scf.for {{%.+}} = [[start]] to [[stop]] step [[step]] + // CHECK-SAME: iter_args(%arg1 = [[q0]], %arg2 = [[q1]]) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[gate2:%.+]]:2 = quantum.custom "gate2"() %arg1, %arg2 adj : !quantum.bit, !quantum.bit + // CHECK: [[gate1:%.+]]:2 = quantum.custom "gate1"() [[gate2]]#0, [[gate2]]#1 adj : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[gate1]]#0, [[gate1]]#1 : !quantum.bit, !quantum.bit + // CHECK: } + %8:2 = scf.for %arg3 = %c0 to %c4 step %c1 iter_args(%arg5 = %arg1, %arg6 = %arg2) -> (!quantum.bit, !quantum.bit) { + %out_qubits:2 = quantum.custom "gate1"() %arg5, %arg6 : !quantum.bit, !quantum.bit + %out_qubits_0:2 = quantum.custom "gate2"() %out_qubits#0, %out_qubits#1 : !quantum.bit, !quantum.bit + scf.yield %out_qubits_0#0, %out_qubits_0#1 : !quantum.bit, !quantum.bit + } - // CHECK-NOT: quantum.adjoint - %3:2 = quantum.adjoint(%1, %2) : !quantum.bit, !quantum.bit { - ^bb0(%arg1: !quantum.bit, %arg2: !quantum.bit): + quantum.yield %8#0, %8#1 : !quantum.bit, !quantum.bit + } - // CHECK: [[for_out:%.+]]:2 = scf.for - // CHECK-SAME: iter_args(%arg1 = [[q0]], %arg2 = [[q1]]) -> (!quantum.bit, !quantum.bit) { - // CHECK: [[gate2:%.+]]:2 = quantum.custom "gate2"() %arg1, %arg2 adj : !quantum.bit, !quantum.bit - // CHECK: [[gate1:%.+]]:2 = quantum.custom "gate1"() [[gate2]]#0, [[gate2]]#1 adj : !quantum.bit, !quantum.bit - // CHECK: scf.yield [[gate1]]#0, [[gate1]]#1 : !quantum.bit, !quantum.bit - // CHECK: } - %8:2 = scf.for %arg3 = %c0 to %c4 step %c1 iter_args(%arg5 = %arg1, %arg6 = %arg2) -> (!quantum.bit, !quantum.bit) { - %out_qubits:2 = quantum.custom "gate1"() %arg5, %arg6 : !quantum.bit, !quantum.bit - %out_qubits_0:2 = quantum.custom "gate2"() %out_qubits#0, %out_qubits#1 : !quantum.bit, !quantum.bit - scf.yield %out_qubits_0#0, %out_qubits_0#1 : !quantum.bit, !quantum.bit - } + // CHECK: [[insert0:%.+]] = quantum.insert [[reg]][ 0], [[for_out]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 1], [[for_out]]#1 : !quantum.reg, !quantum.bit + // CHECK: quantum.dealloc [[insert1]] + %4 = quantum.insert %0[ 0], %3#0 : !quantum.reg, !quantum.bit + %5 = quantum.insert %4[ 1], %3#1 : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + return +} - quantum.yield %8#0, %8#1 : !quantum.bit, !quantum.bit +// ----- + +// Test adjoint of scf.for with dynamic bounds: only the non-constant bound is cached, while +// constant bounds (start, step) are still rematerialized. +func.func public @adjoint_for_loop_mixed_static_dynamic(%stop: index) { + // CHECK-DAG: [[start:%.+]] = index.constant 0 + // CHECK-DAG: [[step:%.+]] = index.constant 1 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + + // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[q0:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + + // CHECK-NOT: quantum.adjoint + %3:2 = quantum.adjoint(%1, %2) : !quantum.bit, !quantum.bit { + ^bb0(%arg1: !quantum.bit, %arg2: !quantum.bit): + + // CHECK: catalyst.list_push %arg0 + // CHECK: [[stop:%.+]] = catalyst.list_pop + // CHECK: [[for_out:%.+]]:2 = scf.for {{%.+}} = [[start]] to [[stop]] step [[step]] + %8:2 = scf.for %arg3 = %c0 to %stop step %c1 iter_args(%arg5 = %arg1, %arg6 = %arg2) -> (!quantum.bit, !quantum.bit) { + %out_qubits:2 = quantum.custom "gate1"() %arg5, %arg6 : !quantum.bit, !quantum.bit + %out_qubits_0:2 = quantum.custom "gate2"() %out_qubits#0, %out_qubits#1 : !quantum.bit, !quantum.bit + scf.yield %out_qubits_0#0, %out_qubits_0#1 : !quantum.bit, !quantum.bit } - // CHECK: [[insert0:%.+]] = quantum.insert [[reg]][ 0], [[for_out]]#0 : !quantum.reg, !quantum.bit - // CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 1], [[for_out]]#1 : !quantum.reg, !quantum.bit - // CHECK: quantum.dealloc [[insert1]] - %4 = quantum.insert %0[ 0], %3#0 : !quantum.reg, !quantum.bit - %5 = quantum.insert %4[ 1], %3#1 : !quantum.reg, !quantum.bit - quantum.dealloc %5 : !quantum.reg - return + quantum.yield %8#0, %8#1 : !quantum.bit, !quantum.bit } + %4 = quantum.insert %0[ 0], %3#0 : !quantum.reg, !quantum.bit + %5 = quantum.insert %4[ 1], %3#1 : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + return +} + // ----- // Test adjoint of scf.while