Skip to content

Commit 6b88b4b

Browse files
authored
feat: port TokenizerModule to C++ (#393)
## Description <!-- Provide a concise and descriptive summary of the changes implemented in this PR. --> ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [ ] iOS - [ ] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent f2939db commit 6b88b4b

File tree

20 files changed

+406
-55
lines changed

20 files changed

+406
-55
lines changed

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "RnExecutorchInstaller.h"
22

3+
#include <rnexecutorch/TokenizerModule.h>
34
#include <rnexecutorch/host_objects/JsiConversions.h>
45
#include <rnexecutorch/models/classification/Classification.h>
56
#include <rnexecutorch/models/image_segmentation/ImageSegmentation.h>
@@ -42,6 +43,11 @@ void RnExecutorchInstaller::injectJSIBindings(
4243
*jsiRuntime, "loadExecutorchModule",
4344
RnExecutorchInstaller::loadModel<BaseModel>(jsiRuntime, jsCallInvoker,
4445
"loadExecutorchModule"));
46+
47+
jsiRuntime->global().setProperty(
48+
*jsiRuntime, "loadTokenizerModule",
49+
RnExecutorchInstaller::loadModel<TokenizerModule>(
50+
jsiRuntime, jsCallInvoker, "loadTokenizerModule"));
4551
}
4652

4753
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ REGISTER_CONSTRUCTOR(ObjectDetection, std::string,
2828
std::shared_ptr<react::CallInvoker>);
2929
REGISTER_CONSTRUCTOR(BaseModel, std::string,
3030
std::shared_ptr<react::CallInvoker>);
31+
REGISTER_CONSTRUCTOR(TokenizerModule, std::string,
32+
std::shared_ptr<react::CallInvoker>);
3133

3234
using namespace facebook;
3335

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "TokenizerModule.h"
2+
#include <executorch/extension/module/module.h>
3+
#include <filesystem>
4+
#include <rnexecutorch/Log.h>
5+
#include <rnexecutorch/data_processing/FileUtils.h>
6+
7+
namespace rnexecutorch {
8+
using namespace facebook;
9+
10+
TokenizerModule::TokenizerModule(
11+
std::string source, std::shared_ptr<react::CallInvoker> callInvoker)
12+
: memorySizeLowerBound(std::filesystem::file_size(source)),
13+
tokenizer(tokenizers::Tokenizer::FromBlobJSON(
14+
fileutils::loadBytesFromFile(source))) {}
15+
16+
void TokenizerModule::ensureTokenizerLoaded(
17+
const std::string &methodName) const {
18+
if (!tokenizer) {
19+
throw std::runtime_error(
20+
methodName + " function was called on an uninitialized tokenizer!");
21+
}
22+
}
23+
24+
std::vector<int32_t> TokenizerModule::encode(std::string s) const {
25+
ensureTokenizerLoaded("encode");
26+
return tokenizer->Encode(s);
27+
}
28+
29+
std::string TokenizerModule::decode(std::vector<int32_t> vec,
30+
bool skipSpecialTokens) const {
31+
ensureTokenizerLoaded("decode");
32+
return tokenizer->Decode(vec, skipSpecialTokens);
33+
}
34+
35+
size_t TokenizerModule::getVocabSize() const {
36+
ensureTokenizerLoaded("getVocabSize");
37+
return tokenizer->GetVocabSize();
38+
}
39+
40+
std::string TokenizerModule::idToToken(int32_t tokenId) const {
41+
ensureTokenizerLoaded("idToToken");
42+
return tokenizer->IdToToken(tokenId);
43+
}
44+
45+
int32_t TokenizerModule::tokenToId(std::string token) const {
46+
ensureTokenizerLoaded("tokenToId");
47+
return tokenizer->TokenToId(token);
48+
}
49+
std::size_t TokenizerModule::getMemoryLowerBound() const noexcept {
50+
return memorySizeLowerBound;
51+
}
52+
} // namespace rnexecutorch
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include <ReactCommon/CallInvoker.h>
4+
#include <string>
5+
#include <tokenizers-cpp/tokenizers_cpp.h>
6+
7+
namespace rnexecutorch {
8+
using namespace facebook;
9+
10+
class TokenizerModule {
11+
public:
12+
explicit TokenizerModule(std::string source,
13+
std::shared_ptr<react::CallInvoker> callInvoker);
14+
std::vector<int32_t> encode(std::string s) const;
15+
std::string decode(std::vector<int32_t> vec, bool skipSpecialTokens) const;
16+
std::string idToToken(int32_t tokenId) const;
17+
int32_t tokenToId(std::string token) const;
18+
std::size_t getVocabSize() const;
19+
std::size_t getMemoryLowerBound() const noexcept;
20+
21+
private:
22+
void ensureTokenizerLoaded(const std::string &methodName) const;
23+
std::unique_ptr<tokenizers::Tokenizer> tokenizer;
24+
const std::size_t memorySizeLowerBound{0};
25+
};
26+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/data_processing/FileUtils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

33
#include <chrono>
4+
#include <filesystem>
5+
#include <fstream>
46
#include <string>
57

68
namespace rnexecutorch::fileutils {
@@ -11,4 +13,18 @@ inline std::string getTimeID() {
1113
.count());
1214
}
1315

16+
inline std::string loadBytesFromFile(const std::string &path) {
17+
std::ifstream fs(path, std::ios::in | std::ios::binary);
18+
if (fs.fail()) {
19+
throw std::runtime_error("Failed to open tokenizer file");
20+
}
21+
std::string data;
22+
fs.seekg(0, std::ios::end);
23+
size_t size = static_cast<size_t>(fs.tellg());
24+
fs.seekg(0, std::ios::beg);
25+
data.resize(size);
26+
fs.read(data.data(), size);
27+
return data;
28+
};
29+
1430
} // namespace rnexecutorch::fileutils

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ inline std::string getValue<std::string>(const jsi::Value &val,
4141
return val.getString(runtime).utf8(runtime);
4242
}
4343

