Skip to content

Commit a1e060c

Browse files
committed
wip
1 parent 029547b commit a1e060c

13 files changed

Lines changed: 188 additions & 80 deletions

File tree

common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "RnExecutorchInstaller.h"
22

3+
#include <rnexecutorch/bindings/ExecutorchModule.h>
34
#include <rnexecutorch/host_objects/JsiConversions.h>
45
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
56

@@ -19,5 +20,10 @@ void RnExecutorchInstaller::injectJSIBindings(
1920
*jsiRuntime, "loadStyleTransfer",
2021
RnExecutorchInstaller::loadModel<StyleTransfer>(jsiRuntime, jsCallInvoker,
2122
"loadStyleTransfer"));
23+
24+
jsiRuntime->global().setProperty(
25+
*jsiRuntime, "loadExecutorchModule",
26+
RnExecutorchInstaller::loadModel<ExecutorchModule>(
27+
jsiRuntime, jsCallInvoker, "loadExecutorchModule"));
2228
}
2329
} // namespace rnexecutorch

common/rnexecutorch/bindings/ExecutorchModule.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <fmt/core.h>
44
#include <rnexecutorch/Log.h>
5+
#include <unordered_set>
56

67
namespace rnexecutorch {
78

@@ -20,24 +21,28 @@ ExecutorchModule::ExecutorchModule(const std::string &modelSource,
2021
}
2122
}
2223

24+
std::unordered_set<std::string> ExecutorchModule::methodNames() {
25+
auto result = module->method_names();
26+
if (!result.ok()) {
27+
throw std::runtime_error("Failed to get method_names!");
28+
}
29+
return result.get();
30+
}
31+
32+
bool ExecutorchModule::isLoaded() { return module->is_loaded(); }
33+
2334
std::vector<int32_t> ExecutorchModule::getInputShape(std::string method_name,
2435
int index) {
2536
auto method_meta = module->method_meta(method_name);
2637
if (!method_meta.ok()) {
27-
throw std::runtime_error(
28-
fmt::format("Failed to load method with name {}", method_name));
38+
throw std::runtime_error("Failed to load method");
2939
}
3040

31-
std::vector<int32_t> input_shape;
3241
auto input_meta = method_meta->input_tensor_meta(index);
3342
if (!input_meta.ok()) {
34-
throw std::runtime_error(
35-
fmt::format("Failed to load forward input {}", index));
36-
}
37-
38-
for (auto size : input_meta->sizes()) {
39-
input_shape.push_back(size);
43+
throw std::runtime_error("Failed to load input for given method");
4044
}
41-
return input_shape;
45+
auto shape = input_meta->sizes();
46+
return std::vector<int32_t>(shape.begin(), shape.end());
4247
}
4348
} // namespace rnexecutorch

common/rnexecutorch/bindings/ExecutorchModule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class ExecutorchModule {
1313
ExecutorchModule(const std::string &modelSource,
1414
facebook::jsi::Runtime *runtime);
1515
std::vector<int32_t> getInputShape(std::string method_name, int index);
16+
std::unordered_set<std::string> methodNames();
17+
bool isLoaded();
1618

1719
protected:
1820
std::unique_ptr<executorch::extension::Module> module;

common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <jsi/jsi.h>
44
#include <type_traits>
5+
#include <unordered_set>
56

67
namespace rnexecutorch::jsiconversion {
78

@@ -11,6 +12,11 @@ using namespace facebook;
1112

1213
template <typename T> T getValue(const jsi::Value &val, jsi::Runtime &runtime);
1314

15+
template <>
16+
inline int getValue<int>(const jsi::Value &val, jsi::Runtime &runtime) {
17+
return val.asNumber();
18+
}
19+
1420
template <>
1521
inline double getValue<double>(const jsi::Value &val, jsi::Runtime &runtime) {
1622
return val.asNumber();
@@ -58,6 +64,26 @@ inline jsi::Value getJsiValue(jsi::Object &&value, jsi::Runtime &runtime) {
5864
return jsi::Value(std::move(value));
5965
}
6066

67+
inline jsi::Value getJsiValue(const std::vector<int32_t> &vec,
68+
jsi::Runtime &runtime) {
69+
jsi::Array array(runtime, vec.size());
70+
for (size_t i = 0; i < vec.size(); ++i) {
71+
array.setValueAtIndex(runtime, i, jsi::Value(static_cast<int>(vec[i])));
72+
}
73+
return jsi::Value(runtime, array);
74+
}
75+
76+
inline jsi::Value getJsiValue(const std::unordered_set<std::string> &uset,
77+
jsi::Runtime &runtime) {
78+
jsi::Array array(runtime, uset.size());
79+
size_t idx = 0;
80+
for (const auto &str : uset) {
81+
array.setValueAtIndex(runtime, idx++,
82+
jsi::String::createFromAscii(runtime, str));
83+
}
84+
return jsi::Value(runtime, array);
85+
}
86+
6187
inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) {
6288
return jsi::String::createFromAscii(runtime, str);
6389
}

common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <rnexecutorch/host_objects/JsiConversions.h>
1212
#include <rnexecutorch/jsi/JsiHostObject.h>
1313
#include <rnexecutorch/jsi/Promise.h>
14+
#include <rnexecutorch/utils/TypeConstraints.h>
1415

1516
namespace rnexecutorch {
1617

@@ -19,9 +20,28 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
1920
explicit ModelHostObject(const std::shared_ptr<Model> &model,
2021
std::shared_ptr<react::CallInvoker> callInvoker)
2122
: model(model), callInvoker(callInvoker) {
22-
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
23-
promiseHostFunction<&Model::forward>,
24-
"forward"));
23+
if constexpr (HasForward<Model>)
24+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
25+
promiseHostFunction<&Model::forward>,
26+
"forward"));
27+
28+
if constexpr (HasGetInputShape<Model>) {
29+
addFunctions(JSI_EXPORT_FUNCTION(
30+
ModelHostObject<Model>, promiseHostFunction<&Model::getInputShape>,
31+
"getInputShape"));
32+
}
33+
34+
if constexpr (HasMethodNames<Model>) {
35+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
36+
promiseHostFunction<&Model::methodNames>,
37+
"methodNames"));
38+
}
39+
40+
if constexpr (HasIsLoaded<Model>) {
41+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
42+
promiseHostFunction<&Model::isLoaded>,
43+
"isLoaded"));
44+
}
2545
}
2646

