Skip to content

Commit a7e3e4e

Browse files
authored
fix: refactor C++ JSI promises (#286)
## Description In the C++ code we need some interface between raw JSI and our logic using JS promises. The current implementation uses too many nested lambdas, as well requires `std::function` bindings to each lambda. ### Type of change - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Related issues #255
1 parent 1a567c3 commit a7e3e4e

File tree

9 files changed

+151
-152
lines changed

9 files changed

+151
-152
lines changed

android/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ android {
142142
sourceCompatibility JavaVersion.VERSION_1_8
143143
targetCompatibility JavaVersion.VERSION_1_8
144144
}
145+
145146
}
146147

147148
repositories {

common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include <rnexecutorch/host_objects/JsiConversions.h>
44
#include <rnexecutorch/host_objects/ModelHostObject.h>
5-
#include <rnexecutorch/jsi/JsiPromise.h>
65
#include <rnexecutorch/models/StyleTransfer.h>
76

87
namespace rnexecutorch {
@@ -27,8 +26,8 @@ jsi::Function RnExecutorchInstaller::loadStyleTransfer(
2726
auto styleTransferPtr =
2827
std::make_shared<StyleTransfer>(source, &runtime);
2928
auto styleTransferHostObject =
30-
std::make_shared<ModelHostObject<StyleTransfer>>(
31-
styleTransferPtr, &runtime, jsCallInvoker);
29+
std::make_shared<ModelHostObject<StyleTransfer>>(styleTransferPtr,
30+
jsCallInvoker);
3231

3332
return jsi::Object::createFromHostObject(runtime,
3433
styleTransferHostObject);

common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,26 @@
55
#include <tuple>
66
#include <vector>
77

8+
#include <ReactCommon/CallInvoker.h>
9+
810
#include <rnexecutorch/Log.h>
911
#include <rnexecutorch/host_objects/JsiConversions.h>
1012
#include <rnexecutorch/jsi/JsiHostObject.h>
11-
#include <rnexecutorch/jsi/JsiPromise.h>
13+
#include <rnexecutorch/jsi/Promise.h>
1214

1315
namespace rnexecutorch {
1416

1517
template <typename Model> class ModelHostObject : public JsiHostObject {
1618
public:
17-
explicit ModelHostObject(
18-
const std::shared_ptr<Model> &model, jsi::Runtime *runtime,
19-
const std::shared_ptr<react::CallInvoker> &callInvoker)
20-
: model(model), promiseVendor(runtime, callInvoker) {
21-
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, forward));
19+
explicit ModelHostObject(const std::shared_ptr<Model> &model,
20+
std::shared_ptr<react::CallInvoker> callInvoker)
21+
: model(model), callInvoker(callInvoker) {
22+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>, forward));
2223
}
2324

2425
JSI_HOST_FUNCTION(forward) {
25-
auto promise = promiseVendor.createPromise(
26+
auto promise = Promise::createPromise(
27+
runtime, callInvoker,
2628
[this, count, args, &runtime](std::shared_ptr<Promise> promise) {
2729
constexpr std::size_t forwardArgCount =
2830
jsiconversion::getArgumentCount(&Model::forward);
@@ -32,48 +34,64 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
3234
errorMessage, sizeof(errorMessage),
3335
"Argument count mismatch, was expecting: %zu but got: %zu",
3436
forwardArgCount, count);
35-
3637
promise->reject(errorMessage);
3738
return;
3839
}
3940

40-
// Do the asynchronous work
41-
std::thread([this, promise = std::move(promise), args, &runtime]() {
42-
try {
43-
auto argsConverted = jsiconversion::createArgsTupleFromJsi(
44-
&Model::forward, args, runtime);
45-
auto result = std::apply(std::bind_front(&Model::forward, model),
46-
argsConverted);
41+
try {
42+
auto argsConverted = jsiconversion::createArgsTupleFromJsi(
43+
&Model::forward, args, runtime);
4744

48-
promise->resolve([result =
49-
std::move(result)](jsi::Runtime &runtime) {
50-
return jsiconversion::getJsiValue(std::move(result), runtime);
51-
});
52-
} catch (const std::runtime_error &e) {
53-
// This catch should be merged with the next one
54-
// (std::runtime_error inherits from std::exception) HOWEVER react
55-
// native has broken RTTI which breaks proper exception type
56-
// checking. Remove when the following change is present in our
57-
// version:
58-
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
59-
promise->reject(e.what());
60-
return;
61-
} catch (const std::exception &e) {
62-
promise->reject(e.what());
63-
return;
64-
} catch (...) {
65-
promise->reject("Unknown error");
66-
return;
67-
}
68-
}).detach();
45+
// We need to dispatch a thread if we want the forward to be
46+
// asynchronous. In this thread all accesses to jsi::Runtime need to
47+
// be done via the callInvoker.
48+
std::thread([this, promise,
49+
argsConverted = std::move(argsConverted)]() {
50+
try {
51+
auto result = std::apply(
52+
std::bind_front(&Model::forward, model), argsConverted);
53+
54+
callInvoker->invokeAsync([promise, result = std::move(result)](
55+
jsi::Runtime &runtime) {
56+
promise->resolve(
57+
jsiconversion::getJsiValue(std::move(result), runtime));
58+
});
59+
} catch (const std::runtime_error &e) {
60+
// This catch should be merged with the next two
61+
// (std::runtime_error and jsi::JSError inherits from
62+
// std::exception) HOWEVER react native has broken RTTI which
63+
// breaks proper exception type checking. Remove when the
64+
// following change is present in our version:
65+
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
66+
callInvoker->invokeAsync(
67+
[&e, promise]() { promise->reject(e.what()); });
68+
return;
69+
} catch (const jsi::JSError &e) {
70+
callInvoker->invokeAsync(
71+
[&e, promise]() { promise->reject(e.what()); });
72+
return;
73+
} catch (const std::exception &e) {
74+
callInvoker->invokeAsync(
75+
[&e, promise]() { promise->reject(e.what()); });
76+
return;
77+
} catch (...) {
78+
callInvoker->invokeAsync(
79+
[promise]() { promise->reject("Unknown error"); });
80+
return;
81+
}
82+
}).detach();
83+
} catch (...) {
84+
promise->reject(
85+
"Couldn't parse JS arguments in native forward function");
86+
}
6987
});
7088

7189
return promise;
7290
}
7391

7492
private:
7593
std::shared_ptr<Model> model;
76-
PromiseVendor promiseVendor;
94+
std::shared_ptr<react::CallInvoker> callInvoker;
7795
};
7896

