diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 02a57fa..69e7210 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -23,14 +23,6 @@ jobs: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - name: Install capnp - run: | - curl -O https://capnproto.org/capnproto-c++-1.3.0.tar.gz - tar zxf capnproto-c++-1.3.0.tar.gz - cd capnproto-c++-1.3.0 - ./configure - make -j6 check - sudo make install - name: Set up Clang ${{ env.CLANG_VERSION }} env: TEMP_DIR: ${{ runner.temp }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0ee2ff7..fcebec6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,14 +19,6 @@ jobs: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - name: Install capnp - run: | - curl -O https://capnproto.org/capnproto-c++-1.3.0.tar.gz - tar zxf capnproto-c++-1.3.0.tar.gz - cd capnproto-c++-1.3.0 - ./configure - make -j6 check - sudo make install - name: Set up MLIR uses: munich-quantum-software/setup-mlir@97da765dfa9dc8e055611ca934fe51bcade8140c # v1.3.0 with: diff --git a/cmake/ExternalDependencies.cmake b/cmake/ExternalDependencies.cmake index 52e7e35..ba00ade 100644 --- a/cmake/ExternalDependencies.cmake +++ b/cmake/ExternalDependencies.cmake @@ -1,4 +1,5 @@ include(FetchContent) + set(FETCH_PACKAGES "") if(BUILD_JEFF_MLIR_TRANSLATION) @@ -9,7 +10,23 @@ if(BUILD_JEFF_MLIR_TRANSLATION) ) list(APPEND FETCH_PACKAGES jeff) - find_package(CapnProto REQUIRED) + if(WIN32) + set(WITH_FIBERS + OFF + CACHE + BOOL + "Disable fiber support on Windows to avoid a build error due to disabled exceptions" + FORCE + ) + endif() + FetchContent_Declare( + capnproto + GIT_REPOSITORY https://github.com/capnproto/capnproto.git + GIT_TAG v1.3.0 + PATCH_COMMAND ${CMAKE_COMMAND} -E chdir patch --forward -p1 -i + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/capnproto-disable-tests.patch + ) + list(APPEND FETCH_PACKAGES capnproto) endif() if(BUILD_JEFF_MLIR_TESTS) diff --git a/cmake/patches/capnproto-disable-tests.patch b/cmake/patches/capnproto-disable-tests.patch new file mode 100644 index 0000000..5cea8cb --- /dev/null +++ b/cmake/patches/capnproto-disable-tests.patch @@ -0,0 +1,11 @@ +diff --git a/c++/src/CMakeLists.txt b/c++/src/CMakeLists.txt +--- a/c++/src/CMakeLists.txt ++++ b/c++/src/CMakeLists.txt +@@ -1,5 +1,8 @@ + ++# Disable vendored Cap'n Proto tests when consumed through FetchContent. ++set(BUILD_TESTING OFF) ++ + # Tests ======================================================================== + + if(BUILD_TESTING) diff --git a/include/jeff/Translation/Deserialize.hpp b/include/jeff/Translation/Deserialize.hpp index 92659ae..aa9b4fa 100644 --- a/include/jeff/Translation/Deserialize.hpp +++ b/include/jeff/Translation/Deserialize.hpp @@ -1,14 +1,24 @@ #pragma once -#include +#include #include #include #include +/** + * @brief Deserialize a memory buffer containing a serialized .jeff module into an MLIR module. + * @param context The MLIR context to use for the deserialization. + * @param buffer A memory buffer containing the serialized jeff module. + * @return An owning reference to the deserialized MLIR module. + */ +mlir::OwningOpRef deserialize(mlir::MLIRContext* context, + kj::ArrayPtr buffer); + /** * @brief Deserialize a .jeff file into an MLIR module. * @param context The MLIR context to use for the deserialization. * @param path The path to the .jeff file. * @return An owning reference to the deserialized MLIR module. */ -mlir::OwningOpRef deserialize(mlir::MLIRContext* context, llvm::StringRef path); +mlir::OwningOpRef deserializeFromFile(mlir::MLIRContext* context, + llvm::StringRef path); diff --git a/include/jeff/Translation/Serialize.hpp b/include/jeff/Translation/Serialize.hpp index 51ade60..c9fd014 100644 --- a/include/jeff/Translation/Serialize.hpp +++ b/include/jeff/Translation/Serialize.hpp @@ -1,10 +1,23 @@ #pragma once -#include +#include +#include #include /** - * @brief Serialize an MLIR module into a .jeff file. + * @brief Serialize an MLIR module containing a jeff program into a memory buffer. + * @param module The MLIR module to serialize. + * @return An owned memory buffer containing the serialized jeff module. + * + * @details + * Known limitations: + * + * - Only one-dimensional tensors with dynamic size are supported. + */ +kj::Array serialize(mlir::ModuleOp module); + +/** + * @brief Serialize an MLIR module containing a jeff program into a .jeff file. * @param module The MLIR module to serialize. * @param path The path to the .jeff file. * @@ -13,4 +26,4 @@ * * - Only one-dimensional tensors with dynamic size are supported. */ -void serialize(mlir::ModuleOp module, llvm::StringRef path); +void serializeToFile(mlir::ModuleOp module, llvm::StringRef path); diff --git a/lib/Translation/CMakeLists.txt b/lib/Translation/CMakeLists.txt index 8663e6a..6609d84 100644 --- a/lib/Translation/CMakeLists.txt +++ b/lib/Translation/CMakeLists.txt @@ -3,18 +3,17 @@ add_mlir_library( Deserialize.cpp Serialize.cpp LINK_LIBS + PRIVATE MLIRJeff MLIRFuncDialect + PUBLIC CapnProto::capnp DISABLE_INSTALL ) -target_compile_definitions(MLIRJeffTranslation PRIVATE ${CAPNP_DEFINITIONS}) - target_include_directories( - MLIRJeffTranslation - PRIVATE ${CAPNP_INCLUDE_DIRS} - PUBLIC ${jeff_SOURCE_DIR}/impl/cpp/src/capnp ${jeff_BINARY_DIR}/impl/cpp/src/capnp + MLIRJeffTranslation PUBLIC ${jeff_SOURCE_DIR}/impl/cpp/src/capnp + ${jeff_BINARY_DIR}/impl/cpp/src/capnp ) target_sources(MLIRJeffTranslation PUBLIC ${jeff_SOURCE_DIR}/impl/cpp/src/capnp/jeff.capnp.c++) diff --git a/lib/Translation/Deserialize.cpp b/lib/Translation/Deserialize.cpp index dca8d25..9effe70 100644 --- a/lib/Translation/Deserialize.cpp +++ b/lib/Translation/Deserialize.cpp @@ -5,16 +5,15 @@ #include #include -#include #include #include -#include +#include #include #include #include #include #include -#include +#include #include #include #include @@ -28,11 +27,11 @@ #include #include +#include #include #include #include #include -#include namespace { @@ -112,29 +111,29 @@ struct DeserializationContext { // Qubit operations //===----------------------------------------------------------------------===// -void deserializeQubitAlloc(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeQubitAlloc(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { auto allocOp = mlir::jeff::QubitAllocOp::create(builder); ctx.setValue(operation.getOutputs()[0], allocOp.getResult()); } -void deserializeQubitFree(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeQubitFree(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { mlir::jeff::QubitFreeOp::create(builder, ctx.getValue(operation.getInputs()[0])); } -void deserializeQubitFreeZero(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeQubitFreeZero(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { mlir::jeff::QubitFreeZeroOp::create(builder, ctx.getValue(operation.getInputs()[0])); } -void deserializeMeasure(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeMeasure(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { auto op = mlir::jeff::QubitMeasureOp::create(builder, ctx.getValue(operation.getInputs()[0])); ctx.setValue(operation.getOutputs()[0], op.getResult()); } -void deserializeMeasureNd(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeMeasureNd(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -143,7 +142,7 @@ void deserializeMeasureNd(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader ctx.setValue(outputs[1], op.getResult()); } -void deserializeReset(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeReset(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -153,7 +152,8 @@ void deserializeReset(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader oper template void deserializeOneTargetZeroParameter(mlir::ImplicitLocOpBuilder& builder, - jeff::Op::Reader operation, DeserializationContext& ctx) { + const jeff::Op::Reader& operation, + DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); const auto gate = operation.getInstruction().getQubit().getGate(); @@ -173,7 +173,8 @@ void deserializeOneTargetZeroParameter(mlir::ImplicitLocOpBuilder& builder, template void deserializeOneTargetOneParameter(mlir::ImplicitLocOpBuilder& builder, - jeff::Op::Reader operation, DeserializationContext& ctx) { + const jeff::Op::Reader& operation, + DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); const auto gate = operation.getInstruction().getQubit().getGate(); @@ -192,7 +193,7 @@ void deserializeOneTargetOneParameter(mlir::ImplicitLocOpBuilder& builder, } } -void deserializeU(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeU(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -214,7 +215,7 @@ void deserializeU(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operatio } } -void deserializeSwap(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeSwap(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -234,7 +235,7 @@ void deserializeSwap(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader opera } } -void deserializeGPhase(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeGPhase(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -253,10 +254,9 @@ void deserializeGPhase(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader ope } } -void deserializeWellKnown(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeWellKnown(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { - const auto wellKnown = operation.getInstruction().getQubit().getGate().getWellKnown(); - switch (wellKnown) { + switch (const auto wellKnown = operation.getInstruction().getQubit().getGate().getWellKnown()) { case jeff::WellKnownGate::X: deserializeOneTargetZeroParameter(builder, operation, ctx); break; @@ -306,7 +306,7 @@ void deserializeWellKnown(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader } } -void deserializeCustom(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeCustom(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -340,7 +340,7 @@ void deserializeCustom(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader ope } } -void deserializePpr(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializePpr(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -374,10 +374,9 @@ void deserializePpr(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operat } } -void deserializeGate(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeGate(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { - const auto gate = operation.getInstruction().getQubit().getGate(); - switch (gate.which()) { + switch (const auto gate = operation.getInstruction().getQubit().getGate(); gate.which()) { case jeff::QubitGate::WELL_KNOWN: deserializeWellKnown(builder, operation, ctx); break; @@ -388,16 +387,14 @@ void deserializeGate(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader opera deserializePpr(builder, operation, ctx); break; default: - llvm::errs() << "Cannot deserialize gate instruction " << static_cast(gate.which()) - << "\n"; + llvm::errs() << "Cannot deserialize gate instruction " << gate.which() << "\n"; llvm::report_fatal_error("Unknown gate instruction"); } } -void deserializeQubit(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeQubit(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { - const auto qubit = operation.getInstruction().getQubit(); - switch (qubit.which()) { + switch (const auto qubit = operation.getInstruction().getQubit(); qubit.which()) { case jeff::QubitOp::ALLOC: deserializeQubitAlloc(builder, operation, ctx); break; @@ -420,8 +417,7 @@ void deserializeQubit(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader oper deserializeGate(builder, operation, ctx); break; default: - llvm::errs() << "Cannot deserialize qubit instruction " << static_cast(qubit.which()) - << "\n"; + llvm::errs() << "Cannot deserialize qubit instruction " << qubit.which() << "\n"; llvm::report_fatal_error("Unknown qubit instruction"); } } @@ -430,7 +426,7 @@ void deserializeQubit(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader oper // Qureg operations //===----------------------------------------------------------------------===// -void deserializeQuregAlloc(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeQuregAlloc(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -439,13 +435,13 @@ void deserializeQuregAlloc(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader ctx.setValue(outputs[0], allocOp.getResult()); } -void deserializeQuregFreeZero(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeQuregFreeZero(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { mlir::jeff::QuregFreeZeroOp::create(builder, ctx.getValue(operation.getInputs()[0])); } -void deserializeQuregExtractIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeQuregExtractIndex(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); auto op = mlir::jeff::QuregExtractIndexOp::create(builder, ctx.getValue(inputs[0]), @@ -454,8 +450,8 @@ void deserializeQuregExtractIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op: ctx.setValue(outputs[1], op.getOutQubit()); } -void deserializeQuregInsertIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeQuregInsertIndex(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); auto op = mlir::jeff::QuregInsertIndexOp::create( @@ -463,8 +459,8 @@ void deserializeQuregInsertIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op:: ctx.setValue(outputs[0], op.getOutQreg()); } -void deserializeQuregExtractSlice(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeQuregExtractSlice(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); auto outQregType = mlir::jeff::QuregType::get(builder.getContext(), ctx.getLength(outputs[0])); @@ -476,8 +472,8 @@ void deserializeQuregExtractSlice(mlir::ImplicitLocOpBuilder& builder, jeff::Op: ctx.setValue(outputs[1], op.getNewQreg()); } -void deserializeQuregInsertSlice(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeQuregInsertSlice(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); auto op = mlir::jeff::QuregInsertSliceOp::create( @@ -485,7 +481,7 @@ void deserializeQuregInsertSlice(mlir::ImplicitLocOpBuilder& builder, jeff::Op:: ctx.setValue(outputs[0], op.getOutQreg()); } -void deserializeQuregLength(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeQuregLength(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -494,7 +490,7 @@ void deserializeQuregLength(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reade ctx.setValue(outputs[1], op.getLength()); } -void deserializeQuregSplit(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeQuregSplit(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -506,7 +502,7 @@ void deserializeQuregSplit(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader ctx.setValue(outputs[1], op.getOutQregTwo()); } -void deserializeQuregJoin(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeQuregJoin(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -516,7 +512,7 @@ void deserializeQuregJoin(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader ctx.setValue(outputs[0], op.getOutQreg()); } -void deserializeQuregCreate(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeQuregCreate(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -530,15 +526,14 @@ void deserializeQuregCreate(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reade ctx.setValue(outputs[0], op.getOutQreg()); } -void deserializeQuregFree(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeQuregFree(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { mlir::jeff::QuregFreeOp::create(builder, ctx.getValue(operation.getInputs()[0])); } -void deserializeQureg(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeQureg(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { - const auto qureg = operation.getInstruction().getQureg(); - switch (qureg.which()) { + switch (const auto qureg = operation.getInstruction().getQureg(); qureg.which()) { case jeff::QuregOp::ALLOC: deserializeQuregAlloc(builder, operation, ctx); break; @@ -573,8 +568,7 @@ void deserializeQureg(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader oper deserializeQuregFree(builder, operation, ctx); break; default: - llvm::errs() << "Cannot deserialize qureg instruction " << static_cast(qureg.which()) - << "\n"; + llvm::errs() << "Cannot deserialize qureg instruction " << qureg.which() << "\n"; llvm::report_fatal_error("Unknown qureg instruction"); } } @@ -585,7 +579,8 @@ void deserializeQureg(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader oper #define DESERIALIZE_INT_CONST(BIT_WIDTH) \ void deserializeIntConst##BIT_WIDTH(mlir::ImplicitLocOpBuilder& builder, \ - jeff::Op::Reader operation, DeserializationContext& ctx) { \ + const jeff::Op::Reader& operation, \ + DeserializationContext& ctx) { \ const auto value = operation.getInstruction().getInt().getConst##BIT_WIDTH(); \ auto intType = builder.getI##BIT_WIDTH##Type(); \ auto intAttr = mlir::IntegerAttr::get(intType, value); \ @@ -601,7 +596,7 @@ DESERIALIZE_INT_CONST(64) #undef DESERIALIZE_INT_CONST -void deserializeIntUnaryOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeIntUnaryOp(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, mlir::jeff::IntUnaryOperation unaryOperation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); @@ -610,7 +605,7 @@ void deserializeIntUnaryOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader ctx.setValue(outputs[0], op.getB()); } -void deserializeIntBinaryOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeIntBinaryOp(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, mlir::jeff::IntBinaryOperation binaryOperation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); @@ -620,7 +615,8 @@ void deserializeIntBinaryOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reade ctx.setValue(outputs[0], op.getC()); } -void deserializeIntComparisonOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeIntComparisonOp(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, mlir::jeff::IntComparisonOperation comparisonOperation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); @@ -653,10 +649,9 @@ void deserializeIntComparisonOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::R mlir::jeff::IntComparisonOperation::_##MLIR_ENUM_SUFFIX, ctx); \ break; -void deserializeInt(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeInt(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { - const auto intInstr = operation.getInstruction().getInt(); - switch (intInstr.which()) { + switch (const auto intInstr = operation.getInstruction().getInt(); intInstr.which()) { ADD_CONST_CASE(1) ADD_CONST_CASE(8) ADD_CONST_CASE(16) @@ -687,8 +682,7 @@ void deserializeInt(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operat ADD_COMPARISON_CASE(LT_U, ltU) ADD_COMPARISON_CASE(LTE_U, lteU) default: - llvm::errs() << "Cannot deserialize int instruction " << static_cast(intInstr.which()) - << "\n"; + llvm::errs() << "Cannot deserialize int instruction " << intInstr.which() << "\n"; llvm::report_fatal_error("Unknown int instruction"); } } @@ -702,8 +696,8 @@ void deserializeInt(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operat // IntArray operations //===----------------------------------------------------------------------===// -void deserializeIntArrayConst1(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeIntArrayConst1(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto outputs = operation.getOutputs(); const auto values = operation.getInstruction().getIntArray().getConst1(); llvm::SmallVector inArray; @@ -719,7 +713,7 @@ void deserializeIntArrayConst1(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Re #define DESERIALIZE_INT_ARRAY_CONST(BIT_WIDTH) \ void deserializeIntArrayConst##BIT_WIDTH(mlir::ImplicitLocOpBuilder& builder, \ - jeff::Op::Reader operation, \ + const jeff::Op::Reader& operation, \ DeserializationContext& ctx) { \ const auto outputs = operation.getOutputs(); \ const auto values = operation.getInstruction().getIntArray().getConst##BIT_WIDTH(); \ @@ -743,7 +737,7 @@ DESERIALIZE_INT_ARRAY_CONST(64) #undef DESERIALIZE_INT_ARRAY_CONST -void deserializeIntArrayZero(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeIntArrayZero(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -754,8 +748,8 @@ void deserializeIntArrayZero(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Read ctx.setValue(outputs[0], op.getOutArray()); } -void deserializeIntArrayGetIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeIntArrayGetIndex(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); auto tensorType = ctx.getValue(inputs[0]).getType(); @@ -765,8 +759,8 @@ void deserializeIntArrayGetIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op:: ctx.setValue(outputs[0], op.getValue()); } -void deserializeIntArraySetIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeIntArraySetIndex(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); auto tensorType = ctx.getValue(inputs[0]).getType(); @@ -776,16 +770,16 @@ void deserializeIntArraySetIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op:: ctx.setValue(outputs[0], op.getOutArray()); } -void deserializeIntArrayLength(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeIntArrayLength(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); auto op = mlir::jeff::IntArrayLengthOp::create(builder, ctx.getValue(inputs[0])); ctx.setValue(outputs[0], op.getLength()); } -void deserializeIntArrayCreate(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeIntArrayCreate(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); llvm::SmallVector inArray; @@ -799,10 +793,9 @@ void deserializeIntArrayCreate(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Re ctx.setValue(outputs[0], op.getOutArray()); } -void deserializeIntArray(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeIntArray(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { - const auto intArray = operation.getInstruction().getIntArray(); - switch (intArray.which()) { + switch (const auto intArray = operation.getInstruction().getIntArray(); intArray.which()) { case jeff::IntArrayOp::CONST1: deserializeIntArrayConst1(builder, operation, ctx); break; @@ -834,8 +827,7 @@ void deserializeIntArray(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader o deserializeIntArrayCreate(builder, operation, ctx); break; default: - llvm::errs() << "Cannot deserialize int array instruction " - << static_cast(intArray.which()) << "\n"; + llvm::errs() << "Cannot deserialize int array instruction " << intArray.which() << "\n"; llvm::report_fatal_error("Unknown int array instruction"); } } @@ -846,7 +838,7 @@ void deserializeIntArray(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader o #define DESERIALIZE_FLOAT_CONST(BIT_WIDTH) \ void deserializeFloatConst##BIT_WIDTH(mlir::ImplicitLocOpBuilder& builder, \ - jeff::Op::Reader operation, \ + const jeff::Op::Reader& operation, \ DeserializationContext& ctx) { \ const auto value = operation.getInstruction().getFloat().getConst##BIT_WIDTH(); \ auto floatType = builder.getF##BIT_WIDTH##Type(); \ @@ -860,7 +852,7 @@ DESERIALIZE_FLOAT_CONST(64) #undef DESERIALIZE_FLOAT_CONST -void deserializeFloatUnaryOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeFloatUnaryOp(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, mlir::jeff::FloatUnaryOperation unaryOperation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); @@ -869,7 +861,8 @@ void deserializeFloatUnaryOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Read ctx.setValue(outputs[0], op.getB()); } -void deserializeFloatBinaryOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeFloatBinaryOp(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, mlir::jeff::FloatBinaryOperation binaryOperation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); @@ -879,7 +872,8 @@ void deserializeFloatBinaryOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Rea ctx.setValue(outputs[0], op.getC()); } -void deserializeFloatComparisonOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeFloatComparisonOp(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, mlir::jeff::FloatComparisonOperation comparisonOperation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); @@ -889,7 +883,7 @@ void deserializeFloatComparisonOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op: ctx.setValue(outputs[0], op.getC()); } -void deserializeFloatIsOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeFloatIsOp(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, mlir::jeff::FloatIsOperation isOperation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -926,10 +920,9 @@ void deserializeFloatIsOp(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader mlir::jeff::FloatIsOperation::_is##MLIR_ENUM_SUFFIX, ctx); \ break; -void deserializeFloat(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeFloat(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { - const auto floatInstr = operation.getInstruction().getFloat(); - switch (floatInstr.which()) { + switch (const auto floatInstr = operation.getInstruction().getFloat(); floatInstr.which()) { ADD_CONST_CASE(32) ADD_CONST_CASE(64) ADD_UNARY_CASE(SQRT, sqrt) @@ -963,8 +956,7 @@ void deserializeFloat(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader oper ADD_IS_CASE(IS_NAN, Nan) ADD_IS_CASE(IS_INF, Inf) default: - llvm::errs() << "Cannot deserialize float instruction " - << static_cast(floatInstr.which()) << "\n"; + llvm::errs() << "Cannot deserialize float instruction " << floatInstr.which() << "\n"; llvm::report_fatal_error("Unknown float instruction"); } } @@ -979,8 +971,8 @@ void deserializeFloat(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader oper // FloatArray operations //===----------------------------------------------------------------------===// -void deserializeFloatArrayConst32(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeFloatArrayConst32(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto outputs = operation.getOutputs(); const auto values = operation.getInstruction().getFloatArray().getConst32(); llvm::SmallVector inArray; @@ -995,8 +987,8 @@ void deserializeFloatArrayConst32(mlir::ImplicitLocOpBuilder& builder, jeff::Op: ctx.setValue(outputs[0], op.getOutArray()); } -void deserializeFloatArrayConst64(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeFloatArrayConst64(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto outputs = operation.getOutputs(); const auto values = operation.getInstruction().getFloatArray().getConst64(); llvm::SmallVector inArray; @@ -1011,8 +1003,8 @@ void deserializeFloatArrayConst64(mlir::ImplicitLocOpBuilder& builder, jeff::Op: ctx.setValue(outputs[0], op.getOutArray()); } -void deserializeFloatArrayZero(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeFloatArrayZero(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); const auto zero = operation.getInstruction().getFloatArray().getZero(); @@ -1032,8 +1024,8 @@ void deserializeFloatArrayZero(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Re ctx.setValue(outputs[0], op.getOutArray()); } -void deserializeFloatArrayGetIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeFloatArrayGetIndex(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); auto tensorType = ctx.getValue(inputs[0]).getType(); @@ -1043,8 +1035,8 @@ void deserializeFloatArrayGetIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op ctx.setValue(outputs[0], op.getValue()); } -void deserializeFloatArraySetIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeFloatArraySetIndex(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); auto tensorType = ctx.getValue(inputs[0]).getType(); @@ -1054,16 +1046,16 @@ void deserializeFloatArraySetIndex(mlir::ImplicitLocOpBuilder& builder, jeff::Op ctx.setValue(outputs[0], op.getOutArray()); } -void deserializeFloatArrayLength(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeFloatArrayLength(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); auto op = mlir::jeff::FloatArrayLengthOp::create(builder, ctx.getValue(inputs[0])); ctx.setValue(outputs[0], op.getLength()); } -void deserializeFloatArrayCreate(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, - DeserializationContext& ctx) { +void deserializeFloatArrayCreate(mlir::ImplicitLocOpBuilder& builder, + const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); llvm::SmallVector inArray; @@ -1077,10 +1069,10 @@ void deserializeFloatArrayCreate(mlir::ImplicitLocOpBuilder& builder, jeff::Op:: ctx.setValue(outputs[0], op.getOutArray()); } -void deserializeFloatArray(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeFloatArray(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { - const auto floatArray = operation.getInstruction().getFloatArray(); - switch (floatArray.which()) { + switch (const auto floatArray = operation.getInstruction().getFloatArray(); + floatArray.which()) { case jeff::FloatArrayOp::CONST32: deserializeFloatArrayConst32(builder, operation, ctx); break; @@ -1103,8 +1095,7 @@ void deserializeFloatArray(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader deserializeFloatArrayCreate(builder, operation, ctx); break; default: - llvm::errs() << "Cannot deserialize float array instruction " - << static_cast(floatArray.which()) << "\n"; + llvm::errs() << "Cannot deserialize float array instruction " << floatArray.which() << "\n"; llvm::report_fatal_error("Unknown float array instruction"); } } @@ -1115,10 +1106,10 @@ void deserializeFloatArray(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader // Forward declaration void deserializeOperations(mlir::ImplicitLocOpBuilder& builder, - capnp::List::Reader operations, + const capnp::List::Reader& operations, DeserializationContext& ctx); -void deserializeSwitch(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeSwitch(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { auto loc = builder.getUnknownLoc(); const auto inputs = operation.getInputs(); @@ -1194,7 +1185,7 @@ void deserializeSwitch(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader ope } } -void deserializeFor(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeFor(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { auto loc = builder.getUnknownLoc(); const auto inputs = operation.getInputs(); @@ -1247,7 +1238,7 @@ void deserializeFor(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operat } template -void deserializeWhile(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeWhile(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, JEFF_WHILE_OP_READER_TYPE reader, DeserializationContext& ctx) { auto loc = builder.getUnknownLoc(); const auto inputs = operation.getInputs(); @@ -1315,10 +1306,9 @@ void deserializeWhile(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader oper } } -void deserializeScf(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeScf(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { - const auto scf = operation.getInstruction().getScf(); - switch (scf.which()) { + switch (const auto scf = operation.getInstruction().getScf(); scf.which()) { case jeff::ScfOp::SWITCH: deserializeSwitch(builder, operation, ctx); break; @@ -1334,8 +1324,7 @@ void deserializeScf(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operat builder, operation, scf.getDoWhile(), ctx); break; default: - llvm::errs() << "Cannot deserialize scf instruction " << static_cast(scf.which()) - << "\n"; + llvm::errs() << "Cannot deserialize scf instruction " << scf.which() << "\n"; llvm::report_fatal_error("Unknown scf instruction"); } } @@ -1344,7 +1333,7 @@ void deserializeScf(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operat // Func operations //===----------------------------------------------------------------------===// -void deserializeFunc(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader operation, +void deserializeFunc(mlir::ImplicitLocOpBuilder& builder, const jeff::Op::Reader& operation, DeserializationContext& ctx) { const auto inputs = operation.getInputs(); const auto outputs = operation.getOutputs(); @@ -1365,7 +1354,8 @@ void deserializeFunc(mlir::ImplicitLocOpBuilder& builder, jeff::Op::Reader opera // Types //===----------------------------------------------------------------------===// -mlir::Type deserializeQuregType(mlir::ImplicitLocOpBuilder& builder, jeff::Type::Reader type) { +mlir::Type deserializeQuregType(const mlir::ImplicitLocOpBuilder& builder, + const jeff::Type::Reader& type) { const auto quregType = type.getQureg(); auto length = mlir::ShapedType::kDynamic; if (quregType.isStatic()) { @@ -1374,7 +1364,7 @@ mlir::Type deserializeQuregType(mlir::ImplicitLocOpBuilder& builder, jeff::Type: return mlir::jeff::QuregType::get(builder.getContext(), length); } -mlir::Type deserializeIntType(mlir::ImplicitLocOpBuilder& builder, jeff::Type::Reader type) { +mlir::Type deserializeIntType(mlir::ImplicitLocOpBuilder& builder, const jeff::Type::Reader& type) { switch (type.getInt()) { case 1: return builder.getI1Type(); @@ -1392,7 +1382,8 @@ mlir::Type deserializeIntType(mlir::ImplicitLocOpBuilder& builder, jeff::Type::R } } -mlir::Type deserializeIntArrayType(mlir::ImplicitLocOpBuilder& builder, jeff::Type::Reader type) { +mlir::Type deserializeIntArrayType(mlir::ImplicitLocOpBuilder& builder, + const jeff::Type::Reader& type) { const auto intArrayType = type.getIntArray(); auto length = mlir::ShapedType::kDynamic; if (intArrayType.getLength().isStatic()) { @@ -1416,7 +1407,8 @@ mlir::Type deserializeIntArrayType(mlir::ImplicitLocOpBuilder& builder, jeff::Ty } } -mlir::FloatType deserializeFloatType(mlir::ImplicitLocOpBuilder& builder, jeff::Type::Reader type) { +mlir::FloatType deserializeFloatType(mlir::ImplicitLocOpBuilder& builder, + const jeff::Type::Reader& type) { switch (type.getFloat()) { case jeff::FloatPrecision::FLOAT32: return builder.getF32Type(); @@ -1429,7 +1421,8 @@ mlir::FloatType deserializeFloatType(mlir::ImplicitLocOpBuilder& builder, jeff:: } } -mlir::Type deserializeFloatArrayType(mlir::ImplicitLocOpBuilder& builder, jeff::Type::Reader type) { +mlir::Type deserializeFloatArrayType(mlir::ImplicitLocOpBuilder& builder, + const jeff::Type::Reader& type) { const auto floatArrayType = type.getFloatArray(); auto length = mlir::ShapedType::kDynamic; if (floatArrayType.getLength().isStatic()) { @@ -1447,7 +1440,7 @@ mlir::Type deserializeFloatArrayType(mlir::ImplicitLocOpBuilder& builder, jeff:: } } -mlir::Type deserializeType(mlir::ImplicitLocOpBuilder& builder, jeff::Type::Reader type) { +mlir::Type deserializeType(mlir::ImplicitLocOpBuilder& builder, const jeff::Type::Reader& type) { switch (type.which()) { case jeff::Type::QUBIT: return mlir::jeff::QubitType::get(builder.getContext()); @@ -1462,17 +1455,16 @@ mlir::Type deserializeType(mlir::ImplicitLocOpBuilder& builder, jeff::Type::Read case jeff::Type::FLOAT_ARRAY: return deserializeFloatArrayType(builder, type); default: - llvm::errs() << "Cannot deserialize type " << static_cast(type.which()) << "\n"; + llvm::errs() << "Cannot deserialize type " << type.which() << "\n"; llvm::report_fatal_error("Unknown type"); } } -void deserializeOperations(mlir::ImplicitLocOpBuilder& builder, - capnp::List::Reader operations, - DeserializationContext& ctx) { +auto deserializeOperations(mlir::ImplicitLocOpBuilder& builder, + const capnp::List::Reader& operations, + DeserializationContext& ctx) -> void { for (auto operation : operations) { - const auto instruction = operation.getInstruction(); - switch (instruction.which()) { + switch (const auto instruction = operation.getInstruction(); instruction.which()) { case jeff::Op::Instruction::QUBIT: deserializeQubit(builder, operation, ctx); break; @@ -1498,15 +1490,15 @@ void deserializeOperations(mlir::ImplicitLocOpBuilder& builder, deserializeFunc(builder, operation, ctx); break; default: - llvm::errs() << "Cannot deserialize instruction " - << static_cast(instruction.which()) << "\n"; + llvm::errs() << "Cannot deserialize instruction " << instruction.which() << "\n"; llvm::report_fatal_error("Unknown instruction"); } } } -void deserializeFunction(mlir::ImplicitLocOpBuilder& builder, jeff::Function::Reader function, - uint16_t functionId, DeserializationContext& ctx) { +void deserializeFunction(mlir::ImplicitLocOpBuilder& builder, + const jeff::Function::Reader& function, uint16_t functionId, + DeserializationContext& ctx) { ctx.values.clear(); // Get function definition @@ -1577,29 +1569,10 @@ void deserializeFunction(mlir::ImplicitLocOpBuilder& builder, jeff::Function::Re mlir::func::ReturnOp::create(builder, results); } -} // namespace - -mlir::OwningOpRef deserialize(mlir::MLIRContext* context, llvm::StringRef path) { +mlir::OwningOpRef deserialize(mlir::MLIRContext* context, + const jeff::Module::Reader& jeffModule) { DeserializationContext ctx; - // Get jeff module from file - llvm::sys::fs::file_t file = 0; - if (llvm::sys::fs::openFileForRead(path, file)) { - llvm::report_fatal_error("Could not open file"); - } - -#ifdef _WIN32 - kj::AutoCloseHandle autoCloseHandle(file); - kj::HandleInputStream input(std::move(autoCloseHandle)); -#else - kj::AutoCloseFd autoCloseFd(file); - kj::FdInputStream input(std::move(autoCloseFd)); -#endif - - capnp::MallocMessageBuilder message; - capnp::readMessageCopy(input, message); - jeff::Module::Reader jeffModule = message.getRoot(); - // Create MLIR builder mlir::ImplicitLocOpBuilder builder(mlir::UnknownLoc::get(context), context); @@ -1663,3 +1636,34 @@ mlir::OwningOpRef deserialize(mlir::MLIRContext* context, llvm:: return mlirModule; } +} // namespace + +mlir::OwningOpRef deserialize(mlir::MLIRContext* context, + kj::ArrayPtr buffer) { + DeserializationContext ctx; + + capnp::FlatArrayMessageReader message(buffer); + return deserialize(context, message.getRoot()); +} + +mlir::OwningOpRef deserializeFromFile(mlir::MLIRContext* context, + llvm::StringRef path) { + auto file = llvm::MemoryBuffer::getFile(path); + if (!file) { + llvm::errs() << "Failed to open file: " << path << "\n"; + llvm::report_fatal_error("Could not open file"); + } + + // Get jeff module from buffer + const auto bytes = (*file)->getBuffer(); + assert(bytes.size() % sizeof(capnp::word) == 0 && + "Serialized module size must be a multiple of capnp::word size"); + assert(reinterpret_cast(bytes.data()) % alignof(capnp::word) == 0 && + "Serialized module buffer must be aligned to capnp::word alignment"); + const auto words = kj::ArrayPtr(reinterpret_cast(bytes.data()), + bytes.size() / sizeof(capnp::word)); + + capnp::FlatArrayMessageReader message(words); + const jeff::Module::Reader jeffModule = message.getRoot(); + return deserialize(context, jeffModule); +} diff --git a/lib/Translation/Serialize.cpp b/lib/Translation/Serialize.cpp index 87b4ad5..a8a6194 100644 --- a/lib/Translation/Serialize.cpp +++ b/lib/Translation/Serialize.cpp @@ -4,10 +4,12 @@ #include "jeff/IR/JeffInterfaces.h" #include "jeff/IR/JeffOps.h" +#include #include #include #include #include +#include #include #include #include @@ -28,7 +30,6 @@ #include #include -#include static void checkRank(mlir::RankedTensorType tensorType) { if (tensorType.getRank() != 1) { @@ -1254,8 +1255,8 @@ void serializeSwitch(jeff::Op::Builder builder, mlir::jeff::SwitchOp op, auto yieldOp = llvm::cast(block.back()); const auto numTargets = yieldOp.getNumOperands(); auto targets = branchBuilder.initTargets(numTargets); - for (size_t j = 0; j < numTargets; ++j) { - targets.set(j, ctx.getValueId(yieldOp.getOperand(j))); + for (size_t t = 0; t < numTargets; ++t) { + targets.set(t, ctx.getValueId(yieldOp.getOperand(t))); } } @@ -1283,8 +1284,8 @@ void serializeSwitch(jeff::Op::Builder builder, mlir::jeff::SwitchOp op, auto yieldOp = llvm::cast(block.back()); const auto numTargets = yieldOp.getNumOperands(); auto targets = defaultBuilder.initTargets(numTargets); - for (size_t i = 0; i < numTargets; ++i) { - targets.set(i, ctx.getValueId(yieldOp.getOperand(i))); + for (size_t t = 0; t < numTargets; ++t) { + targets.set(t, ctx.getValueId(yieldOp.getOperand(t))); } } } @@ -1326,8 +1327,8 @@ void serializeFor(jeff::Op::Builder builder, mlir::jeff::ForOp op, Serialization auto yieldOp = llvm::cast(block.back()); const auto numTargets = yieldOp.getNumOperands(); auto targets = forBuilder.initTargets(numTargets); - for (size_t i = 0; i < numTargets; ++i) { - targets.set(i, ctx.getValueId(yieldOp.getOperand(i))); + for (size_t t = 0; t < numTargets; ++t) { + targets.set(t, ctx.getValueId(yieldOp.getOperand(t))); } } @@ -1396,8 +1397,8 @@ void serializeWhile(jeff::Op::Builder builder, mlir::jeff::WhileOp op, Serializa auto yieldOp = llvm::cast(body.front().back()); const auto numTargets = yieldOp.getNumOperands(); auto targets = bodyBuilder.initTargets(numTargets); - for (size_t i = 0; i < numTargets; ++i) { - targets.set(i, ctx.getValueId(yieldOp.getOperand(i))); + for (size_t t = 0; t < numTargets; ++t) { + targets.set(t, ctx.getValueId(yieldOp.getOperand(t))); } } } @@ -1468,8 +1469,8 @@ void serializeDoWhile(jeff::Op::Builder builder, mlir::jeff::DoWhileOp op, auto yieldOp = llvm::cast(body.front().back()); const auto numTargets = yieldOp.getNumOperands(); auto targets = bodyBuilder.initTargets(numTargets); - for (size_t i = 0; i < numTargets; ++i) { - targets.set(i, ctx.getValueId(yieldOp.getOperand(i))); + for (size_t t = 0; t < numTargets; ++t) { + targets.set(t, ctx.getValueId(yieldOp.getOperand(t))); } } } @@ -1646,8 +1647,8 @@ void serializeFunction(jeff::Function::Builder functionBuilder, mlir::func::Func auto returnOp = llvm::cast(entryBlock.back()); const auto numTargets = returnOp.getNumOperands(); auto targetsBuilder = bodyBuilder.initTargets(numTargets); - for (unsigned i = 0; i < numTargets; ++i) { - targetsBuilder.set(i, ctx.getValueId(returnOp.getOperand(i))); + for (unsigned t = 0; t < numTargets; ++t) { + targetsBuilder.set(t, ctx.getValueId(returnOp.getOperand(t))); } // Build values @@ -1657,10 +1658,10 @@ void serializeFunction(jeff::Function::Builder functionBuilder, mlir::func::Func for (auto& pair : ctx.values) { values[pair.second] = pair.first; } - for (size_t i = 0, j = 0; i < numValues; ++i) { - auto valueBuilder = valuesBuilder[i]; + for (size_t v = 0; v < numValues; ++v) { + auto valueBuilder = valuesBuilder[v]; auto typeBuilder = valueBuilder.initType(); - serializeType(typeBuilder, values[i].getType()); + serializeType(typeBuilder, values[v].getType()); } } @@ -1719,22 +1720,27 @@ void writeMessage(mlir::ModuleOp module, capnp::MallocMessageBuilder& message) { } // namespace -void serialize(mlir::ModuleOp module, llvm::StringRef path) { - llvm::sys::fs::file_t file = 0; +kj::Array serialize(mlir::ModuleOp module) { + capnp::MallocMessageBuilder message; + writeMessage(module, message); + return capnp::messageToFlatArray(message); +} + +void serializeToFile(mlir::ModuleOp module, llvm::StringRef path) { + int file = 0; if (llvm::sys::fs::openFileForWrite(path, file)) { - llvm::errs() << "Failed to open file: " << path << "\n"; llvm::report_fatal_error("Could not open file"); } + auto fd = llvm::sys::fs::convertFDToNativeFile(file); + capnp::MallocMessageBuilder message; + writeMessage(module, message); #ifdef _WIN32 - kj::AutoCloseHandle autoCloseHandle(file); - kj::HandleOutputStream output(std::move(autoCloseHandle)); + kj::AutoCloseHandle handle(fd); + kj::HandleOutputStream output(kj::mv(handle)); + capnp::writeMessage(output, message); #else - kj::AutoCloseFd autoCloseFd(file); - kj::FdOutputStream output(std::move(autoCloseFd)); + const kj::AutoCloseFd autoCloseFd(fd); + capnp::writeMessageToFd(autoCloseFd, message); #endif - - capnp::MallocMessageBuilder message; - writeMessage(module, message); - capnp::writeMessage(output, message); } diff --git a/unittests/Conversion/test_native_round_trip.cpp b/unittests/Conversion/test_native_round_trip.cpp index d34f6f2..23ab59a 100644 --- a/unittests/Conversion/test_native_round_trip.cpp +++ b/unittests/Conversion/test_native_round_trip.cpp @@ -9,10 +9,9 @@ #include #include #include -#include +#include #include #include -#include #include #include #include @@ -28,7 +27,6 @@ #include #include #include -#include #include namespace fs = std::filesystem; @@ -46,24 +44,31 @@ std::ostream& operator<<(std::ostream& os, const NativeRoundTripTestCase& testCa class NativeRoundTripTest : public ::testing::Test, public ::testing::WithParamInterface {}; -kj::Array readJeffFile(llvm::StringRef path) { - llvm::sys::fs::file_t file = 0; - if (llvm::sys::fs::openFileForRead(path, file)) { +std::string readJeffFileToText(llvm::StringRef path) { + auto file = llvm::sys::fs::openNativeFileForRead(path); + if (!file) { llvm::errs() << "Failed to open file: " << path << "\n"; llvm::report_fatal_error("Could not open file"); } + capnp::MallocMessageBuilder message; #ifdef _WIN32 - kj::AutoCloseHandle autoCloseHandle(file); + kj::AutoCloseHandle autoCloseHandle(*file); kj::HandleInputStream input(std::move(autoCloseHandle)); + capnp::readMessageCopy(input, message); #else - kj::AutoCloseFd autoCloseFd(file); - kj::FdInputStream input(std::move(autoCloseFd)); + const kj::AutoCloseFd autoCloseFd(*file); + capnp::readMessageCopyFromFd(autoCloseFd, message); #endif - capnp::MallocMessageBuilder message; - capnp::readMessageCopy(input, message); - return capnp::messageToFlatArray(message); + const auto module = message.getRoot(); + return module.toString().flatten().cStr(); +} + +std::string moduleTextFromBuffer(const kj::ArrayPtr& buffer) { + capnp::FlatArrayMessageReader message(buffer); + const auto module = message.getRoot(); + return module.toString().flatten().cStr(); } mlir::LogicalResult convertJeffToNative(mlir::ModuleOp module) { @@ -135,11 +140,8 @@ TEST_P(NativeRoundTripTest, RoundTrip) { const fs::path inputsDir = TEST_INPUTS_DIR; const auto& path = inputsDir / testCase.filename; - // Load original jeff module - auto original = readJeffFile(path.string()); - // Deserialize jeff module - auto mlirModule = deserialize(&context, path.string()); + auto mlirModule = deserializeFromFile(&context, path.string()); llvm::errs() << "Input MLIR module:\n"; mlirModule->print(llvm::errs()); @@ -162,34 +164,15 @@ TEST_P(NativeRoundTripTest, RoundTrip) { mlirModule->print(llvm::errs()); llvm::errs() << "\n\n"; - // Create temporary file - llvm::SmallString<128> tempFilePath; - if (llvm::sys::fs::createTemporaryFile("test", "jeff", tempFilePath)) { - llvm::report_fatal_error("Could not create temporary file"); - } - // Serialize MLIR module - serialize(*mlirModule, tempFilePath.str()); - - // Load serialized jeff module - auto serialized = readJeffFile(tempFilePath.str()); - - // Remove temporary file - if (llvm::sys::fs::remove(tempFilePath)) { - llvm::errs() << "Failed to remove temporary file\n"; - } + auto serialized = serialize(*mlirModule); // Compare textual representations - capnp::FlatArrayMessageReader originalMessage(original); - auto originalModule = originalMessage.getRoot(); - auto originalText = originalModule.toString().flatten(); - - capnp::FlatArrayMessageReader serializedMessage(serialized); - auto serializedModule = serializedMessage.getRoot(); - auto serializedText = serializedModule.toString().flatten(); + const auto originalText = readJeffFileToText(path.string()); + const auto serializedText = moduleTextFromBuffer(serialized); - llvm::errs() << "Original module:\n" << originalText.cStr() << "\n\n"; - llvm::errs() << "Serialized module:\n" << serializedText.cStr() << "\n\n"; + llvm::errs() << "Original module:\n" << originalText << "\n\n"; + llvm::errs() << "Serialized module:\n" << serializedText << "\n\n"; ASSERT_EQ(originalText, serializedText); } diff --git a/unittests/Translation/test_round_trip.cpp b/unittests/Translation/test_round_trip.cpp index 4cb2395..5920b49 100644 --- a/unittests/Translation/test_round_trip.cpp +++ b/unittests/Translation/test_round_trip.cpp @@ -7,10 +7,9 @@ #include #include #include -#include +#include #include #include -#include #include #include #include @@ -21,7 +20,6 @@ #include #include #include -#include #include namespace fs = std::filesystem; @@ -39,24 +37,31 @@ std::ostream& operator<<(std::ostream& os, const RoundTripTestCase& testCase) { class RoundTripTest : public ::testing::Test, public ::testing::WithParamInterface {}; -kj::Array readJeffFile(llvm::StringRef path) { - llvm::sys::fs::file_t file = 0; - if (llvm::sys::fs::openFileForRead(path, file)) { +std::string readJeffFileToText(llvm::StringRef path) { + auto file = llvm::sys::fs::openNativeFileForRead(path); + if (!file) { llvm::errs() << "Failed to open file: " << path << "\n"; llvm::report_fatal_error("Could not open file"); } + capnp::MallocMessageBuilder message; #ifdef _WIN32 - kj::AutoCloseHandle autoCloseHandle(file); + kj::AutoCloseHandle autoCloseHandle(*file); kj::HandleInputStream input(std::move(autoCloseHandle)); + capnp::readMessageCopy(input, message); #else - kj::AutoCloseFd autoCloseFd(file); - kj::FdInputStream input(std::move(autoCloseFd)); + const kj::AutoCloseFd autoCloseFd(*file); + capnp::readMessageCopyFromFd(autoCloseFd, message); #endif - capnp::MallocMessageBuilder message; - capnp::readMessageCopy(input, message); - return capnp::messageToFlatArray(message); + const auto module = message.getRoot(); + return module.toString().flatten().cStr(); +} + +std::string moduleTextFromBuffer(const kj::ArrayPtr& buffer) { + capnp::FlatArrayMessageReader message(buffer); + const auto module = message.getRoot(); + return module.toString().flatten().cStr(); } std::vector getTestCases() { @@ -93,44 +98,22 @@ TEST_P(RoundTripTest, RoundTrip) { const fs::path inputsDir = TEST_INPUTS_DIR; const auto& path = inputsDir / testCase.filename; - // Load original jeff module - auto original = readJeffFile(path.string()); - // Deserialize jeff module - auto mlirModule = deserialize(&context, path.string()); + auto mlirModule = deserializeFromFile(&context, path.string()); llvm::errs() << "Deserialized MLIR module:\n"; mlirModule->print(llvm::errs()); llvm::errs() << "\n\n"; - // Create temporary file - llvm::SmallString<128> tempFilePath; - if (llvm::sys::fs::createTemporaryFile("test", "jeff", tempFilePath)) { - llvm::report_fatal_error("Could not create temporary file"); - } - // Serialize MLIR module - serialize(*mlirModule, tempFilePath.str()); - - // Load serialized jeff module - auto serialized = readJeffFile(tempFilePath.str()); - - // Remove temporary file - if (llvm::sys::fs::remove(tempFilePath)) { - llvm::errs() << "Failed to remove temporary file\n"; - } + auto serialized = serialize(*mlirModule); // Compare textual representations - capnp::FlatArrayMessageReader originalMessage(original); - auto originalModule = originalMessage.getRoot(); - auto originalText = originalModule.toString().flatten(); - - capnp::FlatArrayMessageReader serializedMessage(serialized); - auto serializedModule = serializedMessage.getRoot(); - auto serializedText = serializedModule.toString().flatten(); + const auto originalText = readJeffFileToText(path.string()); + const auto serializedText = moduleTextFromBuffer(serialized); - llvm::errs() << "Original module:\n" << originalText.cStr() << "\n\n"; - llvm::errs() << "Serialized module:\n" << serializedText.cStr() << "\n\n"; + llvm::errs() << "Original module:\n" << originalText << "\n\n"; + llvm::errs() << "Serialized module:\n" << serializedText << "\n\n"; ASSERT_EQ(originalText, serializedText); }