Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e299c3f
chore: add runner to includes
chmjkb Jun 24, 2025
45938f3
wip: add an example header
chmjkb Jun 24, 2025
5cda73f
feat: support void functions in ModelHostObject
chmjkb Jun 26, 2025
0383fcc
feat: adjust installer to work with llms
chmjkb Jun 26, 2025
6a217bf
feat: add JSI conversion for js callbacks
chmjkb Jun 26, 2025
d985c25
feat: add llm to modelhostobject
chmjkb Jun 26, 2025
189440a
feat: add LLM runner
chmjkb Jun 26, 2025
6e6703d
feat: adjust controller to match the new native impl
chmjkb Jun 26, 2025
6fdd91b
fix: check if native llm is installed
chmjkb Jun 26, 2025
bc83f01
remove a bunch of code 💅🏻
chmjkb Jun 26, 2025
158265f
remove a bunch of code 💅🏻
chmjkb Jun 26, 2025
b00c5f0
chore: move runner to common/
chmjkb Jun 26, 2025
6fdf271
chore: update runner.h
chmjkb Jun 26, 2025
9264242
chore: update xcframework
chmjkb Jun 26, 2025
3740b5b
chore: get rid of tokenizers_c.h
chmjkb Jun 26, 2025
2c5bd57
chore: update executorchlib xcodeproj
chmjkb Jun 26, 2025
23d61ff
chore: rename runner.{h,cpp} to LLM.{h,cpp}
chmjkb Jun 27, 2025
ec40ff8
fix: define moduleInfos
chmjkb Jun 27, 2025
4f2810e
fix: fix includes after renaming runner
chmjkb Jun 27, 2025
64785cc
chore: move executorch ios libs, add tokenizers-cpp static libs
chmjkb Jun 27, 2025
c3a7d17
fix: update podspec to match the new ios libs structure
chmjkb Jun 27, 2025
fcda895
wip: android cmake
chmjkb Jun 27, 2025
50b19cc
chore: unify memory lower bound member naming with BaseModel.cpp
chmjkb Jun 27, 2025
d80e855
chore: remove Android static-libs for tokenizers-cpp as they are not …
chmjkb Jun 27, 2025
5117e65
chore: remove --force_load flag from tokenizers-cpp static libs
chmjkb Jun 27, 2025
c3b1a84
fix: Ensure corectness of podspec libs path, add typing to llm contro…
chmjkb Jun 27, 2025
afb1912
chore: remove outdated comment
chmjkb Jun 27, 2025
2cf6c6a
feat: add generic synchronous host function wraper
chmjkb Jun 27, 2025
2acd171
chore: remove accidental export of getMemoryLowerBound
chmjkb Jun 27, 2025
ffe6387
fix: handle void returns in synchronous host function wrapper
chmjkb Jun 27, 2025
3826a29
chore: remove log includesw
chmjkb Jun 30, 2025
cf72d6a
chore: make functions noexcept & const
chmjkb Jun 30, 2025
ec82b0e
logging: improve error messages
chmjkb Jun 30, 2025
f23587b
chore: remove unused include
chmjkb Jun 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions packages/react-native-executorch/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,28 @@ set(COMMON_CPP_DIR "${CMAKE_SOURCE_DIR}/../common")
set(LIBS_DIR "${CMAKE_SOURCE_DIR}/../third-party/android/libs")
set(INCLUDE_DIR "${CMAKE_SOURCE_DIR}/../third-party/include")

# FIXME: Below u can see miserable attempts of trying to link tokenizers-cpp
# directly into react-native-executorch instead of it being linked against ExecuTorch
# and then transitively to our library. Please go back to this when we bump ET runtime to the next version.
# The problem with directly linking tokenizers-cpp using a submodule is that we get unresolved symbols for
# some android logging libraries, which are referenced by sentencepiece.

# set(TOKENIZERS_CPP_DIR "${CMAKE_SOURCE_DIR}/../../../third-party/tokenizers-cpp")
# add_subdirectory("${TOKENIZERS_CPP_DIR}" tokenizers-cpp)

# # Link Android log library to sentencepiece targets
# if(TARGET sentencepiece-static)
# target_link_libraries(sentencepiece-static INTERFACE log)
# endif()
# if(TARGET sentencepiece_train-static)
# target_link_libraries(sentencepiece_train-static INTERFACE log)
# endif()

# # Link log library to sentencepiece executables
# foreach(exe spm_encode spm_decode spm_normalize spm_train spm_export_vocab)
# if(TARGET ${exe})
# target_link_libraries(${exe} log)
# endif()
# endforeach()

