Skip to content

Commit 9fa5c9f

Browse files
committed
wip
1 parent 3df12b7 commit 9fa5c9f

12 files changed

Lines changed: 157 additions & 11 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "RnExecutorchInstaller.h"
22

3+
#include <rnexecutorch/bindings/ExecutorchModule.h>
34
#include <rnexecutorch/host_objects/JsiConversions.h>
45
#include <rnexecutorch/models/image_segmentation/ImageSegmentation.h>
56
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
@@ -25,5 +26,10 @@ void RnExecutorchInstaller::injectJSIBindings(
2526
*jsiRuntime, "loadImageSegmentation",
2627
RnExecutorchInstaller::loadModel<ImageSegmentation>(
2728
jsiRuntime, jsCallInvoker, "loadImageSegmentation"));
29+
30+
jsiRuntime->global().setProperty(
31+
*jsiRuntime, "loadExecutorchModule",
32+
RnExecutorchInstaller::loadModel<ExecutorchModule>(
33+
jsiRuntime, jsCallInvoker, "loadExecutorchModule"));
2834
}
2935
} // namespace rnexecutorch
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "ExecutorchModule.h"
2+
3+
#include <rnexecutorch/Log.h>
4+
#include <sstream>
5+
6+
namespace rnexecutorch {
7+
8+
using ::executorch::extension::Module;
9+
using ::executorch::runtime::Error;
10+
using namespace facebook;
11+
12+
ExecutorchModule::ExecutorchModule(
13+
const std::string &modelSource,
14+
std::shared_ptr<react::CallInvoker> callInvoker)
15+
: module(std::make_unique<Module>(
16+
modelSource, Module::LoadMode::MmapUseMlockIgnoreErrors)),
17+
callInvoker(callInvoker) {
18+
Error loadError = module->load();
19+
if (loadError != Error::Ok) {
20+
throw std::runtime_error("Couldn't load the model, error: " +
21+
std::to_string(static_cast<uint32_t>(loadError)));
22+
}
23+
}
24+
25+
std::vector<int32_t> ExecutorchModule::getInputShape(std::string method_name,
26+
int index) {
27+
auto method_meta = module->method_meta(method_name);
28+
if (!method_meta.ok()) {
29+
throw std::runtime_error("Failed to load method with name " + method_name);
30+
}
31+
32+
std::vector<int32_t> input_shape;
33+
auto input_meta = method_meta->input_tensor_meta(index);
34+
if (!input_meta.ok()) {
35+
throw std::runtime_error("Failed to load forward input " +
36+
std::to_string(index));
37+
}
38+
39+
for (auto size : input_meta->sizes()) {
40+
input_shape.push_back(size);
41+
}
42+
return input_shape;
43+
}
44+
} // namespace rnexecutorch
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#pragma once
2+
3+
#include <string>
4+
5+
#include <ReactCommon/CallInvoker.h>
6+
#include <executorch/extension/module/module.h>
7+
#include <jsi/jsi.h>
8+
9+
namespace rnexecutorch {
10+
11+
class ExecutorchModule {
12+
public:
13+
ExecutorchModule(const std::string &modelSource,
14+
std::shared_ptr<facebook::react::CallInvoker> callInvoker);
15+
std::vector<int32_t> getInputShape(std::string method_name, int index);
16+
17+
protected:
18+
std::unique_ptr<executorch::extension::Module> module;
19+
std::shared_ptr<facebook::react::CallInvoker> callInvoker;
20+
};
21+
22+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ inline double getValue<double>(const jsi::Value &val, jsi::Runtime &runtime) {
1818
return val.asNumber();
1919
}
2020

21+
template <>
22+
inline int getValue<int>(const jsi::Value &val, jsi::Runtime &runtime) {
23+
return val.asNumber();
24+
}
25+
2126
template <>
2227
inline bool getValue<bool>(const jsi::Value &val, jsi::Runtime &runtime) {
2328
return val.asBool();
@@ -74,6 +79,15 @@ inline jsi::Value getJsiValue(std::shared_ptr<jsi::Object> valuePtr,
7479
return std::move(*valuePtr);
7580
}
7681

82+
inline jsi::Value getJsiValue(const std::vector<int32_t> &vec,
83+
jsi::Runtime &runtime) {
84+
jsi::Array array(runtime, vec.size());
85+
for (size_t i = 0; i < vec.size(); i++) {
86+
array.setValueAtIndex(runtime, i, jsi::Value(static_cast<int>(vec[i])));
87+
}
88+
return jsi::Value(runtime, array);
89+
}
90+
7791
inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) {
7892
return jsi::String::createFromAscii(runtime, str);
7993
}

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <rnexecutorch/host_objects/JsiConversions.h>
1212
#include <rnexecutorch/jsi/JsiHostObject.h>
1313
#include <rnexecutorch/jsi/Promise.h>
14+
#include <rnexecutorch/utils/TypeConstraints.h>
1415

1516
namespace rnexecutorch {
1617

@@ -19,9 +20,16 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
1920
explicit ModelHostObject(const std::shared_ptr<Model> &model,
2021
std::shared_ptr<react::CallInvoker> callInvoker)
2122
: model(model), callInvoker(callInvoker) {
22-
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
23-
promiseHostFunction<&Model::forward>,
24-
"forward"));
23+
if constexpr (HasForward<Model>) {
24+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
25+
promiseHostFunction<&Model::forward>,
26+
"forward"));
27+
}
28+
if constexpr (HasGetInputShape<Model>) {
29+
addFunctions(JSI_EXPORT_FUNCTION(
30+
ModelHostObject<Model>, promiseHostFunction<&Model::getInputShape>,
31+
"getInputShape"));
32+
}
2533
}
2634

