Skip to content

Commit 27e10ef

Browse files
committed
feat: port image segmentation to C++
1 parent 02c9d0e commit 27e10ef

File tree

25 files changed

+355
-524
lines changed

25 files changed

+355
-524
lines changed

android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt

Lines changed: 0 additions & 58 deletions
This file was deleted.

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ class RnExecutorchPackage : TurboReactPackage() {
2828
OCR(reactContext)
2929
} else if (name == VerticalOCR.NAME) {
3030
VerticalOCR(reactContext)
31-
} else if (name == ImageSegmentation.NAME) {
32-
ImageSegmentation(reactContext)
3331
} else if (name == ETInstaller.NAME) {
3432
ETInstaller(reactContext)
3533
} else if (name == Tokenizer.NAME) {
@@ -119,17 +117,6 @@ class RnExecutorchPackage : TurboReactPackage() {
119117
true,
120118
)
121119

122-
moduleInfos[ImageSegmentation.NAME] =
123-
ReactModuleInfo(
124-
ImageSegmentation.NAME,
125-
ImageSegmentation.NAME,
126-
false, // canOverrideExistingModule
127-
false, // needsEagerInit
128-
true, // hasConstants
129-
false, // isCxxModule
130-
true,
131-
)
132-
133120
moduleInfos[Tokenizer.NAME] =
134121
ReactModuleInfo(
135122
Tokenizer.NAME,

android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt

Lines changed: 0 additions & 26 deletions
This file was deleted.

android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt

Lines changed: 0 additions & 139 deletions
This file was deleted.

common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "RnExecutorchInstaller.h"
22

33
#include <rnexecutorch/host_objects/JsiConversions.h>
4+
#include <rnexecutorch/models/image_segmentation/ImageSegmentation.h>
45
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
56

67
namespace rnexecutorch {
@@ -19,5 +20,10 @@ void RnExecutorchInstaller::injectJSIBindings(
1920
*jsiRuntime, "loadStyleTransfer",
2021
RnExecutorchInstaller::loadModel<StyleTransfer>(jsiRuntime, jsCallInvoker,
2122
"loadStyleTransfer"));
23+
24+
jsiRuntime->global().setProperty(
25+
*jsiRuntime, "loadImageSegmentation",
26+
RnExecutorchInstaller::loadModel<ImageSegmentation>(
27+
jsiRuntime, jsCallInvoker, "loadImageSegmentation"));
2228
}
2329
} // namespace rnexecutorch
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include "Numerical.h"
2+
3+
#include <algorithm>
4+
#include <numeric>
5+
6+
namespace rnexecutorch::numerical {
7+
void softmax(std::vector<float> &v) {
8+
float max = *std::max_element(v.begin(), v.end());
9+
10+
for (float &x : v) {
11+
x = std::exp(x - max);
12+
}
13+
float sum = std::accumulate(v.begin(), v.end(), 0.f);
14+
for (float &x : v) {
15+
x /= sum;
16+
}
17+
}
18+
} // namespace rnexecutorch::numerical
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
namespace rnexecutorch::numerical {
6+
void softmax(std::vector<float> &v);
7+
}

common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#pragma once
22

3-
#include <jsi/jsi.h>
3+
#include <set>
44
#include <type_traits>
55

6+
#include <jsi/jsi.h>
7+
68
namespace rnexecutorch::jsiconversion {
79

810
using namespace facebook;
@@ -43,6 +45,25 @@ getValue<std::vector<std::string>>(const jsi::Value &val,
4345
return result;
4446
}
4547

48+
// Set with heterogenerous look-up (adding std::less<> enables querying
49+
// with std::string_view)
50+
template <>
51+
inline std::set<std::string, std::less<>>
52+
getValue<std::set<std::string, std::less<>>>(const jsi::Value &val,
53+
jsi::Runtime &runtime) {
54+
// C++ set from JS array
55+
56+
jsi::Array array = val.asObject(runtime).asArray(runtime);
57+
size_t length = array.size(runtime);
58+
std::set<std::string, std::less<>> result;
59+
60+
for (size_t i = 0; i < length; ++i) {
61+
jsi::Value element = array.getValueAtIndex(runtime, i);
62+
result.insert(getValue<std::string>(element, runtime));
63+
}
64+
return result;
65+
}
66+
4667
// Conversion from C++ types to jsi --------------------------------------------
4768

4869
// Implementation functions might return any type, but in a promise we can only
@@ -54,10 +75,6 @@ inline jsi::Value getJsiValue(jsi::Value &&value, jsi::Runtime &runtime) {
5475
return std::move(value);
5576
}
5677

57-
inline jsi::Value getJsiValue(jsi::Object &&value, jsi::Runtime &runtime) {
58-
return jsi::Value(std::move(value));
59-
}
60-
6178
inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) {
6279
return jsi::String::createFromAscii(runtime, str);
6380
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <jsi/jsi.h>
4+
5+
namespace rnexecutorch {
6+
7+
using namespace facebook;
8+
9+
class OwningArrayBuffer : public jsi::MutableBuffer {
10+
public:
11+
OwningArrayBuffer(const size_t size) : size_(size) {
12+
data_ = new uint8_t[size];
13+
}
14+
~OwningArrayBuffer() override { delete[] data_; }
15+
16+
OwningArrayBuffer(const OwningArrayBuffer &) = delete;
17+
OwningArrayBuffer(OwningArrayBuffer &&) = delete;
18+
OwningArrayBuffer &operator=(const OwningArrayBuffer &) = delete;
19+
OwningArrayBuffer &operator=(OwningArrayBuffer &&) = delete;
20+
21+
[[nodiscard]] size_t size() const override { return size_; }
22+
uint8_t *data() override { return data_; }
23+
24+
private:
25+
uint8_t *data_;
26+
const size_t size_;
27+
};
28+
29+
} // namespace rnexecutorch

common/rnexecutorch/models/BaseModel.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ class BaseModel {
1414

1515
protected:
1616
std::unique_ptr<executorch::extension::Module> module;
17+
// If possible, models should not use the runtime to keep JSI internals away
18+
// from logic, however, sometimes this would incur too big of a penalty
19+
// (unnecessary copies). This is in BaseModel so that we can generalize JSI
20+
// loader method installation.
1721
facebook::jsi::Runtime *runtime;
1822
};
19-
2023
} // namespace rnexecutorch

0 commit comments

Comments
 (0)