Skip to content

Commit 26dc3a0

Browse files
committed
feat: port image segmentation to C++
1 parent 7d9abce commit 26dc3a0

File tree

28 files changed

+371
-573
lines changed

28 files changed

+371
-573
lines changed

android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt

Lines changed: 0 additions & 58 deletions
This file was deleted.

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ class RnExecutorchPackage : TurboReactPackage() {
2828
OCR(reactContext)
2929
} else if (name == VerticalOCR.NAME) {
3030
VerticalOCR(reactContext)
31-
} else if (name == ImageSegmentation.NAME) {
32-
ImageSegmentation(reactContext)
3331
} else if (name == ETInstaller.NAME) {
3432
ETInstaller(reactContext)
3533
} else if (name == Tokenizer.NAME) {
@@ -119,17 +117,6 @@ class RnExecutorchPackage : TurboReactPackage() {
119117
true,
120118
)
121119

122-
moduleInfos[ImageSegmentation.NAME] =
123-
ReactModuleInfo(
124-
ImageSegmentation.NAME,
125-
ImageSegmentation.NAME,
126-
false, // canOverrideExistingModule
127-
false, // needsEagerInit
128-
true, // hasConstants
129-
false, // isCxxModule
130-
true,
131-
)
132-
133120
moduleInfos[Tokenizer.NAME] =
134121
ReactModuleInfo(
135122
Tokenizer.NAME,

android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt

Lines changed: 0 additions & 26 deletions
This file was deleted.

android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt

Lines changed: 0 additions & 139 deletions
This file was deleted.
Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,26 @@
11
#include "RnExecutorchInstaller.h"
22

33
#include <rnexecutorch/host_objects/JsiConversions.h>
4-
#include <rnexecutorch/host_objects/ModelHostObject.h>
5-
#include <rnexecutorch/jsi/JsiPromise.h>
6-
#include <rnexecutorch/models/StyleTransfer.h>
4+
#include <rnexecutorch/models/image_segmentation/ImageSegmentation.h>
5+
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
76

87
namespace rnexecutorch {
98

109
FetchUrlFunc_t fetchUrlFunc;
1110

12-
jsi::Function RnExecutorchInstaller::loadStyleTransfer(
13-
jsi::Runtime *jsiRuntime,
14-
const std::shared_ptr<react::CallInvoker> &jsCallInvoker) {
15-
return jsi::Function::createFromHostFunction(
16-
*jsiRuntime, jsi::PropNameID::forAscii(*jsiRuntime, "loadStyleTransfer"),
17-
0,
18-
[jsCallInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue,
19-
const jsi::Value *args, size_t count) -> jsi::Value {
20-
assert(count == 1);
21-
auto source = jsiconversion::getValue<std::string>(args[0], runtime);
22-
23-
auto styleTransferPtr =
24-
std::make_shared<StyleTransfer>(source, &runtime);
25-
auto styleTransferHostObject =
26-
std::make_shared<ModelHostObject<StyleTransfer>>(
27-
styleTransferPtr, &runtime, jsCallInvoker);
28-
29-
return jsi::Object::createFromHostObject(runtime,
30-
styleTransferHostObject);
31-
});
32-
}
33-
3411
void RnExecutorchInstaller::injectJSIBindings(
35-
jsi::Runtime *jsiRuntime,
36-
const std::shared_ptr<react::CallInvoker> &jsCallInvoker,
12+
jsi::Runtime *jsiRuntime, std::shared_ptr<react::CallInvoker> jsCallInvoker,
3713
FetchUrlFunc_t fetchDataFromUrl) {
3814
fetchUrlFunc = fetchDataFromUrl;
3915

4016
jsiRuntime->global().setProperty(
4117
*jsiRuntime, "loadStyleTransfer",
42-
loadStyleTransfer(jsiRuntime, jsCallInvoker));
18+
RnExecutorchInstaller::loadModel<StyleTransfer>(jsiRuntime, jsCallInvoker,
19+
"loadStyleTransfer"));
20+
21+
jsiRuntime->global().setProperty(
22+
*jsiRuntime, "loadImageSegmentation",
23+
RnExecutorchInstaller::loadModel<ImageSegmentation>(
24+
jsiRuntime, jsCallInvoker, "loadImageSegmentation"));
4325
}
4426
} // namespace rnexecutorch

common/rnexecutorch/RnExecutorchInstaller.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,47 @@
77
#include <ReactCommon/CallInvoker.h>
88
#include <jsi/jsi.h>
99

10+
#include <rnexecutorch/host_objects/JsiConversions.h>
11+
#include <rnexecutorch/host_objects/ModelHostObject.h>
12+
1013
namespace rnexecutorch {
1114

1215
using FetchUrlFunc_t = std::function<std::vector<std::byte>(std::string)>;
16+
extern FetchUrlFunc_t fetchUrlFunc;
1317

1418
using namespace facebook;
1519

1620
class RnExecutorchInstaller {
1721
public:
1822
static void
1923
injectJSIBindings(jsi::Runtime *jsiRuntime,
20-
const std::shared_ptr<react::CallInvoker> &jsCallInvoker,
24+
std::shared_ptr<react::CallInvoker> jsCallInvoker,
2125
FetchUrlFunc_t fetchDataFromUrl);
2226

2327
private:
28+
template <typename ModelT>
2429
static jsi::Function
25-
loadStyleTransfer(jsi::Runtime *jsiRuntime,
26-
const std::shared_ptr<react::CallInvoker> &jsCallInvoker);
30+
loadModel(jsi::Runtime *jsiRuntime,
31+
std::shared_ptr<react::CallInvoker> jsCallInvoker,
32+
const std::string &loadFunctionName) {
33+
return jsi::Function::createFromHostFunction(
34+
*jsiRuntime,
35+
jsi::PropNameID::forAscii(*jsiRuntime, loadFunctionName.c_str()), 0,
36+
[jsCallInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue,
37+
const jsi::Value *args, size_t count) -> jsi::Value {
38+
// We expect a single input -- the path to the model binary
39+
assert(count == 1);
40+
auto source = jsiconversion::getValue<std::string>(args[0], runtime);
41+
42+
auto modelImplementationPtr =
43+
std::make_shared<ModelT>(source, &runtime);
44+
auto modelHostObject = std::make_shared<ModelHostObject<ModelT>>(
45+
modelImplementationPtr, jsCallInvoker);
46+
47+
return jsi::Object::createFromHostObject(runtime, modelHostObject);
48+
return jsi::Value();
49+
});
50+
}
2751
};
2852

2953
} // namespace rnexecutorch
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include "Numerical.h"
2+
3+
#include <algorithm>
4+
#include <numeric>
5+
6+
namespace rnexecutorch::numerical {
7+
void softmax(std::vector<float> &v) {
8+
float max = *std::max_element(v.begin(), v.end());
9+
10+
for (float &x : v) {
11+
x = std::exp(x - max);
12+
}
13+
float sum = std::accumulate(v.begin(), v.end(), 0.f);
14+
for (float &x : v) {
15+
x /= sum;
16+
}
17+
}
18+
} // namespace rnexecutorch::numerical
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
namespace rnexecutorch::numerical {
6+
void softmax(std::vector<float> &v);
7+
}

0 commit comments

Comments
 (0)