Skip to content

Commit e595087

Browse files
authored
🐛 Fix conversions of jeff.switch and jeff.for (#1776)
## Description This PR fixes the conversions of `jeff.switch` and `jeff.for` be ensuring that the created regions are isolated from above. Fixes #1725 ## Checklist - [x] The pull request only contains commits that are focused and relevant to this change. - [x] I have added appropriate tests that cover the new/changed functionality. - [x] ~~I have updated the documentation to reflect these changes.~~ - [x] I have added entries to the changelog for any noteworthy additions, changes, fixes, or removals. - [x] ~~I have added migration instructions to the upgrade guide (if needed).~~ - [x] The changes follow the project's style guidelines and introduce no new warnings. - [x] The changes are fully tested and pass the CI checks. - [x] I have reviewed my own code changes.
1 parent 7082086 commit e595087

9 files changed

Lines changed: 237 additions & 57 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel
1515
- ✨ Add a `quantum-loop-unroll` pass for unrolling for-loop operations containing quantum operations ([#1718]) ([**@MatthiasReumann**])
1616
- ✨ Add a `hadamard-lifting` pass for lifting Hadamard gates above Pauli gates ([#1605]) ([**@lirem101**], [**@burgholzer**])
1717
- ✨ Add a `merge-single-qubit-rotation-gates` pass for merging consecutive rotation gates using quaternions ([#1407], [#1674]) ([**@J4MMlE**], [**@denialhaag**], [**@MatthiasReumann**])
18-
- ✨ Add conversions between `jeff` and QCO ([#1479], [#1548], [#1565], [#1637], [#1676], [#1706]) ([**@denialhaag**], [**@burgholzer**])
18+
- ✨ Add conversions between `jeff` and QCO ([#1479], [#1548], [#1565], [#1637], [#1676], [#1706], [#1776]) ([**@denialhaag**], [**@burgholzer**])
1919
- ✨ Add a `place-and-route` pass for mapping circuits to architectures with restricted topologies ([#1537], [#1547], [#1568], [#1581], [#1583], [#1588], [#1600], [#1664], [#1709], [#1716], [#1748]) ([**@MatthiasReumann**], [**@burgholzer**])
2020
- ✨ Add initial infrastructure for new QC and QCO MLIR dialects
2121
([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1717], [#1728], [#1730], [#1749], [#1751], [#1762], [#1765], [#1774])
@@ -402,6 +402,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool
402402

403403
<!-- PR links -->
404404

405+
[#1776]: https://github.com/munich-quantum-toolkit/core/pull/1776
405406
[#1774]: https://github.com/munich-quantum-toolkit/core/pull/1774
406407
[#1765]: https://github.com/munich-quantum-toolkit/core/pull/1765
407408
[#1762]: https://github.com/munich-quantum-toolkit/core/pull/1762

cmake/ExternalDependencies.cmake

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ if(BUILD_MQT_CORE_MLIR)
2727
FetchContent_Declare(
2828
jeff-mlir
2929
GIT_REPOSITORY https://github.com/PennyLaneAI/jeff-mlir.git
30-
GIT_TAG v0.2.0)
30+
# Pinned to an unreleased commit until v0.3.0 is released. jeff-mlir's SCF operations are
31+
# already marked as IsolatedFromAbove in the pinned version.
32+
GIT_TAG 3f34dc3e2865ceaffb8003b2410404306a49f0ab)
3133
list(APPEND FETCH_PACKAGES jeff-mlir)
3234
endif()
3335

mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,21 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder {
105105
*/
106106
Value intConstant(int64_t value);
107107

108+
/**
109+
* @brief Create a constant float value
110+
* @param value The value to store in the constant
111+
* @return The value produced by the constant operation
112+
*
113+
* @par Example:
114+
* ```c++
115+
* auto c = builder.floatConstant(0.123);
116+
* ```
117+
* ```mlir
118+
* %c = arith.constant 0.123 : f64
119+
* ```
120+
*/
121+
Value floatConstant(double value);
122+
108123
//===--------------------------------------------------------------------===//
109124
// Memory Management
110125
//===--------------------------------------------------------------------===//

mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -861,18 +861,52 @@ struct ConvertJeffSwitchOpToQCO final : OpConversionPattern<jeff::SwitchOp> {
861861
op, "qco.if requires exactly two branches");
862862
}
863863

864-
auto qcoIf = IfOp::create(rewriter, op.getLoc(), adaptor.getSelection(),
865-
adaptor.getInValues());
864+
auto isLinearType = [](Type t) {
865+
return isa<jeff::QubitType, jeff::QuregType>(t);
866+
};
867+
868+
auto inValues = adaptor.getInValues();
869+
870+
SmallVector<Value> qubits;
871+
for (auto [value, adapted] : llvm::zip(op.getInValues(), inValues)) {
872+
if (isLinearType(value.getType())) {
873+
qubits.push_back(adapted);
874+
}
875+
}
876+
877+
auto qcoIf =
878+
IfOp::create(rewriter, op.getLoc(), adaptor.getSelection(), qubits);
866879

867880
auto moveRegion = [&](Region& source, Region& dest) -> LogicalResult {
868-
rewriter.inlineRegionBefore(source, dest, dest.end());
869-
Block* block = &dest.front();
870-
TypeConverter::SignatureConversion sc(block->getNumArguments());
871-
if (failed(getTypeConverter()->convertSignatureArgs(
872-
block->getArgumentTypes(), sc))) {
873-
return failure();
881+
auto* oldBlock = &source.back();
882+
auto* newBlock = &dest.emplaceBlock();
883+
rewriter.setInsertionPointToEnd(newBlock);
884+
885+
IRMapping mapping;
886+
for (auto [oldArg, adapted] :
887+
llvm::zip(oldBlock->getArguments(), inValues)) {
888+
if (isLinearType(oldArg.getType())) {
889+
auto newArg = newBlock->addArgument(
890+
typeConverter->convertType(oldArg.getType()), oldArg.getLoc());
891+
mapping.map(oldArg, newArg);
892+
} else {
893+
mapping.map(oldArg, adapted);
894+
}
895+
}
896+
897+
for (auto& op : oldBlock->without_terminator()) {
898+
rewriter.clone(op, mapping);
899+
}
900+
901+
auto* oldTerminator = oldBlock->getTerminator();
902+
SmallVector<Value> yields;
903+
for (auto value : oldTerminator->getOperands()) {
904+
if (isLinearType(value.getType())) {
905+
yields.push_back(rewriter.getRemappedValue(mapping.lookup(value)));
906+
}
874907
}
875-
rewriter.applySignatureConversion(block, sc);
908+
rewriter.replaceOpWithNewOp<YieldOp>(oldTerminator, yields);
909+
876910
return success();
877911
};
878912

@@ -883,7 +917,15 @@ struct ConvertJeffSwitchOpToQCO final : OpConversionPattern<jeff::SwitchOp> {
883917
return failure();
884918
}
885919

886-
rewriter.replaceOp(op, qcoIf.getResults());
920+
SmallVector<Value> results;
921+
size_t index = 0;
922+
for (auto [value, adapted] : llvm::zip(op.getResults(), inValues)) {
923+
results.push_back(isLinearType(value.getType())
924+
? qcoIf.getResults()[index++]
925+
: adapted);
926+
}
927+
rewriter.replaceOp(op, results);
928+
887929
return success();
888930
}
889931
};
@@ -934,8 +976,8 @@ struct ConvertJeffForOpToQCO final : OpConversionPattern<jeff::ForOp> {
934976
auto scfFor = scf::ForOp::create(rewriter, loc, start, stop, step,
935977
adaptor.getInValues());
936978

937-
Block* jeffBody = &op.getBody().front();
938-
Block* scfBody = scfFor.getBody();
979+
auto* jeffBody = &op.getBody().front();
980+
auto* scfBody = scfFor.getBody();
939981

940982
OpBuilder::InsertionGuard guard(rewriter);
941983
rewriter.setInsertionPointToStart(scfBody);
@@ -944,7 +986,7 @@ struct ConvertJeffForOpToQCO final : OpConversionPattern<jeff::ForOp> {
944986
jeffBody->getArgument(0).getType(),
945987
scfFor.getInductionVar());
946988
SmallVector<Value> args = {iv.getResult()};
947-
for (Value arg : scfFor.getRegionIterArgs()) {
989+
for (auto arg : scfFor.getRegionIterArgs()) {
948990
args.push_back(arg);
949991
}
950992

@@ -964,11 +1006,6 @@ struct ConvertJeffYieldOpToQCO final : OpConversionPattern<jeff::YieldOp> {
9641006
LogicalResult
9651007
matchAndRewrite(jeff::YieldOp op, OpAdaptor adaptor,
9661008
ConversionPatternRewriter& rewriter) const override {
967-
if (isa<IfOp>(op->getParentOp())) {
968-
rewriter.replaceOpWithNewOp<YieldOp>(op, adaptor.getOperands());
969-
return success();
970-
}
971-
9721009
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
9731010
return success();
9741011
}

mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp

Lines changed: 88 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#include "mlir/Conversion/QCOToJeff/QCOToJeff.h"
1212

13-
#include "mlir/Conversion/ConversionUtils.h"
1413
#include "mlir/Dialect/QCO/IR/QCODialect.h"
1514
#include "mlir/Dialect/QCO/IR/QCOOps.h"
1615
#include "mlir/Dialect/QTensor/IR/QTensorDialect.h"
@@ -39,6 +38,7 @@
3938
#include <mlir/Support/LLVM.h>
4039
#include <mlir/Support/LogicalResult.h>
4140
#include <mlir/Transforms/DialectConversion.h>
41+
#include <mlir/Transforms/RegionUtils.h>
4242

4343
#include <cassert>
4444
#include <cstddef>
@@ -342,6 +342,45 @@ static LogicalResult cleanUp(Operation* op, LoweringState& state) {
342342
return success();
343343
}
344344

345+
/**
346+
* @brief Move a region from QCO/SCF operation to a jeff operation
347+
*/
348+
static LogicalResult moveRegion(Region& source, Region& dest,
349+
ConversionPatternRewriter& rewriter,
350+
const TypeConverter* typeConverter,
351+
const SetVector<Value>& aboveValues) {
352+
auto* oldBlock = &source.back();
353+
auto* newBlock = &dest.emplaceBlock();
354+
rewriter.setInsertionPointToEnd(newBlock);
355+
356+
IRMapping mapping;
357+
for (auto oldArg : oldBlock->getArguments()) {
358+
auto newArg = newBlock->addArgument(
359+
typeConverter->convertType(oldArg.getType()), oldArg.getLoc());
360+
mapping.map(oldArg, newArg);
361+
}
362+
for (auto value : aboveValues) {
363+
auto newArg = newBlock->addArgument(
364+
typeConverter->convertType(value.getType()), value.getLoc());
365+
mapping.map(value, newArg);
366+
}
367+
368+
for (auto& op : oldBlock->without_terminator()) {
369+
rewriter.clone(op, mapping);
370+
}
371+
372+
auto* oldTerminator = oldBlock->getTerminator();
373+
SmallVector<Value> yields;
374+
for (auto value : oldTerminator->getOperands()) {
375+
yields.push_back(rewriter.getRemappedValue(mapping.lookup(value)));
376+
}
377+
llvm::append_range(yields,
378+
newBlock->getArguments().take_back(aboveValues.size()));
379+
rewriter.replaceOpWithNewOp<jeff::YieldOp>(oldTerminator, yields);
380+
381+
return success();
382+
}
383+
345384
namespace {
346385

347386
/**
@@ -963,13 +1002,8 @@ struct ConvertQCOYieldOpToJeff final : StatefulOpConversionPattern<YieldOp> {
9631002
using StatefulOpConversionPattern::StatefulOpConversionPattern;
9641003

9651004
LogicalResult
966-
matchAndRewrite(YieldOp op, OpAdaptor adaptor,
1005+
matchAndRewrite(YieldOp op, OpAdaptor /*adaptor*/,
9671006
ConversionPatternRewriter& rewriter) const override {
968-
if (isa<jeff::SwitchOp>(op->getParentOp())) {
969-
rewriter.replaceOpWithNewOp<jeff::YieldOp>(op, adaptor.getOperands());
970-
return success();
971-
}
972-
9731007
auto& state = getState();
9741008

9751009
if (state.inInvOp) {
@@ -1036,37 +1070,55 @@ struct ConvertQCOIfOpToJeff final : StatefulOpConversionPattern<IfOp> {
10361070
matchAndRewrite(IfOp op, OpAdaptor adaptor,
10371071
ConversionPatternRewriter& rewriter) const override {
10381072
auto loc = op.getLoc();
1073+
1074+
SetVector<Value> aboveValues;
1075+
getUsedValuesDefinedAbove(op.getElseRegion(), aboveValues);
1076+
getUsedValuesDefinedAbove(op.getThenRegion(), aboveValues);
1077+
1078+
SmallVector<Value> initArgs;
1079+
llvm::append_range(initArgs, adaptor.getQubits());
1080+
10391081
SmallVector<Type> outTypes;
10401082
if (failed(
10411083
getTypeConverter()->convertTypes(op.getResultTypes(), outTypes))) {
10421084
return failure();
10431085
}
10441086

1045-
auto jeffIf =
1046-
jeff::SwitchOp::create(rewriter, loc, outTypes, adaptor.getCondition(),
1047-
adaptor.getQubits(), 2);
1087+
for (auto value : aboveValues) {
1088+
auto remappedValue = rewriter.getRemappedValue(value);
1089+
initArgs.push_back(remappedValue);
1090+
outTypes.push_back(remappedValue.getType());
1091+
}
1092+
1093+
auto jeffSwitch = jeff::SwitchOp::create(
1094+
rewriter, loc, outTypes, adaptor.getCondition(), initArgs, 2);
10481095

1049-
if (failed(moveRegion(op.getElseRegion(), jeffIf.getBranches()[0], rewriter,
1050-
getTypeConverter()))) {
1096+
if (failed(moveRegion(op.getElseRegion(), jeffSwitch.getBranches()[0],
1097+
rewriter, getTypeConverter(), aboveValues))) {
10511098
return failure();
10521099
}
1053-
if (failed(moveRegion(op.getThenRegion(), jeffIf.getBranches()[1], rewriter,
1054-
getTypeConverter()))) {
1100+
if (failed(moveRegion(op.getThenRegion(), jeffSwitch.getBranches()[1],
1101+
rewriter, getTypeConverter(), aboveValues))) {
10551102
return failure();
10561103
}
10571104

10581105
// Add trivial default case
10591106
{
1060-
auto* block = &jeffIf.getDefault().emplaceBlock();
1107+
auto* block = &jeffSwitch.getDefault().emplaceBlock();
10611108
for (auto value : adaptor.getQubits()) {
10621109
block->addArgument(value.getType(), loc);
10631110
}
1111+
for (auto value : aboveValues) {
1112+
block->addArgument(typeConverter->convertType(value.getType()), loc);
1113+
}
10641114
OpBuilder::InsertionGuard guard(rewriter);
10651115
rewriter.setInsertionPointToStart(block);
10661116
jeff::YieldOp::create(rewriter, loc, block->getArguments());
10671117
}
10681118

1069-
rewriter.replaceOp(op, jeffIf.getResults());
1119+
rewriter.replaceOp(op,
1120+
jeffSwitch.getResults().take_front(op.getNumResults()));
1121+
10701122
return success();
10711123
}
10721124
};
@@ -1104,34 +1156,35 @@ struct ConvertSCFForOpToJeff final : StatefulOpConversionPattern<scf::ForOp> {
11041156
LogicalResult
11051157
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
11061158
ConversionPatternRewriter& rewriter) const override {
1159+
SetVector<Value> aboveValues;
1160+
getUsedValuesDefinedAbove(op.getRegion(), aboveValues);
1161+
1162+
SmallVector<Value> initArgs;
1163+
llvm::append_range(initArgs, adaptor.getInitArgs());
1164+
11071165
SmallVector<Type> outTypes;
11081166
if (failed(
11091167
getTypeConverter()->convertTypes(op.getResultTypes(), outTypes))) {
11101168
return failure();
11111169
}
11121170

1171+
for (auto value : aboveValues) {
1172+
auto remappedValue = rewriter.getRemappedValue(value);
1173+
initArgs.push_back(remappedValue);
1174+
outTypes.push_back(remappedValue.getType());
1175+
}
1176+
11131177
auto jeffFor = jeff::ForOp::create(
11141178
rewriter, op.getLoc(), outTypes, adaptor.getLowerBound(),
1115-
adaptor.getUpperBound(), adaptor.getStep(), adaptor.getInitArgs());
1179+
adaptor.getUpperBound(), adaptor.getStep(), initArgs);
11161180

11171181
if (failed(moveRegion(op.getRegion(), jeffFor.getRegion(), rewriter,
1118-
getTypeConverter()))) {
1182+
getTypeConverter(), aboveValues))) {
11191183
return failure();
11201184
}
11211185

1122-
rewriter.replaceOp(op, jeffFor.getResults());
1123-
return success();
1124-
}
1125-
};
1126-
1127-
struct ConvertSCFYieldOpToJeff final
1128-
: StatefulOpConversionPattern<scf::YieldOp> {
1129-
using StatefulOpConversionPattern::StatefulOpConversionPattern;
1186+
rewriter.replaceOp(op, jeffFor.getResults().take_front(op.getNumResults()));
11301187

1131-
LogicalResult
1132-
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
1133-
ConversionPatternRewriter& rewriter) const override {
1134-
rewriter.replaceOpWithNewOp<jeff::YieldOp>(op, adaptor.getResults());
11351188
return success();
11361189
}
11371190
};
@@ -1414,11 +1467,11 @@ struct QCOToJeff final : impl::QCOToJeffBase<QCOToJeff> {
14141467
addQCOToJeffGatePattern<JK::Custom, 2, 2, XXMinusYYOp, void, false>(
14151468
patterns, typeConverter, context, state, "xx_minus_yy");
14161469

1417-
patterns.add<ConvertQCOBarrierOpToJeff, ConvertQCOCtrlOpToJeff,
1418-
ConvertQCOInvOpToJeff, ConvertQCOYieldOpToJeff,
1419-
ConvertQCOIfOpToJeff, ConvertSCFForOpToJeff,
1420-
ConvertSCFYieldOpToJeff, ConvertQCOMainToJeff>(
1421-
typeConverter, context, &state);
1470+
patterns
1471+
.add<ConvertQCOBarrierOpToJeff, ConvertQCOCtrlOpToJeff,
1472+
ConvertQCOInvOpToJeff, ConvertQCOYieldOpToJeff,
1473+
ConvertQCOIfOpToJeff, ConvertSCFForOpToJeff, ConvertQCOMainToJeff>(
1474+
typeConverter, context, &state);
14221475

14231476
// Apply the conversion
14241477
if (applyPartialConversion(module, target, std::move(patterns)).failed()) {

mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ Value QCOProgramBuilder::intConstant(const int64_t value) {
7474
return arith::ConstantOp::create(*this, getI64IntegerAttr(value)).getResult();
7575
}
7676

77+
Value QCOProgramBuilder::floatConstant(const double value) {
78+
checkFinalized();
79+
return arith::ConstantOp::create(*this, getF64FloatAttr(value)).getResult();
80+
}
81+
7782
Value& QCOProgramBuilder::QubitRegister::operator[](const size_t index) {
7883
if (index >= qubits.size()) {
7984
llvm::reportFatalUsageError("Qubit index out of bounds");

0 commit comments

Comments
 (0)