Skip to content

Commit dc34ec1

Browse files
authored
feat: port + refactor ExecutorchModule to C++ (#345)
## Description This PR aims to replace the existing ExecuTorch bindings for ones that leverage JSI and resemble the underlying runtime more accurately. Currently WIP. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [ ] iOS - [ ] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> #209 #338 - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 960dc3d commit dc34ec1

27 files changed

Lines changed: 451 additions & 129 deletions

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,11 @@ void RnExecutorchInstaller::injectJSIBindings(
3737
*jsiRuntime, "loadObjectDetection",
3838
RnExecutorchInstaller::loadModel<ObjectDetection>(
3939
jsiRuntime, jsCallInvoker, "loadObjectDetection"));
40+
41+
jsiRuntime->global().setProperty(
42+
*jsiRuntime, "loadExecutorchModule",
43+
RnExecutorchInstaller::loadModel<BaseModel>(jsiRuntime, jsCallInvoker,
44+
"loadExecutorchModule"));
4045
}
46+
4147
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <ReactCommon/CallInvoker.h>
88
#include <jsi/jsi.h>
99

10-
#include <rnexecutorch/TypeConstraints.h>
10+
#include <rnexecutorch/TypeConcepts.h>
1111
#include <rnexecutorch/host_objects/JsiConversions.h>
1212
#include <rnexecutorch/host_objects/ModelHostObject.h>
1313

@@ -26,7 +26,7 @@ class RnExecutorchInstaller {
2626
FetchUrlFunc_t fetchDataFromUrl);
2727

