|
10 | 10 |
|
11 | 11 | #include "mlir/Conversion/QCOToJeff/QCOToJeff.h" |
12 | 12 |
|
13 | | -#include "mlir/Conversion/ConversionUtils.h" |
14 | 13 | #include "mlir/Dialect/QCO/IR/QCODialect.h" |
15 | 14 | #include "mlir/Dialect/QCO/IR/QCOOps.h" |
16 | 15 | #include "mlir/Dialect/QTensor/IR/QTensorDialect.h" |
|
39 | 38 | #include <mlir/Support/LLVM.h> |
40 | 39 | #include <mlir/Support/LogicalResult.h> |
41 | 40 | #include <mlir/Transforms/DialectConversion.h> |
| 41 | +#include <mlir/Transforms/RegionUtils.h> |
42 | 42 |
|
43 | 43 | #include <cassert> |
44 | 44 | #include <cstddef> |
@@ -342,6 +342,45 @@ static LogicalResult cleanUp(Operation* op, LoweringState& state) { |
342 | 342 | return success(); |
343 | 343 | } |
344 | 344 |
|
| 345 | +/** |
| 346 | + * @brief Move a region from QCO/SCF operation to a jeff operation |
| 347 | + */ |
| 348 | +static LogicalResult moveRegion(Region& source, Region& dest, |
| 349 | + ConversionPatternRewriter& rewriter, |
| 350 | + const TypeConverter* typeConverter, |
| 351 | + const SetVector<Value>& aboveValues) { |
| 352 | + auto* oldBlock = &source.back(); |
| 353 | + auto* newBlock = &dest.emplaceBlock(); |
| 354 | + rewriter.setInsertionPointToEnd(newBlock); |
| 355 | + |
| 356 | + IRMapping mapping; |
| 357 | + for (auto oldArg : oldBlock->getArguments()) { |
| 358 | + auto newArg = newBlock->addArgument( |
| 359 | + typeConverter->convertType(oldArg.getType()), oldArg.getLoc()); |
| 360 | + mapping.map(oldArg, newArg); |
| 361 | + } |
| 362 | + for (auto value : aboveValues) { |
| 363 | + auto newArg = newBlock->addArgument( |
| 364 | + typeConverter->convertType(value.getType()), value.getLoc()); |
| 365 | + mapping.map(value, newArg); |
| 366 | + } |
| 367 | + |
| 368 | + for (auto& op : oldBlock->without_terminator()) { |
| 369 | + rewriter.clone(op, mapping); |
| 370 | + } |
| 371 | + |
| 372 | + auto* oldTerminator = oldBlock->getTerminator(); |
| 373 | + SmallVector<Value> yields; |
| 374 | + for (auto value : oldTerminator->getOperands()) { |
| 375 | + yields.push_back(rewriter.getRemappedValue(mapping.lookup(value))); |
| 376 | + } |
| 377 | + llvm::append_range(yields, |
| 378 | + newBlock->getArguments().take_back(aboveValues.size())); |
| 379 | + rewriter.replaceOpWithNewOp<jeff::YieldOp>(oldTerminator, yields); |
| 380 | + |
| 381 | + return success(); |
| 382 | +} |
| 383 | + |
345 | 384 | namespace { |
346 | 385 |
|
347 | 386 | /** |
@@ -963,13 +1002,8 @@ struct ConvertQCOYieldOpToJeff final : StatefulOpConversionPattern<YieldOp> { |
963 | 1002 | using StatefulOpConversionPattern::StatefulOpConversionPattern; |
964 | 1003 |
|
965 | 1004 | LogicalResult |
966 | | - matchAndRewrite(YieldOp op, OpAdaptor adaptor, |
| 1005 | + matchAndRewrite(YieldOp op, OpAdaptor /*adaptor*/, |
967 | 1006 | ConversionPatternRewriter& rewriter) const override { |
968 | | - if (isa<jeff::SwitchOp>(op->getParentOp())) { |
969 | | - rewriter.replaceOpWithNewOp<jeff::YieldOp>(op, adaptor.getOperands()); |
970 | | - return success(); |
971 | | - } |
972 | | - |
973 | 1007 | auto& state = getState(); |
974 | 1008 |
|
975 | 1009 | if (state.inInvOp) { |
@@ -1036,37 +1070,55 @@ struct ConvertQCOIfOpToJeff final : StatefulOpConversionPattern<IfOp> { |
1036 | 1070 | matchAndRewrite(IfOp op, OpAdaptor adaptor, |
1037 | 1071 | ConversionPatternRewriter& rewriter) const override { |
1038 | 1072 | auto loc = op.getLoc(); |
| 1073 | + |
| 1074 | + SetVector<Value> aboveValues; |
| 1075 | + getUsedValuesDefinedAbove(op.getElseRegion(), aboveValues); |
| 1076 | + getUsedValuesDefinedAbove(op.getThenRegion(), aboveValues); |
| 1077 | + |
| 1078 | + SmallVector<Value> initArgs; |
| 1079 | + llvm::append_range(initArgs, adaptor.getQubits()); |
| 1080 | + |
1039 | 1081 | SmallVector<Type> outTypes; |
1040 | 1082 | if (failed( |
1041 | 1083 | getTypeConverter()->convertTypes(op.getResultTypes(), outTypes))) { |
1042 | 1084 | return failure(); |
1043 | 1085 | } |
1044 | 1086 |
|
1045 | | - auto jeffIf = |
1046 | | - jeff::SwitchOp::create(rewriter, loc, outTypes, adaptor.getCondition(), |
1047 | | - adaptor.getQubits(), 2); |
| 1087 | + for (auto value : aboveValues) { |
| 1088 | + auto remappedValue = rewriter.getRemappedValue(value); |
| 1089 | + initArgs.push_back(remappedValue); |
| 1090 | + outTypes.push_back(remappedValue.getType()); |
| 1091 | + } |
| 1092 | + |
| 1093 | + auto jeffSwitch = jeff::SwitchOp::create( |
| 1094 | + rewriter, loc, outTypes, adaptor.getCondition(), initArgs, 2); |
1048 | 1095 |
|
1049 | | - if (failed(moveRegion(op.getElseRegion(), jeffIf.getBranches()[0], rewriter, |
1050 | | - getTypeConverter()))) { |
| 1096 | + if (failed(moveRegion(op.getElseRegion(), jeffSwitch.getBranches()[0], |
| 1097 | + rewriter, getTypeConverter(), aboveValues))) { |
1051 | 1098 | return failure(); |
1052 | 1099 | } |
1053 | | - if (failed(moveRegion(op.getThenRegion(), jeffIf.getBranches()[1], rewriter, |
1054 | | - getTypeConverter()))) { |
| 1100 | + if (failed(moveRegion(op.getThenRegion(), jeffSwitch.getBranches()[1], |
| 1101 | + rewriter, getTypeConverter(), aboveValues))) { |
1055 | 1102 | return failure(); |
1056 | 1103 | } |
1057 | 1104 |
|
1058 | 1105 | // Add trivial default case |
1059 | 1106 | { |
1060 | | - auto* block = &jeffIf.getDefault().emplaceBlock(); |
| 1107 | + auto* block = &jeffSwitch.getDefault().emplaceBlock(); |
1061 | 1108 | for (auto value : adaptor.getQubits()) { |
1062 | 1109 | block->addArgument(value.getType(), loc); |
1063 | 1110 | } |
| 1111 | + for (auto value : aboveValues) { |
| 1112 | + block->addArgument(typeConverter->convertType(value.getType()), loc); |
| 1113 | + } |
1064 | 1114 | OpBuilder::InsertionGuard guard(rewriter); |
1065 | 1115 | rewriter.setInsertionPointToStart(block); |
1066 | 1116 | jeff::YieldOp::create(rewriter, loc, block->getArguments()); |
1067 | 1117 | } |
1068 | 1118 |
|
1069 | | - rewriter.replaceOp(op, jeffIf.getResults()); |
| 1119 | + rewriter.replaceOp(op, |
| 1120 | + jeffSwitch.getResults().take_front(op.getNumResults())); |
| 1121 | + |
1070 | 1122 | return success(); |
1071 | 1123 | } |
1072 | 1124 | }; |
@@ -1104,34 +1156,35 @@ struct ConvertSCFForOpToJeff final : StatefulOpConversionPattern<scf::ForOp> { |
1104 | 1156 | LogicalResult |
1105 | 1157 | matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, |
1106 | 1158 | ConversionPatternRewriter& rewriter) const override { |
| 1159 | + SetVector<Value> aboveValues; |
| 1160 | + getUsedValuesDefinedAbove(op.getRegion(), aboveValues); |
| 1161 | + |
| 1162 | + SmallVector<Value> initArgs; |
| 1163 | + llvm::append_range(initArgs, adaptor.getInitArgs()); |
| 1164 | + |
1107 | 1165 | SmallVector<Type> outTypes; |
1108 | 1166 | if (failed( |
1109 | 1167 | getTypeConverter()->convertTypes(op.getResultTypes(), outTypes))) { |
1110 | 1168 | return failure(); |
1111 | 1169 | } |
1112 | 1170 |
|
| 1171 | + for (auto value : aboveValues) { |
| 1172 | + auto remappedValue = rewriter.getRemappedValue(value); |
| 1173 | + initArgs.push_back(remappedValue); |
| 1174 | + outTypes.push_back(remappedValue.getType()); |
| 1175 | + } |
| 1176 | + |
1113 | 1177 | auto jeffFor = jeff::ForOp::create( |
1114 | 1178 | rewriter, op.getLoc(), outTypes, adaptor.getLowerBound(), |
1115 | | - adaptor.getUpperBound(), adaptor.getStep(), adaptor.getInitArgs()); |
| 1179 | + adaptor.getUpperBound(), adaptor.getStep(), initArgs); |
1116 | 1180 |
|
1117 | 1181 | if (failed(moveRegion(op.getRegion(), jeffFor.getRegion(), rewriter, |
1118 | | - getTypeConverter()))) { |
| 1182 | + getTypeConverter(), aboveValues))) { |
1119 | 1183 | return failure(); |
1120 | 1184 | } |
1121 | 1185 |
|
1122 | | - rewriter.replaceOp(op, jeffFor.getResults()); |
1123 | | - return success(); |
1124 | | - } |
1125 | | -}; |
1126 | | - |
1127 | | -struct ConvertSCFYieldOpToJeff final |
1128 | | - : StatefulOpConversionPattern<scf::YieldOp> { |
1129 | | - using StatefulOpConversionPattern::StatefulOpConversionPattern; |
| 1186 | + rewriter.replaceOp(op, jeffFor.getResults().take_front(op.getNumResults())); |
1130 | 1187 |
|
1131 | | - LogicalResult |
1132 | | - matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, |
1133 | | - ConversionPatternRewriter& rewriter) const override { |
1134 | | - rewriter.replaceOpWithNewOp<jeff::YieldOp>(op, adaptor.getResults()); |
1135 | 1188 | return success(); |
1136 | 1189 | } |
1137 | 1190 | }; |
@@ -1414,11 +1467,11 @@ struct QCOToJeff final : impl::QCOToJeffBase<QCOToJeff> { |
1414 | 1467 | addQCOToJeffGatePattern<JK::Custom, 2, 2, XXMinusYYOp, void, false>( |
1415 | 1468 | patterns, typeConverter, context, state, "xx_minus_yy"); |
1416 | 1469 |
|
1417 | | - patterns.add<ConvertQCOBarrierOpToJeff, ConvertQCOCtrlOpToJeff, |
1418 | | - ConvertQCOInvOpToJeff, ConvertQCOYieldOpToJeff, |
1419 | | - ConvertQCOIfOpToJeff, ConvertSCFForOpToJeff, |
1420 | | - ConvertSCFYieldOpToJeff, ConvertQCOMainToJeff>( |
1421 | | - typeConverter, context, &state); |
| 1470 | + patterns |
| 1471 | + .add<ConvertQCOBarrierOpToJeff, ConvertQCOCtrlOpToJeff, |
| 1472 | + ConvertQCOInvOpToJeff, ConvertQCOYieldOpToJeff, |
| 1473 | + ConvertQCOIfOpToJeff, ConvertSCFForOpToJeff, ConvertQCOMainToJeff>( |
| 1474 | + typeConverter, context, &state); |
1422 | 1475 |
|
1423 | 1476 | // Apply the conversion |
1424 | 1477 | if (applyPartialConversion(module, target, std::move(patterns)).failed()) { |
|
0 commit comments