add_subdirectory("${ANDROID_CPP_DIR}")

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ class RnExecutorchPackage : TurboReactPackage() {
name: String,
reactContext: ReactApplicationContext,
): NativeModule? =
if (name == LLM.NAME) {
LLM(reactContext)
} else if (name == SpeechToText.NAME) {
if (name == SpeechToText.NAME) {
SpeechToText(reactContext)
} else if (name == OCR.NAME) {
OCR(reactContext)
Expand All @@ -31,16 +29,6 @@ class RnExecutorchPackage : TurboReactPackage() {
override fun getReactModuleInfoProvider(): ReactModuleInfoProvider =
ReactModuleInfoProvider {
val moduleInfos: MutableMap<String, ReactModuleInfo> = HashMap()
moduleInfos[LLM.NAME] =
ReactModuleInfo(
LLM.NAME,
LLM.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
true, // hasConstants
false, // isCxxModule
true,
)
moduleInfos[SpeechToText.NAME] =
ReactModuleInfo(
SpeechToText.NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <rnexecutorch/models/classification/Classification.h>
#include <rnexecutorch/models/image_embeddings/ImageEmbeddings.h>
#include <rnexecutorch/models/image_segmentation/ImageSegmentation.h>
#include <rnexecutorch/models/llm/LLM.h>
#include <rnexecutorch/models/object_detection/ObjectDetection.h>
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
#include <rnexecutorch/models/text_embeddings/TextEmbeddings.h>
Expand Down Expand Up @@ -55,10 +56,15 @@ void RnExecutorchInstaller::injectJSIBindings(
*jsiRuntime, "loadImageEmbeddings",
RnExecutorchInstaller::loadModel<ImageEmbeddings>(
jsiRuntime, jsCallInvoker, "loadImageEmbeddings"));

jsiRuntime->global().setProperty(
*jsiRuntime, "loadTextEmbeddings",
RnExecutorchInstaller::loadModel<TextEmbeddings>(
jsiRuntime, jsCallInvoker, "loadTextEmbeddings"));

jsiRuntime->global().setProperty(*jsiRuntime, "loadLLM",
RnExecutorchInstaller::loadModel<LLM>(
jsiRuntime, jsCallInvoker, "loadLLM"));
}

} // namespace rnexecutorch
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ REGISTER_CONSTRUCTOR(BaseModel, std::string,
std::shared_ptr<react::CallInvoker>);
REGISTER_CONSTRUCTOR(TokenizerModule, std::string,
std::shared_ptr<react::CallInvoker>);
REGISTER_CONSTRUCTOR(ImageEmbeddings, std::string, std::shared_ptr<react::CallInvoker>);
REGISTER_CONSTRUCTOR(ImageEmbeddings, std::string,
std::shared_ptr<react::CallInvoker>);
REGISTER_CONSTRUCTOR(TextEmbeddings, std::string, std::string,
std::shared_ptr<react::CallInvoker>);
REGISTER_CONSTRUCTOR(LLM, std::string, std::string,
std::shared_ptr<react::CallInvoker>);

using namespace facebook;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include <ada/ada.h>

#include <rnexecutorch/Log.h>
#include <rnexecutorch/RnExecutorchInstaller.h>
#include <rnexecutorch/data_processing/FileUtils.h>
#include <rnexecutorch/data_processing/base64.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ inline std::string getValue<std::string>(const jsi::Value &val,
return val.getString(runtime).utf8(runtime);
}

template <>
inline std::shared_ptr<jsi::Function>
getValue<std::shared_ptr<jsi::Function>>(const jsi::Value &val,
jsi::Runtime &runtime) {
return std::make_shared<jsi::Function>(
val.asObject(runtime).asFunction(runtime));
}

template <>
inline std::vector<int32_t>
getValue<std::vector<int32_t>>(const jsi::Value &val, jsi::Runtime &runtime) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <cstdio>
#include <string>
#include <tuple>
#include <type_traits>
#include <vector>

#include <ReactCommon/CallInvoker.h>
Expand All @@ -15,6 +16,7 @@
#include <rnexecutorch/metaprogramming/FunctionHelpers.h>
#include <rnexecutorch/metaprogramming/TypeConcepts.h>
#include <rnexecutorch/models/BaseModel.h>
#include <rnexecutorch/models/llm/LLM.h>

namespace rnexecutorch {

Expand Down Expand Up @@ -70,6 +72,60 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
promiseHostFunction<&Model::tokenToId>,
"tokenToId"));
}

if constexpr (meta::SameAs<Model, LLM>) {
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
promiseHostFunction<&Model::generate>,
"generate"));

addFunctions(JSI_EXPORT_FUNCTION(
ModelHostObject<Model>, synchronousHostFunction<&Model::interrupt>,
"interrupt"));

addFunctions(
JSI_EXPORT_FUNCTION(ModelHostObject<Model>, unload, "unload"));
}
}