2828
private:
29-
template <DerivedFromBaseModel ModelT>
29+
template <DerivedFromOrSameAs<BaseModel> ModelT>
3030
static jsi::Function
3131
loadModel(jsi::Runtime *jsiRuntime,
3232
std::shared_ptr<react::CallInvoker> jsCallInvoker,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <concepts>
4+
#include <type_traits>
5+
6+
namespace rnexecutorch {
7+
8+
template <typename T, typename Base>
9+
concept DerivedFromOrSameAs = std::is_base_of_v<Base, T>;
10+
11+
template <typename T>
12+
concept HasGenerate = requires(T t) {
13+
{ &T::generate };
14+
};
15+
16+
template <typename T>
17+
concept IsNumeric = std::is_arithmetic_v<T>;
18+
19+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/TypeConstraints.h

Lines changed: 0 additions & 12 deletions
This file was deleted.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
namespace rnexecutorch {
4+
5+
using executorch::aten::ScalarType;
6+
7+
struct JSTensorViewIn {
8+
void *dataPtr;
9+
std::vector<int32_t> sizes;
10+
ScalarType scalarType;
11+
};
12+
} // namespace rnexecutorch
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#pragma once
2+
3+
#include <executorch/runtime/core/portable_type/scalar_type.h>
4+
#include <memory>
5+
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
6+
#include <vector>
7+
8+
namespace rnexecutorch {
9+
10+
using executorch::runtime::etensor::ScalarType;
11+
12+
struct JSTensorViewOut {
13+
std::shared_ptr<OwningArrayBuffer> dataPtr;
14+
std::vector<int32_t> sizes;
15+
ScalarType scalarType;
16+
17+
JSTensorViewOut(std::vector<int32_t> sizes, ScalarType scalarType,
18+
std::shared_ptr<OwningArrayBuffer> dataPtr)
19+
: sizes(std::move(sizes)), scalarType(scalarType),
20+
dataPtr(std::move(dataPtr)) {}
21+
};
22+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@
44
#include <type_traits>
55
#include <unordered_map>
66

7+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
78
#include <jsi/jsi.h>
9+
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
10+
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
11+
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
812

13+
#include <rnexecutorch/TypeConcepts.h>
914
#include <rnexecutorch/models/object_detection/Constants.h>
1015
#include <rnexecutorch/models/object_detection/Utils.h>
1116

@@ -17,9 +22,12 @@ using namespace facebook;
1722

1823
template <typename T> T getValue(const jsi::Value &val, jsi::Runtime &runtime);
1924

20-
template <>
21-
inline double getValue<double>(const jsi::Value &val, jsi::Runtime &runtime) {
22-
return val.asNumber();
25+
template <typename T>
26+
requires IsNumeric<T>
27+
inline T getValue(const jsi::Value &val, jsi::Runtime &runtime) {
28+
static_assert(std::is_integral<T>::value || std::is_floating_point<T>::value,
29+
"Only integral and floating-point types are supported");
30+
return static_cast<T>(val.asNumber());
2331
}
2432

2533
template <>
@@ -33,6 +41,78 @@ inline std::string getValue<std::string>(const jsi::Value &val,
3341
return val.getString(runtime).utf8(runtime);
3442
}
3543

44+
template <>
45+
inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val,
46+
jsi::Runtime &runtime) {
47+
jsi::Object obj = val.asObject(runtime);
48+
JSTensorViewIn tensorView;
49+
50+
int scalarTypeInt = obj.getProperty(runtime, "scalarType").asNumber();
51+
tensorView.scalarType = static_cast<ScalarType>(scalarTypeInt);
52+
53+
jsi::Value shapeValue = obj.getProperty(runtime, "sizes");
54+
jsi::Array shapeArray = shapeValue.asObject(runtime).asArray(runtime);
55+
size_t numShapeDims = shapeArray.size(runtime);
56+
tensorView.sizes.reserve(numShapeDims);
57+
58+
for (size_t i = 0; i < numShapeDims; ++i) {
59+
int dim = getValue<int>(shapeArray.getValueAtIndex(runtime, i), runtime);
60+
tensorView.sizes.push_back(static_cast<int32_t>(dim));
61+
}
62+
63+
// On JS side, TensorPtr objects hold a 'data' property which should be either
64+
// an ArrayBuffer or TypedArray
65+
jsi::Value dataValue = obj.getProperty(runtime, "dataPtr");
66+
jsi::Object dataObj = dataValue.asObject(runtime);
67+
68+
// Check if it's an ArrayBuffer or TypedArray
69+
if (dataObj.isArrayBuffer(runtime)) {
70+
jsi::ArrayBuffer arrayBuffer = dataObj.getArrayBuffer(runtime);
71+
tensorView.dataPtr = arrayBuffer.data(runtime);
72+
73+
} else {
74+
// Handle typed arrays (Float32Array, Int32Array, etc.)
75+
const bool isValidTypedArray = dataObj.hasProperty(runtime, "buffer") &&
76+
dataObj.hasProperty(runtime, "byteOffset") &&
77+
dataObj.hasProperty(runtime, "byteLength") &&
78+
dataObj.hasProperty(runtime, "length");
79+
if (!isValidTypedArray) {
80+
throw jsi::JSError(runtime, "Data must be an ArrayBuffer or TypedArray");
81+
}
82+
jsi::Value bufferValue = dataObj.getProperty(runtime, "buffer");
83+
if (!bufferValue.isObject() ||
84+
!bufferValue.asObject(runtime).isArrayBuffer(runtime)) {
85+
throw jsi::JSError(runtime,
86+
"TypedArray buffer property must be an ArrayBuffer");
87+
}
88+
89+
jsi::ArrayBuffer arrayBuffer =
90+
bufferValue.asObject(runtime).getArrayBuffer(runtime);
91+
size_t byteOffset =
92+
getValue<int>(dataObj.getProperty(runtime, "byteOffset"), runtime);
93+
94+
tensorView.dataPtr =
95+
static_cast<uint8_t *>(arrayBuffer.data(runtime)) + byteOffset;
96+
}
97+
return tensorView;
98+
}
99+
100+
template <>
101+
inline std::vector<JSTensorViewIn>
102+
getValue<std::vector<JSTensorViewIn>>(const jsi::Value &val,
103+
jsi::Runtime &runtime) {
104+
jsi::Array array = val.asObject(runtime).asArray(runtime);
105+
size_t length = array.size(runtime);
106+
std::vector<JSTensorViewIn> result;
107+
result.reserve(length);
108+
109+
for (size_t i = 0; i < length; ++i) {
110+
jsi::Value element = array.getValueAtIndex(runtime, i);
111+
result.push_back(getValue<JSTensorViewIn>(element, runtime));
112+
}
113+
return result;
114+
}
115+
36116
template <>
37117
inline std::vector<std::string>
38118
getValue<std::vector<std::string>>(const jsi::Value &val,
@@ -78,6 +158,51 @@ inline jsi::Value getJsiValue(std::shared_ptr<jsi::Object> valuePtr,
78158
return std::move(*valuePtr);
79159
}
80160

161+
inline jsi::Value getJsiValue(const std::vector<int32_t> &vec,
162+
jsi::Runtime &runtime) {
163+
jsi::Array array(runtime, vec.size());
164+
for (size_t i = 0; i < vec.size(); i++) {
165+
array.setValueAtIndex(runtime, i, jsi::Value(static_cast<int>(vec[i])));
166+
}
167+
return jsi::Value(runtime, array);
168+
}
169+
170+
inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) {
171+
return jsi::Value(runtime, val);
172+
}
173+
174+
inline jsi::Value
175+
getJsiValue(const std::vector<std::shared_ptr<OwningArrayBuffer>> &vec,
176+
jsi::Runtime &runtime) {
177+
jsi::Array array(runtime, vec.size());
178+
for (size_t i = 0; i < vec.size(); i++) {
179+
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]);
180+
array.setValueAtIndex(runtime, i, jsi::Value(runtime, arrayBuffer));
181+
}
182+
return jsi::Value(runtime, array);
183+
}
184+
185+
inline jsi::Value
186+
getJsiValue(const std::vector<std::shared_ptr<JSTensorViewOut>> &vec,
187+
jsi::Runtime &runtime) {
188+
jsi::Array array(runtime, vec.size());
189+
for (size_t i = 0; i < vec.size(); i++) {
190+
jsi::Object tensorObj(runtime);
191+
192+
tensorObj.setProperty(runtime, "sizes",
193+
getJsiValue(vec[i]->sizes, runtime));
194+
195+
tensorObj.setProperty(runtime, "scalarType",
196+
jsi::Value(static_cast<int>(vec[i]->scalarType)));
197+
198+
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->dataPtr);
199+
tensorObj.setProperty(runtime, "dataPtr", arrayBuffer);
200+
201+
array.setValueAtIndex(runtime, i, tensorObj);
202+
}
203+
return jsi::Value(runtime, array);
204+
}
205+
81206
inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) {
82207
return jsi::String::createFromAscii(runtime, str);
83208
}

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
#include <ReactCommon/CallInvoker.h>
99

