Skip to content

Commit 2f3c32a

Browse files
authored
feat: add generic model loading (#395)
## Description In `RnExecutorchInstaller` we can use `loadModel<T>` to install to JSI a method that will load our model without writing the same code for each module. The problem is that this method expects that each module needs exactly a string and a CallInvoker in the constructor, which restricts usage of the method for modules which would have multiple sources or other arguments (e.g. OCR or encoder/decoder modules). At the same time, even if modules have different input on load, the code would look nearly identical, so a generic solution is needed for this case. This PR allows for generic loading of modules if the type info about the constructor is supplied by `REGISTER_CONSTRUCTOR` macro. This allows for a generic conversion of arguments from JSI, similar to `ModelHostObject`. The machinery required for this (`ConstructorHelpers.h`) is admittedly more complicated than in the case of `ModelHostObject` due to the fact that we need to specify the types ourselves (we cannot grab method pointers to constructors to infer the types), and due to having `CallInvoker` as the last argument. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Related issues #255 ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings
1 parent 329014a commit 2f3c32a

6 files changed

Lines changed: 214 additions & 40 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,28 @@
77
#include <ReactCommon/CallInvoker.h>
88
#include <jsi/jsi.h>
99

10-
#include <rnexecutorch/TypeConcepts.h>
1110
#include <rnexecutorch/host_objects/JsiConversions.h>
1211
#include <rnexecutorch/host_objects/ModelHostObject.h>
12+
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>
13+
#include <rnexecutorch/metaprogramming/FunctionHelpers.h>
14+
#include <rnexecutorch/metaprogramming/TypeConcepts.h>
1315

1416
namespace rnexecutorch {
1517

1618
using FetchUrlFunc_t = std::function<std::vector<std::byte>(std::string)>;
1719
extern FetchUrlFunc_t fetchUrlFunc;
1820

21+
REGISTER_CONSTRUCTOR(StyleTransfer, std::string,
22+
std::shared_ptr<react::CallInvoker>);
23+
REGISTER_CONSTRUCTOR(ImageSegmentation, std::string,
24+
std::shared_ptr<react::CallInvoker>);
25+
REGISTER_CONSTRUCTOR(Classification, std::string,
26+
std::shared_ptr<react::CallInvoker>);
27+
REGISTER_CONSTRUCTOR(ObjectDetection, std::string,
28+
std::shared_ptr<react::CallInvoker>);
29+
REGISTER_CONSTRUCTOR(BaseModel, std::string,
30+
std::shared_ptr<react::CallInvoker>);
31+
1932
using namespace facebook;
2033

2134
class RnExecutorchInstaller {
@@ -26,7 +39,10 @@ class RnExecutorchInstaller {
2639
FetchUrlFunc_t fetchDataFromUrl);
2740

2841
private:
29-
template <DerivedFromOrSameAs<BaseModel> ModelT>
42+
template <typename ModelT>
43+
requires meta::ValidConstructorTraits<ModelT> &&
44+
meta::CallInvokerLastInConstructor<ModelT> &&
45+
meta::ProvidesMemoryLowerBound<ModelT>
3046
static jsi::Function
3147
loadModel(jsi::Runtime *jsiRuntime,
3248
std::shared_ptr<react::CallInvoker> jsCallInvoker,
@@ -37,20 +53,24 @@ class RnExecutorchInstaller {
3753
0,
3854
[jsCallInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue,
3955
const jsi::Value *args, size_t count) -> jsi::Value {
40-
if (count != 1) {
56+
constexpr std::size_t expectedCount = std::tuple_size_v<
57+
typename meta::ConstructorTraits<ModelT>::arg_types>;
58+
// count doesn't account for the JSCallInvoker
59+
if (count != expectedCount - 1) {
4160
char errorMessage[100];
4261
std::snprintf(
4362
errorMessage, sizeof(errorMessage),
44-
"Argument count mismatch, was expecting: 1 but got: %zu",
45-
count);
63+
"Argument count mismatch, was expecting: %zu but got: %zu",
64+
expectedCount, count);
4665
throw jsi::JSError(runtime, errorMessage);
4766
}
4867
try {
49-
auto source =
50-
jsiconversion::getValue<std::string>(args[0], runtime);
68+
auto constructorArgs =
69+
meta::createConstructorArgsWithCallInvoker<ModelT>(
70+
args, runtime, jsCallInvoker);
5171

52-
auto modelImplementationPtr =
53-
std::make_shared<ModelT>(source, jsCallInvoker);
72+
auto modelImplementationPtr = std::make_shared<ModelT>(
73+
std::make_from_tuple<ModelT>(constructorArgs));
5474
auto modelHostObject = std::make_shared<ModelHostObject<ModelT>>(
5575
modelImplementationPtr, jsCallInvoker);
5676

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
1111
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
1212

13-
#include <rnexecutorch/TypeConcepts.h>
13+
#include <rnexecutorch/metaprogramming/TypeConcepts.h>
1414
#include <rnexecutorch/models/object_detection/Constants.h>
1515
#include <rnexecutorch/models/object_detection/Utils.h>
1616

@@ -23,7 +23,7 @@ using namespace facebook;
2323
template <typename T> T getValue(const jsi::Value &val, jsi::Runtime &runtime);
2424

2525
template <typename T>
26-
requires IsNumeric<T>
26+
requires meta::IsNumeric<T>
2727
inline T getValue(const jsi::Value &val, jsi::Runtime &runtime) {
2828
static_assert(std::is_integral<T>::value || std::is_floating_point<T>::value,
2929
"Only integral and floating-point types are supported");
@@ -239,24 +239,4 @@ inline jsi::Value getJsiValue(const std::vector<Detection> &detections,
239239
return array;
240240
}
241241

242-
template <typename Model, typename R, typename... Types>
243-
constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) {
244-
return sizeof...(Types);
245-
}
246-
247-
template <typename... Types, std::size_t... I>
248-
std::tuple<Types...> fillTupleFromArgs(std::index_sequence<I...>,
249-
const jsi::Value *args,
250-
jsi::Runtime &runtime) {
251-
return std::make_tuple(getValue<Types>(args[I], runtime)...);
252-
}
253-
254-
template <typename Model, typename R, typename... Types>
255-
std::tuple<Types...> createArgsTupleFromJsi(R (Model::*f)(Types...),
256-
const jsi::Value *args,
257-
jsi::Runtime &runtime) {
258-
return fillTupleFromArgs<Types...>(std::index_sequence_for<Types...>{}, args,
259-
runtime);
260-
}
261-
262242
} // namespace rnexecutorch::jsiconversion

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
#include <ReactCommon/CallInvoker.h>
99

1010
#include <rnexecutorch/Log.h>
11-
#include <rnexecutorch/TypeConcepts.h>
1211
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
1312
#include <rnexecutorch/host_objects/JsiConversions.h>
1413
#include <rnexecutorch/jsi/JsiHostObject.h>
1514
#include <rnexecutorch/jsi/Promise.h>
15+
#include <rnexecutorch/metaprogramming/FunctionHelpers.h>
16+
#include <rnexecutorch/metaprogramming/TypeConcepts.h>
1617
#include <rnexecutorch/models/BaseModel.h>
1718

1819
namespace rnexecutorch {
@@ -22,24 +23,24 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
2223
explicit ModelHostObject(const std::shared_ptr<Model> &model,
2324
std::shared_ptr<react::CallInvoker> callInvoker)
2425
: model(model), callInvoker(callInvoker) {
25-
if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
26+
if constexpr (meta::DerivedFromOrSameAs<Model, BaseModel>) {
2627
addFunctions(
2728
JSI_EXPORT_FUNCTION(ModelHostObject<Model>, unload, "unload"));
2829
}
2930

30-
if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
31+
if constexpr (meta::DerivedFromOrSameAs<Model, BaseModel>) {
3132
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
3233
promiseHostFunction<&Model::forwardJS>,
3334
"forward"));
3435
}
3536

36-
if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
37+
if constexpr (meta::DerivedFromOrSameAs<Model, BaseModel>) {
3738
addFunctions(JSI_EXPORT_FUNCTION(
3839
ModelHostObject<Model>, promiseHostFunction<&Model::getInputShape>,
3940
"getInputShape"));
4041
}
4142

42-
if constexpr (HasGenerate<Model>) {
43+
if constexpr (meta::HasGenerate<Model>) {
4344
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
4445
promiseHostFunction<&Model::generate>,
4546
"generate"));
@@ -54,7 +55,7 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
5455
runtime, callInvoker,
5556
[this, count, args, &runtime](std::shared_ptr<Promise> promise) {
5657
constexpr std::size_t functionArgCount =
57-
jsiconversion::getArgumentCount(FnPtr);
58+
meta::getArgumentCount(FnPtr);
5859
if (functionArgCount != count) {
5960
char errorMessage[100];
6061
std::snprintf(
@@ -67,7 +68,7 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
6768

6869
try {
6970
auto argsConverted =
70-
jsiconversion::createArgsTupleFromJsi(FnPtr, args, runtime);
71+
meta::createArgsTupleFromJsi(FnPtr, args, runtime);
7172

7273
// We need to dispatch a thread if we want the function to be
7374
// asynchronous. In this thread all accesses to jsi::Runtime need to
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#pragma once
2+
3+
#include <concepts>
4+
#include <jsi/jsi.h>
5+
#include <memory>
6+
#include <rnexecutorch/host_objects/JsiConversions.h>
7+
#include <tuple>
8+
#include <type_traits>
9+
10+
namespace facebook::react {
11+
class CallInvoker;
12+
}
13+
14+
namespace rnexecutorch {
15+
namespace meta {
16+
17+
using namespace facebook;
18+
19+
/**
20+
* To be able to generically invoke constructors, we need to know what types
21+
* are the arguments that need to be passed to it. To do this, we specialize the
22+
* ConstructorTraits struct template for each class which constructor we want to
23+
* use. See REGISTER_CONSTRUCTOR macro for a easy way to do this.
24+
*
25+
* Note: a downside for this method is that we can specialize ConstructorTraits
26+
* only for a single constructor signature per class.
27+
*/
28+
29+
template <typename T> struct ConstructorTraits;
30+
31+
template <typename T>
32+
concept HasConstructorTraits =
33+
requires { typename ConstructorTraits<T>::arg_types; };
34+
35+
template <typename T, typename Tuple> struct is_constructible_from_tuple;
36+
37+
template <typename T, typename... Args>
38+
struct is_constructible_from_tuple<T, std::tuple<Args...>>
39+
: std::is_constructible<T, Args...> {};
40+
41+
template <typename T, typename Tuple>
42+
concept ConstructibleFromTuple = is_constructible_from_tuple<T, Tuple>::value;
43+
44+
template <typename NotTuple>
45+
struct last_element_is_call_invoker : std::false_type {};
46+
47+
template <typename... Args>
48+
struct last_element_is_call_invoker<std::tuple<Args...>> {
49+
private:
50+
template <typename Last> static constexpr bool check() {
51+
return std::is_same_v<Last, std::shared_ptr<facebook::react::CallInvoker>>;
52+
}
53+
54+
template <typename First, typename Second, typename... Rest>
55+
static constexpr bool check_last() {
56+
return check_last<Second, Rest...>();
57+
}
58+
59+
template <typename Last> static constexpr bool check_last() {
60+
return check<Last>();
61+
}
62+
63+
public:
64+
static constexpr bool value = sizeof...(Args) > 0 && check_last<Args...>();
65+
};
66+
67+
// HasConstructorTraits<T> could be removed as typename
68+
// ConstructorTraits<T>::arg_types would still resolve the concept to false if
69+
// it wouldn't be defined, but we keep it for readability
70+
template <typename T>
71+
concept ValidConstructorTraits =
72+
HasConstructorTraits<T> &&
73+
ConstructibleFromTuple<T, typename ConstructorTraits<T>::arg_types>;
74+
75+
template <typename T>
76+
concept CallInvokerLastInConstructor =
77+
HasConstructorTraits<T> &&
78+
last_element_is_call_invoker<
79+
typename ConstructorTraits<T>::arg_types>::value;
80+
81+
template <typename... Types, std::size_t... I>
82+
std::tuple<Types...> fillConstructorTupleFromArgs(
83+
std::index_sequence<I...>, const jsi::Value *args, jsi::Runtime &runtime,
84+
std::shared_ptr<react::CallInvoker> jsCallInvoker) {
85+
constexpr std::size_t lastIndex = sizeof...(Types) - 1;
86+
return std::make_tuple([&]() {
87+
if constexpr (I == lastIndex) {
88+
return jsCallInvoker;
89+
} else {
90+
return jsiconversion::getValue<Types>(args[I], runtime);
91+
}
92+
}()...);
93+
}
94+
95+
/// @brief A method that creates a tuple of arguments based on types specified
96+
/// in a ConstructorTraits specialization. The class has to have CallInvoker as
97+
/// the last argument in the constructor.
98+
/// @tparam T The class for which we want to construct the tuple
99+
/// @param args JSI args passed from JS that will be converted according to
100+
/// getValue<T> from JsiConversions
101+
/// @param runtime JS runtime reference
102+
/// @param jsCallInvoker CallInvoker that will be passed to the constructed
103+
/// object. This is the only argument that is not created from jsi::Value.
104+
/// @return A tuple which can then be used to instantiate the class T.
105+
template <typename T>
106+
requires ValidConstructorTraits<T> && CallInvokerLastInConstructor<T>
107+
auto createConstructorArgsWithCallInvoker(
108+
const jsi::Value *args, jsi::Runtime &runtime,
109+
std::shared_ptr<react::CallInvoker> jsCallInvoker) {
110+
return std::apply(
111+
[&](auto... typeWrappers) {
112+
return fillConstructorTupleFromArgs<decltype(typeWrappers)...>(
113+
std::index_sequence_for<decltype(typeWrappers)...>{}, args, runtime,
114+
jsCallInvoker);
115+
},
116+
typename ConstructorTraits<T>::arg_types{});
117+
}
118+
119+
} // namespace meta
120+
121+
// A helper macro to create ConstructorTraits for a class. The variadic pack
122+
// ("...") should list the types of the constructor arguments.
123+
// A class declaration is added so that we don't need to include the class
124+
// definition.
125+
#define REGISTER_CONSTRUCTOR(Class, ...) \
126+
class Class; \
127+
template <> struct meta::ConstructorTraits<Class> { \
128+
using arg_types = std::tuple<__VA_ARGS__>; \
129+
}
130+
131+
} // namespace rnexecutorch
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#include <cstddef>
4+
#include <jsi/jsi.h>
5+
#include <tuple>
6+
7+
#include <rnexecutorch/host_objects/JsiConversions.h>
8+
9+
namespace rnexecutorch::meta {
10+
using namespace facebook;
11+
12+
template <typename Model, typename R, typename... Types>
13+
constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) {
14+
return sizeof...(Types);
15+
}
16+
17+
template <typename... Types, std::size_t... I>
18+
std::tuple<Types...> fillTupleFromArgs(std::index_sequence<I...>,
19+
const jsi::Value *args,
20+
jsi::Runtime &runtime) {
21+
return std::make_tuple(jsiconversion::getValue<Types>(args[I], runtime)...);
22+
}
23+
24+
/**
25+
* createArgsTupleFromJsi creates a tuple that can be used as a collection of
26+
* arguments for method supplied with a pointer. The types in the tuple are
27+
* inferred from the method pointer.
28+
*/
29+
30+
template <typename Model, typename R, typename... Types>
31+
std::tuple<Types...> createArgsTupleFromJsi(R (Model::*f)(Types...),
32+
const jsi::Value *args,
33+
jsi::Runtime &runtime) {
34+
return fillTupleFromArgs<Types...>(std::index_sequence_for<Types...>{}, args,
35+
runtime);
36+
}
37+
} // namespace rnexecutorch::meta

packages/react-native-executorch/common/rnexecutorch/TypeConcepts.h renamed to packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include <concepts>
44
#include <type_traits>
55

6-
namespace rnexecutorch {
6+
namespace rnexecutorch::meta {
77

88
template <typename T, typename Base>
99
concept DerivedFromOrSameAs = std::is_base_of_v<Base, T>;
@@ -16,4 +16,9 @@ concept HasGenerate = requires(T t) {
1616
template <typename T>
1717
concept IsNumeric = std::is_arithmetic_v<T>;
1818

19-
} // namespace rnexecutorch
19+
template <typename T>
20+
concept ProvidesMemoryLowerBound = requires(T t) {
21+
{ &T::getMemoryLowerBound };
22+
};
23+
24+
} // namespace rnexecutorch::meta

0 commit comments

Comments
 (0)