Skip to content

Commit fb3be2d

Browse files
committed
Make serialize() and deserialize() independent of Cap'n Proto
1 parent 10c0ef4 commit fb3be2d

6 files changed

Lines changed: 65 additions & 41 deletions

File tree

include/jeff/Translation/Deserialize.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
#pragma once
22

3-
#include <capnp/common.h>
4-
#include <kj/common.h>
5-
#include <llvm/ADT/StringRef.h>
63
#include <mlir/IR/BuiltinOps.h>
74
#include <mlir/IR/MLIRContext.h>
85
#include <mlir/IR/OwningOpRef.h>
6+
#include <mlir/Support/LLVM.h>
7+
8+
#include <cstdint>
99

1010
/**
11-
* @brief Deserialize a flat word array into an MLIR module.
11+
* @brief Deserialize a byte buffer into an MLIR module.
1212
* @param context The MLIR context to use for the deserialization.
13-
* @param data A flat word array containing the serialized module.
13+
* @param data A byte buffer containing the serialized module.
1414
* @return An owning reference to the deserialized MLIR module.
1515
*/
1616
mlir::OwningOpRef<mlir::ModuleOp> deserialize(mlir::MLIRContext* context,
17-
kj::ArrayPtr<capnp::word> data);
17+
llvm::ArrayRef<uint8_t> data);
1818

1919
/**
2020
* @brief Deserialize a .jeff file into an MLIR module.

include/jeff/Translation/Serialize.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
#pragma once
22

3-
#include <capnp/common.h>
4-
#include <kj/array.h>
5-
#include <llvm/ADT/StringRef.h>
63
#include <mlir/IR/BuiltinOps.h>
4+
#include <mlir/Support/LLVM.h>
5+
6+
#include <cstdint>
77

88
/**
9-
* @brief Serialize an MLIR module into a flat word array.
9+
* @brief Serialize an MLIR module into a byte buffer.
1010
* @param module The MLIR module to serialize.
11-
* @return A flat word array containing the serialized module.
11+
* @return A byte buffer containing the serialized module.
1212
*
1313
* @details
1414
* Known limitations:
1515
*
1616
* - Only one-dimensional tensors with dynamic size are supported.
1717
*/
18-
kj::Array<capnp::word> serialize(mlir::ModuleOp module);
18+
llvm::SmallVector<uint8_t> serialize(mlir::ModuleOp module);
1919