1010
#include <rnexecutorch/Log.h>
11-
#include <rnexecutorch/TypeConstraints.h>
11+
#include <rnexecutorch/TypeConcepts.h>
12+
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
1213
#include <rnexecutorch/host_objects/JsiConversions.h>
1314
#include <rnexecutorch/jsi/JsiHostObject.h>
1415
#include <rnexecutorch/jsi/Promise.h>
16+
#include <rnexecutorch/models/BaseModel.h>
1517

1618
namespace rnexecutorch {
1719

@@ -20,13 +22,28 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
2022
explicit ModelHostObject(const std::shared_ptr<Model> &model,
2123
std::shared_ptr<react::CallInvoker> callInvoker)
2224
: model(model), callInvoker(callInvoker) {
23-
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
24-
promiseHostFunction<&Model::forward>,
25-
"forward"));
26-
if constexpr (DerivedFromBaseModel<Model>) {
25+
if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
2726
addFunctions(
2827
JSI_EXPORT_FUNCTION(ModelHostObject<Model>, unload, "unload"));
2928
}
29+
30+
if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
31+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
32+
promiseHostFunction<&Model::forwardJS>,
33+
"forward"));
34+
}
35+
36+
if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
37+
addFunctions(JSI_EXPORT_FUNCTION(
38+
ModelHostObject<Model>, promiseHostFunction<&Model::getInputShape>,
39+
"getInputShape"));
40+
}
41+
42+
if constexpr (HasGenerate<Model>) {
43+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
44+
promiseHostFunction<&Model::generate>,
45+
"generate"));
46+
}
3047
}
3148

3249
// A generic host function that resolves a promise with a result of a

0 commit comments

Comments
 (0)