2747
// A generic host function that resolves a promise with a result of a

common/rnexecutorch/models/BaseModel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class BaseModel {
1111
public:
1212
BaseModel(const std::string &modelSource, facebook::jsi::Runtime *runtime);
1313
std::vector<std::vector<int32_t>> getInputShape();
14+
std::vector<int32_t> getInputShape(std::string method_name, int index);
1415

1516
protected:
1617
std::unique_ptr<executorch::extension::Module> module;
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <concepts>
4+
5+
namespace rnexecutorch {
6+
template <typename T>
7+
concept HasForward = requires(T t) {
8+
{ &T::forward };
9+
};
10+
11+
template <typename T>
12+
concept HasMethodNames = requires(T t) {
13+
{ &T::methodNames };
14+
};
15+
16+
template <typename T>
17+
concept HasGetInputShape = requires(T t) {
18+
{ &T::getInputShape };
19+
};
20+
21+
template <typename T>
22+
concept HasIsLoaded = requires(T t) {
23+
{ &T::isLoaded };
24+
};
25+
} // namespace rnexecutorch

examples/computer-vision/ios/Podfile.lock

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ PODS:
4343
- hermes-engine (0.76.9):
4444
- hermes-engine/Pre-built (= 0.76.9)
4545
- hermes-engine/Pre-built (0.76.9)
46-
- opencv-rne (0.1.0)
46+
- opencv-rne (4.11.0)
4747
- RCT-Folly (2024.10.14.00):
4848
- boost
4949
- DoubleConversion
@@ -1322,11 +1322,11 @@ PODS:
13221322
- ReactCommon/turbomodule/bridging
13231323
- ReactCommon/turbomodule/core
13241324
- Yoga
1325-
- react-native-executorch (0.3.2):
1325+
- react-native-executorch (0.3.1):
13261326
- DoubleConversion
13271327
- glog
13281328
- hermes-engine
1329-
- opencv-rne (~> 0.1.0)
1329+
- opencv-rne (~> 4.11.0)
13301330
- RCT-Folly (= 2024.10.14.00)
13311331
- RCTRequired
13321332
- RCTTypeSafety
@@ -1343,6 +1343,7 @@ PODS:
13431343
- ReactCodegen
13441344
- ReactCommon/turbomodule/bridging
13451345
- ReactCommon/turbomodule/core
1346+
- sqlite3
13461347
- Yoga
13471348
- react-native-image-picker (7.2.3):
13481349
- DoubleConversion
@@ -1857,6 +1858,9 @@ PODS:
18571858
- ReactCommon/turbomodule/core
18581859
- Yoga
18591860
- SocketRocket (0.7.1)
1861+
- sqlite3 (3.49.2):
1862+
- sqlite3/common (= 3.49.2)
1863+
- sqlite3/common (3.49.2)
18601864
- Yoga (0.0.0)
18611865

18621866
DEPENDENCIES:
@@ -1944,6 +1948,7 @@ SPEC REPOS:
19441948
trunk:
19451949
- opencv-rne
19461950
- SocketRocket
1951+
- sqlite3
19471952

19481953
EXTERNAL SOURCES:
19491954
boost:
@@ -2117,7 +2122,7 @@ SPEC CHECKSUMS:
21172122
fmt: 01b82d4ca6470831d1cc0852a1af644be019e8f6
21182123
glog: 08b301085f15bcbb6ff8632a8ebaf239aae04e6a
21192124
hermes-engine: 9e868dc7be781364296d6ee2f56d0c1a9ef0bb11
2120-
opencv-rne: 63e933ae2373fc91351f9a348dc46c3f523c2d3f
2125+
opencv-rne: 2305807573b6e29c8c87e3416ab096d09047a7a0
21212126
RCT-Folly: ea9d9256ba7f9322ef911169a9f696e5857b9e17
21222127
RCTDeprecation: ebe712bb05077934b16c6bf25228bdec34b64f83
21232128
RCTRequired: ca91e5dd26b64f577b528044c962baf171c6b716
@@ -2147,7 +2152,7 @@ SPEC CHECKSUMS:
21472152
React-logger: c4052eb941cca9a097ef01b59543a656dc088559
21482153
React-Mapbuffer: 33546a3ebefbccb8770c33a1f8a5554fa96a54de
21492154
React-microtasksnativemodule: d80ff86c8902872d397d9622f1a97aadcc12cead
2150-
react-native-executorch: 63ab47d8a0c602a4dee0acb36f6e8d9891ad357f
2155+
react-native-executorch: 0375510055856a3854e4128afc5db8126615254f
21512156
react-native-image-picker: dbc35687199a8bf89514e09b6b105557f9f63162
21522157
react-native-safe-area-context: cd916088cac5300c3266876218377518987b995e
21532158
react-native-skia: 9b4e1185bdc0d4e7e6488c5419b5643cc456dd2e
@@ -2181,6 +2186,7 @@ SPEC CHECKSUMS:
21812186
RNReanimated: 2e5069649cbab2c946652d3b97589b2ae0526220
21822187
RNSVG: b889dc9c1948eeea0576a16cc405c91c37a12c19
21832188
SocketRocket: d4aabe649be1e368d1318fdf28a022d714d65748
2189+
sqlite3: 3c950dc86011117c307eb0b28c4a7bb449dce9f1
21842190
Yoga: feb4910aba9742cfedc059e2b2902e22ffe9954a
21852191

21862192
PODFILE CHECKSUM: d2d76566c3147849493ab633854730a1f661227b

examples/computer-vision/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"metro-config": "^0.81.0",
1818
"react": "18.3.1",
1919
"react-native": "0.76.9",
20-
"react-native-executorch": "^0.3.2",
20+
"react-native-executorch": "/Users/jakubchmura/Desktop/SWM_AI/react-native-executorch/react-native-executorch-0.3.1-test2404.tgz",
2121
"react-native-image-picker": "^7.2.2",
2222
"react-native-loading-spinner-overlay": "^3.0.1",
2323
"react-native-reanimated": "~3.16.1",

examples/computer-vision/screens/StyleTransferScreen.tsx

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import { getImage } from '../utils';
44
import {
55
useStyleTransfer,
66
STYLE_TRANSFER_CANDY,
7+
NewExecutorchModule,
78
} from 'react-native-executorch';
89
import { View, StyleSheet, Image } from 'react-native';
10+
import { useEffect } from 'react';
911

1012
export const StyleTransferScreen = ({
1113
imageUri,
@@ -18,6 +20,23 @@ export const StyleTransferScreen = ({
1820
modelSource: STYLE_TRANSFER_CANDY,
1921
});
2022

23+
useEffect(() => {
24+
console.log('asdad');
25+
const loadModel = async () => {
26+
try {
27+
const model = new NewExecutorchModule();
28+
await model.load(STYLE_TRANSFER_CANDY);
29+
console.log(await model.getInputShape('forward', 0));
30+
console.log(await model.methodNames());
31+
console.log(await model.isLoaded());
32+
} catch (e) {
33+
console.error('Error loading model:', e);
34+
}
35+
};
36+
37+
loadModel();
38+
}, []);
39+
2140
const handleCameraPress = async (isCamera: boolean) => {
2241
const image = await getImage(isCamera);
2342
const uri = image?.uri;
@@ -30,7 +49,7 @@ export const StyleTransferScreen = ({
3049
if (imageUri) {
3150
try {
3251
const output = await model.forward(imageUri);
33-
setImageUri(output);
52+
setImageUri(output as string);
3453
} catch (e) {
3554
console.error(e);
3655
}

0 commit comments

Comments
 (0)