|
24 | 24 | #include <mlir/Dialect/Arith/IR/Arith.h> |
25 | 25 | #include <mlir/Dialect/Func/IR/FuncOps.h> |
26 | 26 | #include <mlir/Dialect/Math/IR/Math.h> |
| 27 | +#include <mlir/Dialect/SCF/IR/SCF.h> |
27 | 28 | #include <mlir/Dialect/Tensor/IR/Tensor.h> |
28 | 29 | #include <mlir/Dialect/Utils/StaticValueUtils.h> |
29 | 30 | #include <mlir/IR/Builders.h> |
30 | 31 | #include <mlir/IR/BuiltinAttributes.h> |
31 | 32 | #include <mlir/IR/BuiltinOps.h> |
32 | 33 | #include <mlir/IR/MLIRContext.h> |
33 | 34 | #include <mlir/IR/PatternMatch.h> |
| 35 | +#include <mlir/IR/Region.h> |
34 | 36 | #include <mlir/IR/Types.h> |
35 | 37 | #include <mlir/IR/ValueRange.h> |
36 | 38 | #include <mlir/Support/LLVM.h> |
@@ -813,6 +815,165 @@ struct ConvertJeffPPROpToQCO final : OpConversionPattern<jeff::PPROp> { |
813 | 815 | } |
814 | 816 | }; |
815 | 817 |
|
| 818 | +/** |
| 819 | + * @brief Converts jeff.switch to qco.if |
| 820 | + * |
| 821 | + * @par Example: |
| 822 | + * ```mlir |
| 823 | + * %q_out = jeff.switch(%condition) : i1 -> (!jeff.qubit) |
| 824 | + * case 0 args(%a = %q_in) { |
| 825 | + * %jeff.yield %a : !jeff.qubit |
| 826 | + * } |
| 827 | + * case 1 args(%a = %q_in) { |
| 828 | + * %q_res = jeff.x {is_adjoint = false, num_ctrls = 0 : i8, power = 1 : i8} %a |
| 829 | + * : !jeff.qubit |
| 830 | + * jeff.yield %q_res : !jeff.qubit |
| 831 | + * } |
| 832 | + * default args(%a = %q_in) { |
| 833 | + * jeff.yield %a : !jeff.qubit |
| 834 | + * } |
| 835 | + * ``` |
| 836 | + * is converted to |
| 837 | + * ```mlir |
| 838 | + * %q_out = qco.if %condition args(%a = %q_in) -> (!qco.qubit) { |
| 839 | + * %q_res = qco.x %a : !qco.qubit -> !qco.qubit |
| 840 | + * qco.yield %q_res : !qco.qubit |
| 841 | + * } else args(%a = %q_in) { |
| 842 | + * qco.yield %a : !qco.qubit |
| 843 | + * } |
| 844 | + * ``` |
| 845 | + */ |
| 846 | +struct ConvertJeffSwitchOpToQCO final : OpConversionPattern<jeff::SwitchOp> { |
| 847 | + using OpConversionPattern::OpConversionPattern; |
| 848 | + |
| 849 | + LogicalResult |
| 850 | + matchAndRewrite(jeff::SwitchOp op, OpAdaptor adaptor, |
| 851 | + ConversionPatternRewriter& rewriter) const override { |
| 852 | + if (!adaptor.getSelection().getType().isInteger(1)) { |
| 853 | + return rewriter.notifyMatchFailure(op, "qco.if requires an i1 selector"); |
| 854 | + } |
| 855 | + if (op.getDefault().front().getOperations().size() != 1) { |
| 856 | + return rewriter.notifyMatchFailure( |
| 857 | + op, "qco.if requires a trivial default branch"); |
| 858 | + } |
| 859 | + if (op.getBranches().size() != 2) { |
| 860 | + return rewriter.notifyMatchFailure( |
| 861 | + op, "qco.if requires exactly two branches"); |
| 862 | + } |
| 863 | + |
| 864 | + auto qcoIf = IfOp::create(rewriter, op.getLoc(), adaptor.getSelection(), |
| 865 | + adaptor.getInValues()); |
| 866 | + |
| 867 | + auto moveRegion = [&](Region& source, Region& dest) -> LogicalResult { |
| 868 | + rewriter.inlineRegionBefore(source, dest, dest.end()); |
| 869 | + Block* block = &dest.front(); |
| 870 | + TypeConverter::SignatureConversion sc(block->getNumArguments()); |
| 871 | + if (failed(getTypeConverter()->convertSignatureArgs( |
| 872 | + block->getArgumentTypes(), sc))) { |
| 873 | + return failure(); |
| 874 | + } |
| 875 | + rewriter.applySignatureConversion(block, sc); |
| 876 | + return success(); |
| 877 | + }; |
| 878 | + |
| 879 | + if (failed(moveRegion(op.getBranches()[0], qcoIf.getElseRegion()))) { |
| 880 | + return failure(); |
| 881 | + } |
| 882 | + if (failed(moveRegion(op.getBranches()[1], qcoIf.getThenRegion()))) { |
| 883 | + return failure(); |
| 884 | + } |
| 885 | + |
| 886 | + rewriter.replaceOp(op, qcoIf.getResults()); |
| 887 | + return success(); |
| 888 | + } |
| 889 | +}; |
| 890 | + |
| 891 | +/** |
| 892 | + * @brief Converts jeff.for to scf.for |
| 893 | + * |
| 894 | + * @par Example: |
| 895 | + * ```mlir |
| 896 | + * %reg_out = jeff.for %iv = %start to %stop step %step args(%a = %reg_in) -> |
| 897 | + * (!jeff.qureg<2>) : i32 { |
| 898 | + * %reg0, %q0 = jeff.qureg_extract_index(%iv) %a : (!jeff.qureg<2>, i32) -> |
| 899 | + * (!jeff.qureg<2>, !jeff.qubit) |
| 900 | + * %q1 = jeff.h {is_adjoint = false, num_ctrls = 0 : i8, power = 1 : i8} %q0 : |
| 901 | + * !jeff.qubit |
| 902 | + * %reg1 = jeff.qureg_insert_index(%iv) %reg0 %q1 : (!jeff.qureg<2>, i32, |
| 903 | + * !jeff.qubit) -> !jeff.qureg<2> |
| 904 | + * jeff.yield %reg1 : !jeff.qureg<2> |
| 905 | + * } |
| 906 | + * ``` |
| 907 | + * is converted to |
| 908 | + * ```mlir |
| 909 | + * %reg_out = scf.for %iv = %start to %stop step %step iter_args(%a = %reg_in) |
| 910 | + * -> (tensor<2x!qco.qubit>) { |
| 911 | + * %reg0, %q0 = qtensor.extract %a[%iv] : tensor<2x!qco.qubit> |
| 912 | + * %q1 = qco.h %q0 : !qco.qubit -> !qco.qubit |
| 913 | + * %reg1 = qtensor.insert %q1 into %reg0[%iv] : tensor<2x!qco.qubit> |
| 914 | + * scf.yield %reg1 : tensor<2x!qco.qubit> |
| 915 | + * } |
| 916 | + * ``` |
| 917 | + */ |
| 918 | +struct ConvertJeffForOpToQCO final : OpConversionPattern<jeff::ForOp> { |
| 919 | + using OpConversionPattern::OpConversionPattern; |
| 920 | + |
| 921 | + LogicalResult |
| 922 | + matchAndRewrite(jeff::ForOp op, OpAdaptor adaptor, |
| 923 | + ConversionPatternRewriter& rewriter) const override { |
| 924 | + auto loc = op.getLoc(); |
| 925 | + auto indexType = rewriter.getIndexType(); |
| 926 | + |
| 927 | + auto start = arith::IndexCastOp::create(rewriter, loc, indexType, |
| 928 | + adaptor.getStart()); |
| 929 | + auto stop = |
| 930 | + arith::IndexCastOp::create(rewriter, loc, indexType, adaptor.getStop()); |
| 931 | + auto step = |
| 932 | + arith::IndexCastOp::create(rewriter, loc, indexType, adaptor.getStep()); |
| 933 | + |
| 934 | + auto scfFor = scf::ForOp::create(rewriter, loc, start, stop, step, |
| 935 | + adaptor.getInValues()); |
| 936 | + |
| 937 | + Block* jeffBody = &op.getBody().front(); |
| 938 | + Block* scfBody = scfFor.getBody(); |
| 939 | + |
| 940 | + OpBuilder::InsertionGuard guard(rewriter); |
| 941 | + rewriter.setInsertionPointToStart(scfBody); |
| 942 | + |
| 943 | + auto iv = arith::IndexCastOp::create(rewriter, loc, |
| 944 | + jeffBody->getArgument(0).getType(), |
| 945 | + scfFor.getInductionVar()); |
| 946 | + SmallVector<Value> args = {iv.getResult()}; |
| 947 | + for (Value arg : scfFor.getRegionIterArgs()) { |
| 948 | + args.push_back(arg); |
| 949 | + } |
| 950 | + |
| 951 | + rewriter.mergeBlocks(jeffBody, scfBody, args); |
| 952 | + |
| 953 | + rewriter.replaceOp(op, scfFor.getResults()); |
| 954 | + return success(); |
| 955 | + } |
| 956 | +}; |
| 957 | + |
| 958 | +/** |
| 959 | + * @brief Converts jeff.yield to QCO |
| 960 | + */ |
| 961 | +struct ConvertJeffYieldOpToQCO final : OpConversionPattern<jeff::YieldOp> { |
| 962 | + using OpConversionPattern::OpConversionPattern; |
| 963 | + |
| 964 | + LogicalResult |
| 965 | + matchAndRewrite(jeff::YieldOp op, OpAdaptor adaptor, |
| 966 | + ConversionPatternRewriter& rewriter) const override { |
| 967 | + if (isa<IfOp>(op->getParentOp())) { |
| 968 | + rewriter.replaceOpWithNewOp<YieldOp>(op, adaptor.getOperands()); |
| 969 | + return success(); |
| 970 | + } |
| 971 | + |
| 972 | + rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands()); |
| 973 | + return success(); |
| 974 | + } |
| 975 | +}; |
| 976 | + |
816 | 977 | /** |
817 | 978 | * @brief Converts the jeff-style main function to a QCO-style main function |
818 | 979 | * |
@@ -914,7 +1075,7 @@ struct JeffToQCO final : impl::JeffToQCOBase<JeffToQCO> { |
914 | 1075 | target.addIllegalDialect<jeff::JeffDialect>(); |
915 | 1076 | target.addLegalDialect<QCODialect, qtensor::QTensorDialect, |
916 | 1077 | arith::ArithDialect, math::MathDialect, |
917 | | - tensor::TensorDialect>(); |
| 1078 | + tensor::TensorDialect, scf::SCFDialect>(); |
918 | 1079 |
|
919 | 1080 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
920 | 1081 | return !(op.getSymName() == getEntryPointName(module) && |
@@ -943,7 +1104,8 @@ struct JeffToQCO final : impl::JeffToQCOBase<JeffToQCO> { |
943 | 1104 | ConvertJeffOneTargetOneParameterToQCO<jeff::RzOp, RZOp>, |
944 | 1105 | ConvertJeffOneTargetOneParameterToQCO<jeff::R1Op, POp>, |
945 | 1106 | ConvertJeffUOpToQCO, ConvertJeffSwapOpToQCO, ConvertJeffCustomOpToQCO, |
946 | | - ConvertJeffPPROpToQCO, ConvertJeffMainToQCO>(typeConverter, context); |
| 1107 | + ConvertJeffPPROpToQCO, ConvertJeffSwitchOpToQCO, ConvertJeffForOpToQCO, |
| 1108 | + ConvertJeffYieldOpToQCO, ConvertJeffMainToQCO>(typeConverter, context); |
947 | 1109 |
|
948 | 1110 | // Apply the conversion |
949 | 1111 | if (applyPartialConversion(module, target, std::move(patterns)).failed()) { |
|
0 commit comments