Skip to content

Commit 24f4d2a

Browse files
feat: port image segmentation native code to C++ (#313)
## Description Port image segmentation native code to C++. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Related issues #256 #255 ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings --------- Co-authored-by: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com>
1 parent 0442aa3 commit 24f4d2a

File tree

33 files changed

+440
-546
lines changed

33 files changed

+440
-546
lines changed

android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt

Lines changed: 0 additions & 58 deletions
This file was deleted.

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ class RnExecutorchPackage : TurboReactPackage() {
2828
OCR(reactContext)
2929
} else if (name == VerticalOCR.NAME) {
3030
VerticalOCR(reactContext)
31-
} else if (name == ImageSegmentation.NAME) {
32-
ImageSegmentation(reactContext)
3331
} else if (name == ETInstaller.NAME) {
3432
ETInstaller(reactContext)
3533
} else if (name == Tokenizer.NAME) {
@@ -119,17 +117,6 @@ class RnExecutorchPackage : TurboReactPackage() {
119117
true,
120118
)
121119

122-
moduleInfos[ImageSegmentation.NAME] =
123-
ReactModuleInfo(
124-
ImageSegmentation.NAME,
125-
ImageSegmentation.NAME,
126-
false, // canOverrideExistingModule
127-
false, // needsEagerInit
128-
true, // hasConstants
129-
false, // isCxxModule
130-
true,
131-
)
132-
133120
moduleInfos[Tokenizer.NAME] =
134121
ReactModuleInfo(
135122
Tokenizer.NAME,

android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt

Lines changed: 0 additions & 26 deletions
This file was deleted.

android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt

Lines changed: 0 additions & 139 deletions
This file was deleted.

common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "RnExecutorchInstaller.h"
22

33
#include <rnexecutorch/host_objects/JsiConversions.h>
4+
#include <rnexecutorch/models/image_segmentation/ImageSegmentation.h>
45
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
56

67
namespace rnexecutorch {
@@ -19,5 +20,10 @@ void RnExecutorchInstaller::injectJSIBindings(
1920
*jsiRuntime, "loadStyleTransfer",
2021
RnExecutorchInstaller::loadModel<StyleTransfer>(jsiRuntime, jsCallInvoker,
2122
"loadStyleTransfer"));
23+
24+
jsiRuntime->global().setProperty(
25+
*jsiRuntime, "loadImageSegmentation",
26+
RnExecutorchInstaller::loadModel<ImageSegmentation>(
27+
jsiRuntime, jsCallInvoker, "loadImageSegmentation"));
2228
}
2329
} // namespace rnexecutorch

common/rnexecutorch/RnExecutorchInstaller.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class RnExecutorchInstaller {
4949
jsiconversion::getValue<std::string>(args[0], runtime);
5050

5151
auto modelImplementationPtr =
52-
std::make_shared<ModelT>(source, &runtime);
52+
std::make_shared<ModelT>(source, jsCallInvoker);
5353
auto modelHostObject = std::make_shared<ModelHostObject<ModelT>>(
5454
modelImplementationPtr, jsCallInvoker);
5555

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include "Numerical.h"
2+
3+
#include <algorithm>
4+
#include <numeric>
5+
6+
namespace rnexecutorch::numerical {
7+
void softmax(std::vector<float> &v) {
8+
float max = *std::max_element(v.begin(), v.end());
9+
10+
float sum = 0.0f;
11+
for (float &x : v) {
12+
x = std::exp(x - max);
13+
sum += x;
14+
}
15+
for (float &x : v) {
16+
x /= sum;
17+
}
18+
}
19+
} // namespace rnexecutorch::numerical
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
namespace rnexecutorch::numerical {
6+
void softmax(std::vector<float> &v);
7+
} // namespace rnexecutorch::numerical

common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#pragma once
22

3-
#include <jsi/jsi.h>
3+
#include <set>
44
#include <type_traits>
55

6+
#include <jsi/jsi.h>
7+
68
namespace rnexecutorch::jsiconversion {
79

810
using namespace facebook;
@@ -43,19 +45,33 @@ getValue<std::vector<std::string>>(const jsi::Value &val,
4345
return result;
4446
}
4547

48+
// C++ set from JS array. Set with heterogenerous look-up (adding std::less<>
49+
// enables querying with std::string_view).
50+
template <>
51+
inline std::set<std::string, std::less<>>
52+
getValue<std::set<std::string, std::less<>>>(const jsi::Value &val,
53+
jsi::Runtime &runtime) {
54+
55+
jsi::Array array = val.asObject(runtime).asArray(runtime);
56+
size_t length = array.size(runtime);
57+
std::set<std::string, std::less<>> result;
58+
59+
for (size_t i = 0; i < length; ++i) {
60+
jsi::Value element = array.getValueAtIndex(runtime, i);
61+
result.insert(getValue<std::string>(element, runtime));
62+
}
63+
return result;
64+
}
65+
4666
// Conversion from C++ types to jsi --------------------------------------------
4767

4868
// Implementation functions might return any type, but in a promise we can only
4969
// return jsi::Value or jsi::Object. For each type being returned
5070
// we add a function here.
5171

52-
// Identity function for the sake of completeness
53-
inline jsi::Value getJsiValue(jsi::Value &&value, jsi::Runtime &runtime) {
54-
return std::move(value);
55-
}
56-
57-
inline jsi::Value getJsiValue(jsi::Object &&value, jsi::Runtime &runtime) {
58-
return jsi::Value(std::move(value));
72+
inline jsi::Value getJsiValue(std::shared_ptr<jsi::Object> valuePtr,
73+
jsi::Runtime &runtime) {
74+
return std::move(*valuePtr);
5975
}
6076

6177
inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) {

common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
5555
try {
5656
auto result =
5757
std::apply(std::bind_front(FnPtr, model), argsConverted);
58-
59-
callInvoker->invokeAsync([promise, result = std::move(result)](
60-
jsi::Runtime &runtime) {
58+
// The result is copied. It should either be quickly copiable,
59+
// or passed with a shared_ptr.
60+
callInvoker->invokeAsync([promise,
61+
result](jsi::Runtime &runtime) {
6162
promise->resolve(
6263
jsiconversion::getJsiValue(std::move(result), runtime));
6364
});

0 commit comments

Comments
 (0)