Skip to content

Commit ffce3f9

Browse files
committed
wip
1 parent 8560b53 commit ffce3f9

6 files changed

Lines changed: 89 additions & 1 deletion

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "TokenizerModule.h"
2+
#include <executorch/extension/module/module.h>
3+
#include <fstream>
4+
#include <rnexecutorch/Log.h>
5+
#include <rnexecutorch/data_processing/FileUtils.h>
6+
7+
namespace rnexecutorch {
8+
using namespace facebook;
9+
10+
TokenizerModule::TokenizerModule(
11+
std::string source, std::shared_ptr<react::CallInvoker> callInvoker) {
12+
auto blob = fileutils::loadBytesFromFile(source);
13+
tokenizer = tokenizers::Tokenizer::FromBlobJSON(blob);
14+
}
15+
16+
std::vector<int32_t> TokenizerModule::encode(std::string s) {
17+
if (!tokenizer) {
18+
throw std::runtime_error("Encode called on an uninitialized tokenizer!");
19+
};
20+
return tokenizer->Encode(s);
21+
}
22+
23+
std::string TokenizerModule::decode(std::vector<int32_t> vec) {
24+
if (!tokenizer) {
25+
throw std::runtime_error("Decode called on an uninitialized tokenizer!");
26+
}
27+
return tokenizer->Decode(vec);
28+
}
29+
30+
int TokenizerModule::getMemoryLowerBound() { return 1; }
31+
32+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/TypeConcepts.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@ concept HasGenerate = requires(T t) {
1313
{ &T::generate };
1414
};
1515

16+
template <typename T>
17+
concept HasEncode = requires(T t) {
18+
{ &T::encode };
19+
};
20+
21+
template <typename T>
22+
concept HasDecode = requires(T t) {
23+
{ &T::decode };
24+
};
25+
1626
template <typename T>
1727
concept IsNumeric = std::is_arithmetic_v<T>;
1828

packages/react-native-executorch/common/rnexecutorch/data_processing/FileUtils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

33
#include <chrono>
4+
#include <filesystem>
5+
#include <fstream>
46
#include <string>
57

68
namespace rnexecutorch::fileutils {
@@ -11,4 +13,18 @@ inline std::string getTimeID() {
1113
.count());
1214
}
1315

16+
inline std::string loadBytesFromFile(const std::string &path) {
17+
std::ifstream fs(path, std::ios::in | std::ios::binary);
18+
if (fs.fail()) {
19+
throw std::runtime_error("Failed to open tokenizer file");
20+
}
21+
std::string data;
22+
fs.seekg(0, std::ios::end);
23+
size_t size = static_cast<size_t>(fs.tellg());
24+
fs.seekg(0, std::ios::beg);
25+
data.resize(size);
26+
fs.read(data.data(), size);
27+
return data;
28+
};
29+
1430
} // namespace rnexecutorch::fileutils

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ inline std::string getValue<std::string>(const jsi::Value &val,
4141
return val.getString(runtime).utf8(runtime);
4242
}
4343

44+
template <>
45+
inline std::vector<int32_t>
46+
getValue<std::vector<int32_t>>(const jsi::Value &val, jsi::Runtime &runtime) {
47+
jsi::Array array = val.asObject(runtime).asArray(runtime);
48+
size_t length = array.size(runtime);
49+
std::vector<int32_t> result;
50+
result.reserve(length);
51+
52+
for (size_t i = 0; i < length; ++i) {
53+
jsi::Value element = array.getValueAtIndex(runtime, i);
54+
result.push_back(getValue<int32_t>(element, runtime));
55+
}
56+
return result;
57+
}
58+
4459
template <>
4560
inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val,
4661
jsi::Runtime &runtime) {

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
4444
promiseHostFunction<&Model::generate>,
4545
"generate"));
4646
}
47+
48+
if constexpr (HasEncode<Model>) {
49+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
50+
promiseHostFunction<&Model::encode>,
51+
"encode"));
52+
}
53+
54+
if constexpr (HasDecode<Model>) {
55+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
56+
promiseHostFunction<&Model::decode>,
57+
"decode"));
58+
}
4759
}
4860

4961
// A generic host function that resolves a promise with a result of a

packages/react-native-executorch/src/index.tsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@ declare global {
99
var loadClassification: (source: string) => any;
1010
var loadObjectDetection: (source: string) => any;
1111
var loadExecutorchModule: (source: string) => any;
12+
var loadTokenizerModule: (source: string) => any;
1213
}
1314
// eslint-disable no-var
1415
if (
1516
global.loadStyleTransfer == null ||
1617
global.loadImageSegmentation == null ||
1718
global.loadExecutorchModule == null ||
1819
global.loadClassification == null ||
19-
global.loadObjectDetection == null
20+
global.loadObjectDetection == null ||
21+
global.loadTokenizerModule == null
2022
) {
2123
if (!ETInstallerNativeModule) {
2224
throw new Error(
@@ -54,6 +56,7 @@ export * from './modules/natural_language_processing/LLMModule';
5456
export * from './modules/natural_language_processing/SpeechToTextModule';
5557
export * from './modules/natural_language_processing/TextEmbeddingsModule';
5658
export * from './modules/natural_language_processing/TokenizerModule';
59+
export * from './modules/natural_language_processing/NewTokenizerModule';
5760

5861
// utils
5962
export * from './utils/ResourceFetcher';

0 commit comments

Comments
 (0)