44+
template <>
45+
inline std::vector<int32_t>
46+
getValue<std::vector<int32_t>>(const jsi::Value &val, jsi::Runtime &runtime) {
47+
jsi::Array array = val.asObject(runtime).asArray(runtime);
48+
size_t length = array.size(runtime);
49+
std::vector<int32_t> result;
50+
result.reserve(length);
51+
52+
for (size_t i = 0; i < length; ++i) {
53+
jsi::Value element = array.getValueAtIndex(runtime, i);
54+
result.push_back(getValue<int32_t>(element, runtime));
55+
}
56+
return result;
57+
}
58+
4459
template <>
4560
inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val,
4661
jsi::Runtime &runtime) {
@@ -182,20 +197,18 @@ getJsiValue(const std::vector<std::shared_ptr<OwningArrayBuffer>> &vec,
182197
return jsi::Value(runtime, array);
183198
}
184199

185-
inline jsi::Value
186-
getJsiValue(const std::vector<std::shared_ptr<JSTensorViewOut>> &vec,
187-
jsi::Runtime &runtime) {
200+
inline jsi::Value getJsiValue(const std::vector<JSTensorViewOut> &vec,
201+
jsi::Runtime &runtime) {
188202
jsi::Array array(runtime, vec.size());
189203
for (size_t i = 0; i < vec.size(); i++) {
190204
jsi::Object tensorObj(runtime);
191205

192-
tensorObj.setProperty(runtime, "sizes",
193-
getJsiValue(vec[i]->sizes, runtime));
206+
tensorObj.setProperty(runtime, "sizes", getJsiValue(vec[i].sizes, runtime));
194207

195208
tensorObj.setProperty(runtime, "scalarType",
196-
jsi::Value(static_cast<int>(vec[i]->scalarType)));
209+
jsi::Value(static_cast<int>(vec[i].scalarType)));
197210

198-
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->dataPtr);
211+
jsi::ArrayBuffer arrayBuffer(runtime, vec[i].dataPtr);
199212
tensorObj.setProperty(runtime, "dataPtr", arrayBuffer);
200213

201214
array.setValueAtIndex(runtime, i, tensorObj);

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#include <ReactCommon/CallInvoker.h>
99

10-
#include <rnexecutorch/Log.h>
10+
#include <rnexecutorch/TokenizerModule.h>
1111
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
1212
#include <rnexecutorch/host_objects/JsiConversions.h>
1313
#include <rnexecutorch/jsi/JsiHostObject.h>
@@ -45,6 +45,31 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
4545
promiseHostFunction<&Model::generate>,
4646
"generate"));
4747
}
48+
49+
if constexpr (meta::HasEncode<Model>) {
50+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
51+
promiseHostFunction<&Model::encode>,
52+
"encode"));
53+
}
54+
55+
if constexpr (meta::SameAs<Model, TokenizerModule>) {
56+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
57+
promiseHostFunction<&Model::encode>,
58+
"encode"));
59+
60+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
61+
promiseHostFunction<&Model::decode>,
62+
"decode"));
63+
addFunctions(JSI_EXPORT_FUNCTION(
64+
ModelHostObject<Model>, promiseHostFunction<&Model::getVocabSize>,
65+
"getVocabSize"));
66+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
67+
promiseHostFunction<&Model::idToToken>,
68+
"idToToken"));
69+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
70+
promiseHostFunction<&Model::tokenToId>,
71+
"tokenToId"));
72+
}
4873
}
4974

