Skip to content

Commit d4df366

Browse files
authored
feat: port LLMs to C++ (#415)
## Description This PR ports the current LLM functionality to C++, getting rid of the `Runner` within the `ExecutorchLib` framework. I've also made some changes to the build system. `tokenizers-cpp` - Previously the library was linked in `ExecutorchLib` in XCode via a build script, now it is completely removed from the frameworks. - I've prebuiilt static libraries from `tokenizers-cpp` repo, which I uploaded to `common/ios/libs/tokenizers-cpp`, similarly to the pre-build ExecuTorch binaries. - The includes for tokenizers-cpp are now in `react-native-executorch/third-party/include/tokenizers-cpp/tokenizers_cpp.h` - Made some changes to the `libs` directory structure, please see the podspec for reference. - These headers are then included from the llama runner source code. - Since `tokenizers-cpp` for Android is pre-built with the `ExecuTorch` aar, I'm not making any changes here. This will need to be updated once we bump the ExecuTorch runtime and when we can safely get rid of the aar/jitpack setup. We can keep tokenizers-cpp as our submodule and then just reference it in Android's CMake. `runner` - The runner source along with headers was moved from `ExecuTorchLib` to `common/runner`, similarly to `ada` and is compiled on the fly when our library compiles. - In the current situation, I think that we will be soon able to get rid of the ET fork and the submodule. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 9d42f7f commit d4df366

91 files changed

Lines changed: 239 additions & 2005 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

packages/react-native-executorch/android/CMakeLists.txt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,28 @@ set(COMMON_CPP_DIR "${CMAKE_SOURCE_DIR}/../common")
1414
set(LIBS_DIR "${CMAKE_SOURCE_DIR}/../third-party/android/libs")
1515
set(INCLUDE_DIR "${CMAKE_SOURCE_DIR}/../third-party/include")
1616

17+
# FIXME: Below u can see miserable attempts of trying to link tokenizers-cpp
18+
# directly into react-native-executorch instead of it being linked against ExecuTorch
19+
# and then transitively to our library. Please go back to this when we bump ET runtime to the next version.
20+
# The problem with directly linking tokenizers-cpp using a submodule is that we get unresolved symbols for
21+
# some android logging libraries, which are referenced by sentencepiece.
22+
23+
# set(TOKENIZERS_CPP_DIR "${CMAKE_SOURCE_DIR}/../../../third-party/tokenizers-cpp")
24+
# add_subdirectory("${TOKENIZERS_CPP_DIR}" tokenizers-cpp)
25+
26+
# # Link Android log library to sentencepiece targets
27+
# if(TARGET sentencepiece-static)
28+
# target_link_libraries(sentencepiece-static INTERFACE log)
29+
# endif()
30+
# if(TARGET sentencepiece_train-static)
31+
# target_link_libraries(sentencepiece_train-static INTERFACE log)
32+
# endif()
33+
34+
# # Link log library to sentencepiece executables
35+
# foreach(exe spm_encode spm_decode spm_normalize spm_train spm_export_vocab)
36+
# if(TARGET ${exe})
37+
# target_link_libraries(${exe} log)
38+
# endif()
39+
# endforeach()
40+
1741
add_subdirectory("${ANDROID_CPP_DIR}")

packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt

Lines changed: 0 additions & 63 deletions
This file was deleted.

packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ class RnExecutorchPackage : TurboReactPackage() {
1414
name: String,
1515
reactContext: ReactApplicationContext,
1616
): NativeModule? =
17-
if (name == LLM.NAME) {
18-
LLM(reactContext)
19-
} else if (name == SpeechToText.NAME) {
17+
if (name == SpeechToText.NAME) {
2018
SpeechToText(reactContext)
2119
} else if (name == OCR.NAME) {
2220
OCR(reactContext)
@@ -31,16 +29,6 @@ class RnExecutorchPackage : TurboReactPackage() {
3129
override fun getReactModuleInfoProvider(): ReactModuleInfoProvider =
3230
ReactModuleInfoProvider {
3331
val moduleInfos: MutableMap<String, ReactModuleInfo> = HashMap()
34-
moduleInfos[LLM.NAME] =
35-
ReactModuleInfo(
36-
LLM.NAME,
37-
LLM.NAME,
38-
false, // canOverrideExistingModule
39-
false, // needsEagerInit
40-
true, // hasConstants
41-
false, // isCxxModule
42-
true,
43-
)
4432
moduleInfos[SpeechToText.NAME] =
4533
ReactModuleInfo(
4634
SpeechToText.NAME,

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <rnexecutorch/models/classification/Classification.h>
66
#include <rnexecutorch/models/image_embeddings/ImageEmbeddings.h>
77
#include <rnexecutorch/models/image_segmentation/ImageSegmentation.h>
8+
#include <rnexecutorch/models/llm/LLM.h>
89
#include <rnexecutorch/models/object_detection/ObjectDetection.h>
910
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
1011
#include <rnexecutorch/models/text_embeddings/TextEmbeddings.h>
@@ -55,10 +56,15 @@ void RnExecutorchInstaller::injectJSIBindings(
5556
*jsiRuntime, "loadImageEmbeddings",
5657
RnExecutorchInstaller::loadModel<ImageEmbeddings>(
5758
jsiRuntime, jsCallInvoker, "loadImageEmbeddings"));
59+
5860
jsiRuntime->global().setProperty(
5961
*jsiRuntime, "loadTextEmbeddings",
6062
RnExecutorchInstaller::loadModel<TextEmbeddings>(
6163
jsiRuntime, jsCallInvoker, "loadTextEmbeddings"));
64+
65+
jsiRuntime->global().setProperty(*jsiRuntime, "loadLLM",
66+
RnExecutorchInstaller::loadModel<LLM>(
67+
jsiRuntime, jsCallInvoker, "loadLLM"));
6268
}
6369

6470
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@ REGISTER_CONSTRUCTOR(BaseModel, std::string,
3030
std::shared_ptr<react::CallInvoker>);
3131
REGISTER_CONSTRUCTOR(TokenizerModule, std::string,
3232
std::shared_ptr<react::CallInvoker>);
33-
REGISTER_CONSTRUCTOR(ImageEmbeddings, std::string, std::shared_ptr<react::CallInvoker>);
33+
REGISTER_CONSTRUCTOR(ImageEmbeddings, std::string,
34+
std::shared_ptr<react::CallInvoker>);
3435
REGISTER_CONSTRUCTOR(TextEmbeddings, std::string, std::string,
3536
std::shared_ptr<react::CallInvoker>);
37+
REGISTER_CONSTRUCTOR(LLM, std::string, std::string,
38+
std::shared_ptr<react::CallInvoker>);
3639

3740
using namespace facebook;
3841

packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
#include <ada/ada.h>
77

8-
#include <rnexecutorch/Log.h>
98
#include <rnexecutorch/RnExecutorchInstaller.h>
109
#include <rnexecutorch/data_processing/FileUtils.h>
1110
#include <rnexecutorch/data_processing/base64.h>

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ inline std::string getValue<std::string>(const jsi::Value &val,
4141
return val.getString(runtime).utf8(runtime);
4242
}
4343

44+
template <>
45+
inline std::shared_ptr<jsi::Function>
46+
getValue<std::shared_ptr<jsi::Function>>(const jsi::Value &val,
47+
jsi::Runtime &runtime) {
48+
return std::make_shared<jsi::Function>(
49+
val.asObject(runtime).asFunction(runtime));
50+
}
51+
4452
template <>
4553
inline std::vector<int32_t>
4654
getValue<std::vector<int32_t>>(const jsi::Value &val, jsi::Runtime &runtime) {

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <cstdio>
44
#include <string>
55
#include <tuple>
6+
#include <type_traits>
67
#include <vector>
78

89
#include <ReactCommon/CallInvoker.h>
@@ -15,6 +16,7 @@
1516
#include <rnexecutorch/metaprogramming/FunctionHelpers.h>
1617
#include <rnexecutorch/metaprogramming/TypeConcepts.h>
1718
#include <rnexecutorch/models/BaseModel.h>
19+
#include <rnexecutorch/models/llm/LLM.h>
1820

1921
namespace rnexecutorch {
2022

@@ -70,6 +72,60 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
7072
promiseHostFunction<&Model::tokenToId>,
7173
"tokenToId"));
7274
}
75+
76+
if constexpr (meta::SameAs<Model, LLM>) {
77+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
78+
promiseHostFunction<&Model::generate>,
79+
"generate"));
80+
81+
addFunctions(JSI_EXPORT_FUNCTION(
82+
ModelHostObject<Model>, synchronousHostFunction<&Model::interrupt>,
83+
"interrupt"));
84+
85+
addFunctions(
86+
JSI_EXPORT_FUNCTION(ModelHostObject<Model>, unload, "unload"));
87+
}
88+
}
89+
90+
// A generic host function that runs synchronously, works analogously to the
91+
// generic promise host function.
92+
template <auto FnPtr> JSI_HOST_FUNCTION(synchronousHostFunction) {
93+
constexpr std::size_t functionArgCount = meta::getArgumentCount(FnPtr);
94+
if (functionArgCount != count) {
95+
char errorMessage[100];
96+
std::snprintf(errorMessage, sizeof(errorMessage),
97+
"Argument count mismatch, was expecting: %zu but got: %zu",
98+
functionArgCount, count);
99+
throw jsi::JSError(runtime, errorMessage);
100+
}
101+
102+
try {
103+
auto argsConverted = meta::createArgsTupleFromJsi(FnPtr, args, runtime);
104+
105+
if constexpr (std::is_void_v<decltype(std::apply(
106+
std::bind_front(FnPtr, model), argsConverted))>) {
107+
// For void functions, just call the function and return undefined
108+
std::apply(std::bind_front(FnPtr, model), std::move(argsConverted));
109+
return jsi::Value::undefined();
110+
} else {
111+
// For non-void functions, capture the result and convert it
112+
auto result =
113+
std::apply(std::bind_front(FnPtr, model), std::move(argsConverted));
114+
return jsiconversion::getJsiValue(std::move(result), runtime);
115+
}
116+
} catch (const std::runtime_error &e) {
117+
// This catch should be merged with the next one
118+
// (std::runtime_error inherits from std::exception) HOWEVER react
119+
// native has broken RTTI which breaks proper exception type
120+
// checking. Remove when the following change is present in our
121+
// version:
122+
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
123+
throw jsi::JSError(runtime, e.what());
124+
} catch (const std::exception &e) {
125+
throw jsi::JSError(runtime, e.what());
126+
} catch (...) {
127+
throw jsi::JSError(runtime, "Unknown error in synchronous function");
128+
}
73129
}
74130

75131
// A generic host function that resolves a promise with a result of a
@@ -101,15 +157,28 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
101157
std::thread([this, promise,
102158
argsConverted = std::move(argsConverted)]() {
103159
try {
104-
auto result = std::apply(std::bind_front(FnPtr, model),
105-
std::move(argsConverted));
106-
// The result is copied. It should either be quickly copiable,
107-
// or passed with a shared_ptr.
108-
callInvoker->invokeAsync([promise,
109-
result](jsi::Runtime &runtime) {
110-
promise->resolve(
111-
jsiconversion::getJsiValue(std::move(result), runtime));
112-
});
160+
if constexpr (std::is_void_v<decltype(std::apply(
161+
std::bind_front(FnPtr, model),
162+
argsConverted))>) {
163+
// For void functions, just call the function and resolve with
164+
// undefined
165+
std::apply(std::bind_front(FnPtr, model),
166+
std::move(argsConverted));
167+
callInvoker->invokeAsync([promise](jsi::Runtime &runtime) {
168+
promise->resolve(jsi::Value::undefined());
169+
});
170+
} else {
171+
// For non-void functions, capture the result and convert it
172+
auto result = std::apply(std::bind_front(FnPtr, model),
173+
std::move(argsConverted));
174+
// The result is copied. It should either be quickly copiable,
175+
// or passed with a shared_ptr.
176+
callInvoker->invokeAsync([promise,
177+
result](jsi::Runtime &runtime) {
178+
promise->resolve(
179+
jsiconversion::getJsiValue(std::move(result), runtime));
180+
});
181+
}
113182
} catch (const std::runtime_error &e) {
114183
// This catch should be merged with the next two
115184
// (std::runtime_error and jsi::JSError inherits from
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include "LLM.h"
2+
3+
#include <executorch/extension/tensor/tensor.h>
4+
#include <filesystem>
5+
6+
namespace rnexecutorch {
7+
using namespace facebook;
8+
using executorch::extension::TensorPtr;
9+
using executorch::runtime::Error;
10+
11+
LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
12+
std::shared_ptr<react::CallInvoker> callInvoker)
13+
: runner(std::make_unique<example::Runner>(modelSource, tokenizerSource)),
14+
callInvoker(callInvoker) {
15+
auto loadResult = runner->load();
16+
if (loadResult != Error::Ok) {
17+
throw std::runtime_error("Failed to load LLM runner, error code: " +
18+
std::to_string(static_cast<int>(loadResult)));
19+
}
20+
memorySizeLowerBound =
21+
std::filesystem::file_size(std::filesystem::path(modelSource)) +
22+
std::filesystem::file_size(std::filesystem::path(tokenizerSource));
23+
}
24+
25+
void LLM::generate(std::string input, std::shared_ptr<jsi::Function> callback) {
26+
if (!runner || !runner->is_loaded()) {
27+
throw std::runtime_error("Runner is not loaded");
28+
}
29+
30+
// Create a native callback that will invoke the JS callback on the JS thread
31+
auto nativeCallback = [this, callback](const std::string &token) {
32+
callInvoker->invokeAsync([callback, token](jsi::Runtime &runtime) {
33+
callback->call(runtime, jsi::String::createFromUtf8(runtime, token));
34+
});
35+
};
36+
37+
auto error = runner->generate(input, nativeCallback, {}, false);
38+
if (error != executorch::runtime::Error::Ok) {
39+
throw std::runtime_error("Failed to generate text, error code: " +
40+
std::to_string(static_cast<int>(error)));
41+
}
42+
}
43+
44+
void LLM::interrupt() {
45+
if (!runner || !runner->is_loaded()) {
46+
throw std::runtime_error("Can't interrupt a model that's not loaded!");
47+
}
48+
runner->stop();
49+
}
50+
51+
std::size_t LLM::getMemoryLowerBound() const noexcept {
52+
return memorySizeLowerBound;
53+
}
54+
55+
void LLM::unload() noexcept { runner.reset(nullptr); }
56+
57+
} // namespace rnexecutorch
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
6+
#include <ReactCommon/CallInvoker.h>
7+
#include <jsi/jsi.h>
8+
#include <runner/runner.h>
9+
10+
namespace rnexecutorch {
11+
using namespace facebook;
12+
13+
class LLM {
14+
public:
15+
explicit LLM(const std::string &modelSource,
16+
const std::string &tokenizerSource,
17+
std::shared_ptr<react::CallInvoker> callInvoker);
18+
19+
void generate(std::string input, std::shared_ptr<jsi::Function> callback);
20+
void interrupt();
21+
void unload() noexcept;
22+
std::size_t getMemoryLowerBound() const noexcept;
23+
24+
private:
25+
size_t memorySizeLowerBound;
26+
std::unique_ptr<example::Runner> runner;
27+
std::shared_ptr<react::CallInvoker> callInvoker;
28+
};
29+
} // namespace rnexecutorch

0 commit comments

Comments
 (0)