7997
} // namespace rnexecutorch

common/rnexecutorch/jsi/JsiPromise.cpp

Lines changed: 0 additions & 60 deletions
This file was deleted.

common/rnexecutorch/jsi/JsiPromise.h

Lines changed: 0 additions & 48 deletions
This file was deleted.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "Promise.h"
2+
3+
namespace rnexecutorch {
4+
5+
Promise::Promise(jsi::Runtime &runtime,
6+
std::shared_ptr<react::CallInvoker> callInvoker,
7+
jsi::Value resolver, jsi::Value rejecter)
8+
: runtime(runtime), callInvoker(callInvoker),
9+
_resolver(std::move(resolver)), _rejecter(std::move(rejecter)) {}
10+
11+
void Promise::resolve(jsi::Value &&result) {
12+
_resolver.asObject(runtime).asFunction(runtime).call(runtime, result);
13+
}
14+
15+
void Promise::reject(std::string message) {
16+
jsi::JSError error(runtime, message);
17+
_rejecter.asObject(runtime).asFunction(runtime).call(runtime, error.value());
18+
}
19+
20+
} // namespace rnexecutorch

common/rnexecutorch/jsi/Promise.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
6+
#include <ReactCommon/CallInvoker.h>
7+
#include <jsi/jsi.h>
8+
9+
namespace rnexecutorch {
10+
11+
using namespace facebook;
12+
13+
class Promise;
14+
15+
template <typename T>
16+
concept PromiseRunFn =
17+
std::invocable<T, std::shared_ptr<Promise>> &&
18+
std::same_as<std::invoke_result_t<T, std::shared_ptr<Promise>>, void>;
19+
20+
class Promise {
21+
public:
22+
Promise(jsi::Runtime &runtime,
23+
std::shared_ptr<react::CallInvoker> callInvoker, jsi::Value resolver,
24+
jsi::Value rejecter);
25+
26+
Promise(const Promise &) = delete;
27+
Promise &operator=(const Promise &) = delete;
28+
29+
void resolve(jsi::Value &&result);
30+
void reject(std::string error);
31+
32+
/**
33+
Creates a new promise and runs the supplied "run" function that takes this
34+
promise. We use a template for the function type to not use std::function
35+
and be able to bind a lambda.
36+
*/
37+
template <PromiseRunFn Fn>
38+
static jsi::Value
39+
createPromise(jsi::Runtime &runtime,
40+
std::shared_ptr<react::CallInvoker> callInvoker, Fn &&run) {
41+
// Get Promise ctor from global
42+
auto promiseCtor =
43+
runtime.global().getPropertyAsFunction(runtime, "Promise");
44+
45+
auto promiseCallback = jsi::Function::createFromHostFunction(
46+
runtime, jsi::PropNameID::forUtf8(runtime, "PromiseCallback"), 2,
47+
[run = std::move(run),
48+
callInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue,
49+
const jsi::Value *arguments, size_t count) -> jsi::Value {
50+
// Call function
51+
auto promise = std::make_shared<Promise>(
52+
runtime, callInvoker, arguments[0].asObject(runtime),
53+
arguments[1].asObject(runtime));
54+
run(promise);
55+
56+
return jsi::Value::undefined();
57+
});
58+
59+
return promiseCtor.callAsConstructor(runtime, promiseCallback);
60+
}
61+
62+
private:
63+
jsi::Runtime &runtime;
64+
std::shared_ptr<react::CallInvoker> callInvoker;
65+
jsi::Value _resolver;
66+
jsi::Value _rejecter;
67+
};
68+
69+
} // namespace rnexecutorch

common/rnexecutorch/models/BaseModel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <rnexecutorch/models/BaseModel.h>
1+
#include "BaseModel.h"
22

33
#include <rnexecutorch/Log.h>
44

common/rnexecutorch/models/StyleTransfer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#include "StyleTransfer.h"
22

3+
#include <rnexecutorch/Log.h>
4+
#include <rnexecutorch/data_processing/ImageProcessing.h>
5+
36
#include <span>
47

58
#include <executorch/extension/tensor/tensor.h>
69
#include <opencv2/opencv.hpp>
710

8-
#include <rnexecutorch/Log.h>
9-
#include <rnexecutorch/data_processing/ImageProcessing.h>
10-
1111
namespace rnexecutorch {
1212
using namespace facebook;
1313
using executorch::extension::Module;

0 commit comments

Comments
 (0)