5075
// A generic host function that resolves a promise with a result of a
@@ -76,8 +101,8 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
76101
std::thread([this, promise,
77102
argsConverted = std::move(argsConverted)]() {
78103
try {
79-
auto result =
80-
std::apply(std::bind_front(FnPtr, model), argsConverted);
104+
auto result = std::apply(std::bind_front(FnPtr, model),
105+
std::move(argsConverted));
81106
// The result is copied. It should either be quickly copiable,
82107
// or passed with a shared_ptr.
83108
callInvoker->invokeAsync([promise,

packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) {
1414
return sizeof...(Types);
1515
}
1616

17+
template <typename Model, typename R, typename... Types>
18+
constexpr std::size_t getArgumentCount(R (Model::*f)(Types...) const) {
19+
return sizeof...(Types);
20+
}
21+
1722
template <typename... Types, std::size_t... I>
1823
std::tuple<Types...> fillTupleFromArgs(std::index_sequence<I...>,
1924
const jsi::Value *args,
@@ -34,4 +39,12 @@ std::tuple<Types...> createArgsTupleFromJsi(R (Model::*f)(Types...),
3439
return fillTupleFromArgs<Types...>(std::index_sequence_for<Types...>{}, args,
3540
runtime);
3641
}
42+
43+
template <typename Model, typename R, typename... Types>
44+
std::tuple<Types...> createArgsTupleFromJsi(R (Model::*f)(Types...) const,
45+
const jsi::Value *args,
46+
jsi::Runtime &runtime) {
47+
return fillTupleFromArgs<Types...>(std::index_sequence_for<Types...>{}, args,
48+
runtime);
49+
}
3750
} // namespace rnexecutorch::meta

packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,24 @@ namespace rnexecutorch::meta {
88
template <typename T, typename Base>
99
concept DerivedFromOrSameAs = std::is_base_of_v<Base, T>;
1010

11+
template <typename T, typename Base>
12+
concept SameAs = std::is_same_v<Base, T>;
13+
1114
template <typename T>
1215
concept HasGenerate = requires(T t) {
1316
{ &T::generate };
1417
};
1518

19+
template <typename T>
20+
concept HasEncode = requires(T t) {
21+
{ &T::encode };
22+
};
23+
24+
template <typename T>
25+
concept HasDecode = requires(T t) {
26+
{ &T::decode };
27+
};
28+
1629
template <typename T>
1730
concept IsNumeric = std::is_arithmetic_v<T>;
1831

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ BaseModel::getAllInputShapes(std::string methodName) {
8383
return output;
8484
}
8585

86-
std::vector<std::shared_ptr<JSTensorViewOut>>
86+
std::vector<JSTensorViewOut>
8787
BaseModel::forwardJS(const std::vector<JSTensorViewIn> tensorViewVec) {
8888
if (!module) {
8989
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
@@ -114,7 +114,7 @@ BaseModel::forwardJS(const std::vector<JSTensorViewIn> tensorViewVec) {
114114
}
115115

116116
auto &outputs = result.get();
117-
std::vector<std::shared_ptr<JSTensorViewOut>> output;
117+
std::vector<JSTensorViewOut> output;
118118
output.reserve(outputs.size());
119119

120120
// Convert ET outputs to a vector of JSTensorViewOut which are later
@@ -125,8 +125,7 @@ BaseModel::forwardJS(const std::vector<JSTensorViewIn> tensorViewVec) {
125125
size_t bufferSize = outputTensor.numel() * outputTensor.element_size();
126126
auto buffer = std::make_shared<OwningArrayBuffer>(bufferSize);
127127
std::memcpy(buffer->data(), outputTensor.const_data_ptr(), bufferSize);
128-
auto jsTensor = std::make_shared<JSTensorViewOut>(
129-
sizes, outputTensor.scalar_type(), buffer);
128+
auto jsTensor = JSTensorViewOut(sizes, outputTensor.scalar_type(), buffer);
130129
output.emplace_back(jsTensor);
131130
}
132131
return output;

0 commit comments

Comments
 (0)