2735
// A generic host function that resolves a promise with a result of a

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ BaseModel::BaseModel(const std::string &modelSource,
2020
}
2121
}
2222

23-
std::vector<std::vector<int32_t>> BaseModel::getInputShape() {
23+
std::vector<std::vector<int32_t>> BaseModel::getAllInputShapes() {
2424
auto method_meta = module->method_meta("forward");
2525

2626
if (!method_meta.ok()) {

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class BaseModel {
1313
public:
1414
BaseModel(const std::string &modelSource,
1515
std::shared_ptr<react::CallInvoker> callInvoker);
16-
std::vector<std::vector<int32_t>> getInputShape();
16+
std::vector<std::vector<int32_t>> getAllInputShapes();
1717

1818
protected:
1919
std::unique_ptr<executorch::extension::Module> module;

packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ ImageSegmentation::ImageSegmentation(
1515
const std::string &modelSource,
1616
std::shared_ptr<react::CallInvoker> callInvoker)
1717
: BaseModel(modelSource, callInvoker) {
18-
auto inputTensors = getInputShape();
18+
auto inputTensors = getAllInputShapes();
1919
if (inputTensors.size() == 0) {
2020
throw std::runtime_error("Model seems to not take any input tensors.");
2121
}
@@ -58,9 +58,9 @@ ImageSegmentation::preprocess(const std::string &imageSource) {
5858
cv::resize(input, input, modelImageSize);
5959

6060
std::vector<float> inputVector = imageprocessing::colorMatToVector(input);
61-
return {
62-
executorch::extension::make_tensor_ptr(getInputShape()[0], inputVector),
63-
inputSize};
61+
return {executorch::extension::make_tensor_ptr(getAllInputShapes()[0],
62+
inputVector),
63+
inputSize};
6464
}
6565

6666
std::shared_ptr<jsi::Object> ImageSegmentation::postprocess(

packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using executorch::runtime::Error;
1717
StyleTransfer::StyleTransfer(const std::string &modelSource,
1818
std::shared_ptr<react::CallInvoker> callInvoker)
1919
: BaseModel(modelSource, callInvoker) {
20-
auto inputTensors = getInputShape();
20+
auto inputTensors = getAllInputShapes();
2121
if (inputTensors.size() == 0) {
2222
throw std::runtime_error("Model seems to not take any input tensors.");
2323
}
@@ -40,7 +40,7 @@ StyleTransfer::preprocess(const std::string &imageSource) {
4040
auto originalSize = image.size();
4141
cv::resize(image, image, modelImageSize);
4242

43-
return {imageprocessing::getTensorFromMatrix(getInputShape()[0], image),
43+
return {imageprocessing::getTensorFromMatrix(getAllInputShapes()[0], image),
4444
originalSize};
4545
}
4646

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <concepts>
4+
5+
namespace rnexecutorch {
6+
template <typename T>
7+
concept HasForward = requires(T t) {
8+
{ &T::forward };
9+
};
10+
11+
template <typename T>
12+
concept HasMethodNames = requires(T t) {
13+
{ &T::methodNames };
14+
};
15+
16+
template <typename T>
17+
concept HasGetInputShape = requires(T t) {
18+
{ &T::getInputShape };
19+
};
20+
21+
template <typename T>
22+
concept HasIsLoaded = requires(T t) {
23+
{ &T::isLoaded };
24+
};
25+
} // namespace rnexecutorch

0 commit comments

Comments
 (0)