Skip to content

Commit c2064bb

Browse files
committed
fix: adapt segmentation for refactored promises
1 parent 4a786da commit c2064bb

File tree

9 files changed

+91
-80
lines changed

9 files changed

+91
-80
lines changed

common/rnexecutorch/RnExecutorchInstaller.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class RnExecutorchInstaller {
4949
jsiconversion::getValue<std::string>(args[0], runtime);
5050

5151
auto modelImplementationPtr =
52-
std::make_shared<ModelT>(source, &runtime);
52+
std::make_shared<ModelT>(source, jsCallInvoker);
5353
auto modelHostObject = std::make_shared<ModelHostObject<ModelT>>(
5454
modelImplementationPtr, jsCallInvoker);
5555

common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ getValue<std::set<std::string, std::less<>>>(const jsi::Value &val,
7171
// we add a function here.
7272

7373
// Identity function for the sake of completeness
74-
inline jsi::Value getJsiValue(jsi::Value &&value, jsi::Runtime &runtime) {
75-
return std::move(value);
74+
inline jsi::Value getJsiValue(std::unique_ptr<jsi::Object> &&valuePtr,
75+
jsi::Runtime &runtime) {
76+
return std::move(*valuePtr);
7677
}
7778

7879
inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) {

common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
5151
auto result = std::apply(
5252
std::bind_front(&Model::forward, model), argsConverted);
5353

54-
callInvoker->invokeAsync([promise, result = std::move(result)](
55-
jsi::Runtime &runtime) {
54+
callInvoker->invokeSync([promise,
55+
&result](jsi::Runtime &runtime) {
5656
promise->resolve(
5757
jsiconversion::getJsiValue(std::move(result), runtime));
5858
});

common/rnexecutorch/models/BaseModel.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
namespace rnexecutorch {
66

7+
using namespace facebook;
78
using ::executorch::extension::Module;
89
using ::executorch::runtime::Error;
910

1011
BaseModel::BaseModel(const std::string &modelSource,
11-
facebook::jsi::Runtime *runtime)
12+
std::shared_ptr<react::CallInvoker> callInvoker)
1213
: module(std::make_unique<Module>(
1314
modelSource, Module::LoadMode::MmapUseMlockIgnoreErrors)),
14-
runtime(runtime) {
15+
callInvoker(callInvoker) {
1516
Error loadError = module->load();
1617
if (loadError != Error::Ok) {
1718
throw std::runtime_error("Couldn't load the model, error: " +

common/rnexecutorch/models/BaseModel.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,25 @@
22

33
#include <string>
44

5+
#include <ReactCommon/CallInvoker.h>
56
#include <executorch/extension/module/module.h>
67
#include <jsi/jsi.h>
78

89
namespace rnexecutorch {
10+
using namespace facebook;
911

1012
class BaseModel {
1113
public:
12-
BaseModel(const std::string &modelSource, facebook::jsi::Runtime *runtime);
14+
BaseModel(const std::string &modelSource,
15+
std::shared_ptr<react::CallInvoker> callInvoker);
1316
std::vector<std::vector<int32_t>> getInputShape();
1417

1518
protected:
1619
std::unique_ptr<executorch::extension::Module> module;
17-
// If possible, models should not use the runtime to keep JSI internals away
18-
// from logic, however, sometimes this would incur too big of a penalty
19-
// (unnecessary copies). This is in BaseModel so that we can generalize JSI
20-
// loader method installation.
21-
facebook::jsi::Runtime *runtime;
20+
// If possible, models should not use the JS runtime to keep JSI internals
21+
// away from logic, however, sometimes this would incur too big of a penalty
22+
// (unnecessary copies instead of working on JS memory). In this case
23+
// CallInvoker can be used to get jsi::Runtime, and use it in a safe manner.
24+
std::shared_ptr<react::CallInvoker> callInvoker;
2225
};
2326
} // namespace rnexecutorch

common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp

Lines changed: 57 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@
99

1010
namespace rnexecutorch {
1111

12-
ImageSegmentation::ImageSegmentation(const std::string &modelSource,
13-
jsi::Runtime *runtime)
14-
: BaseModel(modelSource, runtime) {
12+
ImageSegmentation::ImageSegmentation(
13+
const std::string &modelSource,
14+
std::shared_ptr<react::CallInvoker> callInvoker)
15+
: BaseModel(modelSource, callInvoker) {
1516

16-
std::vector<int32_t> modelInputShape = getInputShape();
17+
std::vector<int32_t> modelInputShape = getInputShape()[0];
1718
modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1],
1819
modelInputShape[modelInputShape.size() - 2]);
1920
numModelPixels = modelImageSize.area();
2021
}
2122

22-
jsi::Value
23+
std::unique_ptr<jsi::Object>
2324
ImageSegmentation::forward(std::string imageSource,
2425
std::set<std::string, std::less<>> classesOfInterest,
2526
bool resize) {
@@ -36,7 +37,20 @@ ImageSegmentation::forward(std::string imageSource,
3637
classesOfInterest, resize);
3738
}
3839

39-
jsi::Value ImageSegmentation::postprocess(
40+
std::pair<TensorPtr, cv::Size>
41+
ImageSegmentation::preprocess(const std::string &imageSource) {
42+
cv::Mat input = imageprocessing::readImage(imageSource);
43+
cv::Size inputSize = input.size();
44+
45+
cv::resize(input, input, modelImageSize);
46+
47+
std::vector<float> inputVector = imageprocessing::colorMatToVector(input);
48+
return {
49+
executorch::extension::make_tensor_ptr(getInputShape()[0], inputVector),
50+
inputSize};
51+
}
52+
53+
std::unique_ptr<jsi::Object> ImageSegmentation::postprocess(
4054
const Tensor &tensor, cv::Size originalSize,
4155
std::set<std::string, std::less<>> classesOfInterest, bool resize) {
4256

@@ -84,11 +98,11 @@ jsi::Value ImageSegmentation::postprocess(
8498
reinterpret_cast<int32_t *>(argmax->data())[pixel] = maxInd;
8599
}
86100

87-
std::unordered_map<std::string_view, std::shared_ptr<OwningArrayBuffer>>
88-
buffersToReturn;
101+
auto buffersToReturn = std::make_shared<std::unordered_map<
102+
std::string_view, std::shared_ptr<OwningArrayBuffer>>>();
89103
for (std::size_t cl = 0; cl < numClasses; ++cl) {
90104
if (classesOfInterest.contains(deeplabv3_resnet50_labels[cl])) {
91-
buffersToReturn[deeplabv3_resnet50_labels[cl]] = resultClasses[cl];
105+
(*buffersToReturn)[deeplabv3_resnet50_labels[cl]] = resultClasses[cl];
92106
}
93107
}
94108

@@ -102,7 +116,7 @@ jsi::Value ImageSegmentation::postprocess(
102116
std::memcpy(argmax->data(), argmaxMat.data,
103117
originalSize.area() * sizeof(int32_t));
104118

105-
for (auto &[label, arrayBuffer] : buffersToReturn) {
119+
for (auto &[label, arrayBuffer] : *buffersToReturn) {
106120
cv::Mat classMat(modelImageSize, CV_32FC1, arrayBuffer->data());
107121
cv::resize(classMat, classMat, originalSize);
108122
arrayBuffer = std::make_shared<OwningArrayBuffer>(originalSize.area() *
@@ -114,53 +128,41 @@ jsi::Value ImageSegmentation::postprocess(
114128
return populateDictionary(argmax, buffersToReturn);
115129
}
116130

117-
jsi::Value ImageSegmentation::populateDictionary(
131+
std::unique_ptr<jsi::Object> ImageSegmentation::populateDictionary(
118132
std::shared_ptr<OwningArrayBuffer> argmax,
119-
std::unordered_map<std::string_view, std::shared_ptr<OwningArrayBuffer>>
133+
std::shared_ptr<std::unordered_map<std::string_view,
134+
std::shared_ptr<OwningArrayBuffer>>>
120135
classesToOutput) {
121-
jsi::Object dict(*runtime);
122-
123-
auto argmaxArrayBuffer = jsi::ArrayBuffer(*runtime, argmax);
124-
125-
auto int32ArrayCtor =
126-
runtime->global().getPropertyAsFunction(*runtime, "Int32Array");
127-
auto int32Array =
128-
int32ArrayCtor.callAsConstructor(*runtime, argmaxArrayBuffer)
129-
.getObject(*runtime);
130-
dict.setProperty(*runtime, "ARGMAX", int32Array);
131-
132-
std::size_t dictIndex = 1;
133-
for (auto &[classLabel, owningBuffer] : classesToOutput) {
134-
auto classArrayBuffer = jsi::ArrayBuffer(*runtime, owningBuffer);
135-
136-
auto float32ArrayCtor =
137-
runtime->global().getPropertyAsFunction(*runtime, "Float32Array");
138-
auto float32Array =
139-
float32ArrayCtor.callAsConstructor(*runtime, classArrayBuffer)
140-
.getObject(*runtime);
141-
142-
dict.setProperty(*runtime,
143-
jsi::String::createFromAscii(*runtime, classLabel.data()),
144-
float32Array);
145-
}
146-
return dict;
147-
}
148-
149-
std::pair<TensorPtr, cv::Size>
150-
ImageSegmentation::preprocess(const std::string &imageSource) {
151-
cv::Mat input = imageprocessing::readImage(imageSource);
152-
cv::Size inputSize = input.size();
153-
154-
std::vector<int32_t> modelInputShape = getInputShape();
155-
cv::Size modelImageSize =
156-
cv::Size(modelInputShape[modelInputShape.size() - 1],
157-
modelInputShape[modelInputShape.size() - 2]);
158-
159-
cv::resize(input, input, modelImageSize);
160-
161-
std::vector<float> inputVector = imageprocessing::colorMatToVector(input);
162-
return {executorch::extension::make_tensor_ptr(modelInputShape, inputVector),
163-
inputSize};
136+
std::unique_ptr<jsi::Object> dictPtr;
137+
138+
callInvoker->invokeSync(
139+
[argmax, classesToOutput, &dictPtr](jsi::Runtime &runtime) {
140+
dictPtr = std::make_unique<jsi::Object>(runtime);
141+
auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax);
142+
143+
auto int32ArrayCtor =
144+
runtime.global().getPropertyAsFunction(runtime, "Int32Array");
145+
auto int32Array =
146+
int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer)
147+
.getObject(runtime);
148+
dictPtr->setProperty(runtime, "ARGMAX", int32Array);
149+
150+
for (auto &[classLabel, owningBuffer] : *classesToOutput) {
151+
auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer);
152+
153+
auto float32ArrayCtor =
154+
runtime.global().getPropertyAsFunction(runtime, "Float32Array");
155+
auto float32Array =
156+
float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer)
157+
.getObject(runtime);
158+
159+
dictPtr->setProperty(
160+
runtime, jsi::String::createFromAscii(runtime, classLabel.data()),
161+
float32Array);
162+
}
163+
});
164+
165+
return dictPtr;
164166
}
165167

166168
} // namespace rnexecutorch

common/rnexecutorch/models/image_segmentation/ImageSegmentation.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,22 @@ using executorch::extension::TensorPtr;
1919

2020
class ImageSegmentation : public BaseModel {
2121
public:
22-
ImageSegmentation(const std::string &modelSource, jsi::Runtime *runtime);
23-
jsi::Value forward(std::string imageSource,
24-
std::set<std::string, std::less<>> classesOfInterest,
25-
bool resize);
22+
ImageSegmentation(const std::string &modelSource,
23+
std::shared_ptr<react::CallInvoker> callInvoker);
24+
std::unique_ptr<jsi::Object>
25+
forward(std::string imageSource,
26+
std::set<std::string, std::less<>> classesOfInterest, bool resize);
2627

2728
private:
2829
std::pair<TensorPtr, cv::Size> preprocess(const std::string &imageSource);
29-
jsi::Value postprocess(const Tensor &tensor, cv::Size originalSize,
30-
std::set<std::string, std::less<>> classesOfInterest,
31-
bool resize);
32-
jsi::Value populateDictionary(
30+
std::unique_ptr<jsi::Object>
31+
postprocess(const Tensor &tensor, cv::Size originalSize,
32+
std::set<std::string, std::less<>> classesOfInterest,
33+
bool resize);
34+
std::unique_ptr<jsi::Object> populateDictionary(
3335
std::shared_ptr<OwningArrayBuffer> argmax,
34-
std::unordered_map<std::string_view, std::shared_ptr<OwningArrayBuffer>>
36+
std::shared_ptr<std::unordered_map<std::string_view,
37+
std::shared_ptr<OwningArrayBuffer>>>
3538
classesToOutput);
3639

3740
static constexpr std::size_t numClasses{deeplabv3_resnet50_labels.size()};

common/rnexecutorch/models/style_transfer/StyleTransfer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ using executorch::extension::TensorPtr;
1515
using executorch::runtime::Error;
1616

1717
StyleTransfer::StyleTransfer(const std::string &modelSource,
18-
jsi::Runtime *runtime)
19-
: BaseModel(modelSource, runtime) {
18+
std::shared_ptr<react::CallInvoker> callInvoker)
19+
: BaseModel(modelSource, callInvoker) {
2020
std::vector<int32_t> modelInputShape = getInputShape()[0];
2121
modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1],
2222
modelInputShape[modelInputShape.size() - 2]);

common/rnexecutorch/models/style_transfer/StyleTransfer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ using executorch::extension::TensorPtr;
1717

1818
class StyleTransfer : public BaseModel {
1919
public:
20-
StyleTransfer(const std::string &modelSource, jsi::Runtime *runtime);
20+
StyleTransfer(const std::string &modelSource,
21+
std::shared_ptr<react::CallInvoker> callInvoker);
2122
std::string forward(std::string imageSource);
2223

2324
private:

0 commit comments

Comments
 (0)