// A generic host function that runs synchronously, works analogously to the
// generic promise host function.
template <auto FnPtr> JSI_HOST_FUNCTION(synchronousHostFunction) {
constexpr std::size_t functionArgCount = meta::getArgumentCount(FnPtr);
if (functionArgCount != count) {
char errorMessage[100];
std::snprintf(errorMessage, sizeof(errorMessage),
"Argument count mismatch, was expecting: %zu but got: %zu",
functionArgCount, count);
throw jsi::JSError(runtime, errorMessage);
}

try {
auto argsConverted = meta::createArgsTupleFromJsi(FnPtr, args, runtime);

if constexpr (std::is_void_v<decltype(std::apply(
std::bind_front(FnPtr, model), argsConverted))>) {
// For void functions, just call the function and return undefined
std::apply(std::bind_front(FnPtr, model), std::move(argsConverted));
return jsi::Value::undefined();
} else {
// For non-void functions, capture the result and convert it
auto result =
std::apply(std::bind_front(FnPtr, model), std::move(argsConverted));
return jsiconversion::getJsiValue(std::move(result), runtime);
}
} catch (const std::runtime_error &e) {
// This catch should be merged with the next one
// (std::runtime_error inherits from std::exception) HOWEVER react
// native has broken RTTI which breaks proper exception type
// checking. Remove when the following change is present in our
// version:
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
throw jsi::JSError(runtime, e.what());
} catch (const std::exception &e) {
throw jsi::JSError(runtime, e.what());
} catch (...) {
throw jsi::JSError(runtime, "Unknown error in synchronous function");
}
}

// A generic host function that resolves a promise with a result of a
Expand Down Expand Up @@ -101,15 +157,28 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
std::thread([this, promise,
argsConverted = std::move(argsConverted)]() {
try {
auto result = std::apply(std::bind_front(FnPtr, model),
std::move(argsConverted));
// The result is copied. It should either be quickly copiable,
// or passed with a shared_ptr.
callInvoker->invokeAsync([promise,
result](jsi::Runtime &runtime) {
promise->resolve(
jsiconversion::getJsiValue(std::move(result), runtime));
});
if constexpr (std::is_void_v<decltype(std::apply(
std::bind_front(FnPtr, model),
argsConverted))>) {
// For void functions, just call the function and resolve with
// undefined
std::apply(std::bind_front(FnPtr, model),
std::move(argsConverted));
callInvoker->invokeAsync([promise](jsi::Runtime &runtime) {
promise->resolve(jsi::Value::undefined());
});
} else {
// For non-void functions, capture the result and convert it
auto result = std::apply(std::bind_front(FnPtr, model),
std::move(argsConverted));
// The result is copied. It should either be quickly copiable,
// or passed with a shared_ptr.
callInvoker->invokeAsync([promise,
result](jsi::Runtime &runtime) {
promise->resolve(
jsiconversion::getJsiValue(std::move(result), runtime));
});
}
} catch (const std::runtime_error &e) {
// This catch should be merged with the next two
// (std::runtime_error and jsi::JSError inherits from
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include "LLM.h"

#include <executorch/extension/tensor/tensor.h>
#include <filesystem>

namespace rnexecutorch {
using namespace facebook;
using executorch::extension::TensorPtr;
using executorch::runtime::Error;

LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
Comment thread
chmjkb marked this conversation as resolved.
std::shared_ptr<react::CallInvoker> callInvoker)
: runner(std::make_unique<example::Runner>(modelSource, tokenizerSource)),
callInvoker(callInvoker) {
auto loadResult = runner->load();
if (loadResult != Error::Ok) {
throw std::runtime_error("Failed to load LLM runner, error code: " +
std::to_string(static_cast<int>(loadResult)));
}
Comment thread
chmjkb marked this conversation as resolved.
memorySizeLowerBound =
std::filesystem::file_size(std::filesystem::path(modelSource)) +
std::filesystem::file_size(std::filesystem::path(tokenizerSource));
}

void LLM::generate(std::string input, std::shared_ptr<jsi::Function> callback) {
if (!runner || !runner->is_loaded()) {
throw std::runtime_error("Runner is not loaded");
}

// Create a native callback that will invoke the JS callback on the JS thread
auto nativeCallback = [this, callback](const std::string &token) {
callInvoker->invokeAsync([callback, token](jsi::Runtime &runtime) {
callback->call(runtime, jsi::String::createFromUtf8(runtime, token));
});
};

auto error = runner->generate(input, nativeCallback, {}, false);
if (error != executorch::runtime::Error::Ok) {
throw std::runtime_error("Failed to generate text, error code: " +
std::to_string(static_cast<int>(error)));
}
}

void LLM::interrupt() {
if (!runner || !runner->is_loaded()) {
throw std::runtime_error("Can't interrupt a model that's not loaded!");
}
runner->stop();
}

std::size_t LLM::getMemoryLowerBound() const noexcept {
return memorySizeLowerBound;
}

void LLM::unload() noexcept { runner.reset(nullptr); }

} // namespace rnexecutorch
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <memory>
#include <string>

#include <ReactCommon/CallInvoker.h>
#include <jsi/jsi.h>
#include <runner/runner.h>

namespace rnexecutorch {
using namespace facebook;

class LLM {
public:
explicit LLM(const std::string &modelSource,
const std::string &tokenizerSource,
std::shared_ptr<react::CallInvoker> callInvoker);

void generate(std::string input, std::shared_ptr<jsi::Function> callback);
void interrupt();
void unload() noexcept;
std::size_t getMemoryLowerBound() const noexcept;

private:
size_t memorySizeLowerBound;
std::unique_ptr<example::Runner> runner;
std::shared_ptr<react::CallInvoker> callInvoker;
};
} // namespace rnexecutorch
Loading
Loading