2020
/**
2121
* @brief Serialize an MLIR module into a .jeff file.

lib/Translation/Deserialize.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,11 +1581,14 @@ void deserializeFunction(mlir::ImplicitLocOpBuilder& builder, jeff::Function::Re
15811581
} // namespace
15821582

15831583
mlir::OwningOpRef<mlir::ModuleOp> deserialize(mlir::MLIRContext* context,
1584-
kj::ArrayPtr<capnp::word> data) {
1584+
llvm::ArrayRef<uint8_t> data) {
15851585
DeserializationContext ctx;
15861586

1587+
auto words = kj::heapArray<capnp::word>(data.size() / sizeof(capnp::word));
1588+
std::memcpy(words.begin(), data.data(), data.size());
1589+
15871590
// Get jeff module from data
1588-
capnp::FlatArrayMessageReader message(data);
1591+
capnp::FlatArrayMessageReader message(words);
15891592
jeff::Module::Reader jeffModule = message.getRoot<jeff::Module>();
15901593

15911594
// Create MLIR builder
@@ -1668,7 +1671,8 @@ mlir::OwningOpRef<mlir::ModuleOp> deserializeFromFile(mlir::MLIRContext* context
16681671
kj::FdInputStream input(std::move(autoCloseFd));
16691672
#endif
16701673

1671-
capnp::MallocMessageBuilder message;
1672-
capnp::readMessageCopy(input, message);
1673-
return deserialize(context, capnp::messageToFlatArray(message));
1674+
auto data = input.readAllBytes();
1675+
auto bytes = data.asBytes();
1676+
llvm::ArrayRef<uint8_t> buffer(reinterpret_cast<const uint8_t*>(bytes.begin()), bytes.size());
1677+
return deserialize(context, buffer);
16741678
}

lib/Translation/Serialize.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,10 +1721,14 @@ void writeMessage(mlir::ModuleOp module, capnp::MallocMessageBuilder& message) {
17211721

17221722
} // namespace
17231723

1724-
kj::Array<capnp::word> serialize(mlir::ModuleOp module) {
1724+
llvm::SmallVector<uint8_t> serialize(mlir::ModuleOp module) {
17251725
capnp::MallocMessageBuilder message;
17261726
writeMessage(module, message);
1727-
return capnp::messageToFlatArray(message);
1727+
1728+
auto words = capnp::messageToFlatArray(message);
1729+
auto bytes = words.asBytes();
1730+
return llvm::SmallVector<uint8_t>(reinterpret_cast<const uint8_t*>(bytes.begin()),
1731+
reinterpret_cast<const uint8_t*>(bytes.end()));
17281732
}
17291733

17301734
void serializeToFile(mlir::ModuleOp module, llvm::StringRef path) {

unittests/Conversion/test_native_round_trip.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
#include <capnp/serialize.h>
1010
#include <gtest/gtest.h>
1111
#include <jeff.capnp.h>
12-
#include <kj/array.h>
1312
#include <kj/io.h>
1413
#include <kj/string-tree.h>
14+
#include <llvm/ADT/ArrayRef.h>
15+
#include <llvm/ADT/SmallVector.h>
1516
#include <llvm/Support/ErrorHandling.h>
1617
#include <llvm/Support/FileSystem.h>
1718
#include <llvm/Support/raw_ostream.h>
@@ -45,7 +46,7 @@ std::ostream& operator<<(std::ostream& os, const NativeRoundTripTestCase& testCa
4546
class NativeRoundTripTest : public ::testing::Test,
4647
public ::testing::WithParamInterface<NativeRoundTripTestCase> {};
4748

48-
kj::Array<capnp::word> readJeffFile(llvm::StringRef path) {
49+
llvm::SmallVector<uint8_t> readJeffFile(llvm::StringRef path) {
4950
llvm::sys::fs::file_t file = 0;
5051
if (llvm::sys::fs::openFileForRead(path, file)) {
5152
llvm::errs() << "Failed to open file: " << path << "\n";
@@ -60,9 +61,20 @@ kj::Array<capnp::word> readJeffFile(llvm::StringRef path) {
6061
kj::FdInputStream input(std::move(autoCloseFd));
6162
#endif
6263

64+
auto words = input.readAllBytes();
65+
auto bytes = words.asBytes();
66+
return llvm::SmallVector<uint8_t>(reinterpret_cast<const uint8_t*>(bytes.begin()),
67+
reinterpret_cast<const uint8_t*>(bytes.end()));
68+
}
69+
70+
std::string moduleTextFromBytes(llvm::ArrayRef<uint8_t> data) {
71+
kj::ArrayPtr<const kj::byte> bytes(reinterpret_cast<const kj::byte*>(data.data()), data.size());
72+
kj::ArrayInputStream input(bytes);
73+
6374
capnp::MallocMessageBuilder message;
6475
capnp::readMessageCopy(input, message);
65-
return capnp::messageToFlatArray(message);
76+
auto module = message.getRoot<jeff::Module>();
77+
return std::string(module.toString().flatten().cStr());
6678
}
6779

6880
mlir::LogicalResult convertJeffToNative(mlir::ModuleOp module) {
@@ -165,16 +177,12 @@ TEST_P(NativeRoundTripTest, RoundTrip) {
165177
auto serialized = serialize(*mlirModule);
166178

167179
// Compare textual representations
168-
capnp::FlatArrayMessageReader originalMessage(original);
169-
auto originalModule = originalMessage.getRoot<jeff::Module>();
170-
auto originalText = originalModule.toString().flatten();
180+
auto originalText = moduleTextFromBytes(original);
171181

172-
capnp::FlatArrayMessageReader serializedMessage(serialized);
173-
auto serializedModule = serializedMessage.getRoot<jeff::Module>();
174-
auto serializedText = serializedModule.toString().flatten();
182+
auto serializedText = moduleTextFromBytes(serialized);
175183

176-
llvm::errs() << "Original module:\n" << originalText.cStr() << "\n\n";
177-
llvm::errs() << "Serialized module:\n" << serializedText.cStr() << "\n\n";
184+
llvm::errs() << "Original module:\n" << originalText << "\n\n";
185+
llvm::errs() << "Serialized module:\n" << serializedText << "\n\n";
178186

179187
ASSERT_EQ(originalText, serializedText);
180188
}

unittests/Translation/test_round_trip.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
#include <capnp/serialize.h>
88
#include <gtest/gtest.h>
99
#include <jeff.capnp.h>
10-
#include <kj/array.h>
1110
#include <kj/io.h>
1211
#include <kj/string-tree.h>
12+
#include <llvm/ADT/ArrayRef.h>
13+
#include <llvm/ADT/SmallVector.h>
1314
#include <llvm/Support/ErrorHandling.h>
1415
#include <llvm/Support/FileSystem.h>
1516
#include <llvm/Support/raw_ostream.h>
@@ -38,7 +39,7 @@ std::ostream& operator<<(std::ostream& os, const RoundTripTestCase& testCase) {
3839
class RoundTripTest : public ::testing::Test,
3940
public ::testing::WithParamInterface<RoundTripTestCase> {};
4041

41-
kj::Array<capnp::word> readJeffFile(llvm::StringRef path) {
42+
llvm::SmallVector<uint8_t> readJeffFile(llvm::StringRef path) {
4243
llvm::sys::fs::file_t file = 0;
4344
if (llvm::sys::fs::openFileForRead(path, file)) {
4445
llvm::errs() << "Failed to open file: " << path << "\n";
@@ -53,9 +54,20 @@ kj::Array<capnp::word> readJeffFile(llvm::StringRef path) {
5354
kj::FdInputStream input(std::move(autoCloseFd));
5455
#endif
5556

57+
auto words = input.readAllBytes();
58+
auto bytes = words.asBytes();
59+
return llvm::SmallVector<uint8_t>(reinterpret_cast<const uint8_t*>(bytes.begin()),
60+
reinterpret_cast<const uint8_t*>(bytes.end()));
61+
}
62+
63+
std::string moduleTextFromBytes(llvm::ArrayRef<uint8_t> data) {
64+
kj::ArrayPtr<const kj::byte> bytes(reinterpret_cast<const kj::byte*>(data.data()), data.size());
65+
kj::ArrayInputStream input(bytes);
66+
5667
capnp::MallocMessageBuilder message;
5768
capnp::readMessageCopy(input, message);
58-
return capnp::messageToFlatArray(message);
69+
auto module = message.getRoot<jeff::Module>();
70+
return std::string(module.toString().flatten().cStr());
5971
}
6072

6173
std::vector<RoundTripTestCase> getTestCases() {
@@ -106,16 +118,12 @@ TEST_P(RoundTripTest, RoundTrip) {
106118
auto serialized = serialize(*mlirModule);
107119

108120
// Compare textual representations
109-
capnp::FlatArrayMessageReader originalMessage(original);
110-
auto originalModule = originalMessage.getRoot<jeff::Module>();
111-
auto originalText = originalModule.toString().flatten();
121+
auto originalText = moduleTextFromBytes(original);
112122

113-
capnp::FlatArrayMessageReader serializedMessage(serialized);
114-
auto serializedModule = serializedMessage.getRoot<jeff::Module>();
115-
auto serializedText = serializedModule.toString().flatten();
123+
auto serializedText = moduleTextFromBytes(serialized);
116124

117-
llvm::errs() << "Original module:\n" << originalText.cStr() << "\n\n";
118-
llvm::errs() << "Serialized module:\n" << serializedText.cStr() << "\n\n";
125+
llvm::errs() << "Original module:\n" << originalText << "\n\n";
126+
llvm::errs() << "Serialized module:\n" << serializedText << "\n\n";
119127

120128
ASSERT_EQ(originalText, serializedText);
121129
}

0 commit comments

Comments
 (0)