Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,28 @@
#include <ReactCommon/CallInvoker.h>
#include <jsi/jsi.h>

#include <rnexecutorch/TypeConcepts.h>
#include <rnexecutorch/host_objects/JsiConversions.h>
#include <rnexecutorch/host_objects/ModelHostObject.h>
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>
#include <rnexecutorch/metaprogramming/FunctionHelpers.h>
#include <rnexecutorch/metaprogramming/TypeConcepts.h>

namespace rnexecutorch {

using FetchUrlFunc_t = std::function<std::vector<std::byte>(std::string)>;
extern FetchUrlFunc_t fetchUrlFunc;

REGISTER_CONSTRUCTOR(StyleTransfer, std::string,
std::shared_ptr<react::CallInvoker>);
REGISTER_CONSTRUCTOR(ImageSegmentation, std::string,
std::shared_ptr<react::CallInvoker>);
REGISTER_CONSTRUCTOR(Classification, std::string,
std::shared_ptr<react::CallInvoker>);
REGISTER_CONSTRUCTOR(ObjectDetection, std::string,
std::shared_ptr<react::CallInvoker>);
REGISTER_CONSTRUCTOR(BaseModel, std::string,
std::shared_ptr<react::CallInvoker>);

using namespace facebook;

class RnExecutorchInstaller {
Expand All @@ -26,7 +39,10 @@ class RnExecutorchInstaller {
FetchUrlFunc_t fetchDataFromUrl);

private:
template <DerivedFromOrSameAs<BaseModel> ModelT>
template <typename ModelT>
requires meta::ValidConstructorTraits<ModelT> &&
meta::CallInvokerLastInConstructor<ModelT> &&
meta::ProvidesMemoryLowerBound<ModelT>
static jsi::Function
loadModel(jsi::Runtime *jsiRuntime,
std::shared_ptr<react::CallInvoker> jsCallInvoker,
Expand All @@ -37,20 +53,24 @@ class RnExecutorchInstaller {
0,
[jsCallInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue,
const jsi::Value *args, size_t count) -> jsi::Value {
if (count != 1) {
constexpr std::size_t expectedCount = std::tuple_size_v<
typename meta::ConstructorTraits<ModelT>::arg_types>;
// count doesn't account for the JSCallInvoker
if (count != expectedCount - 1) {
char errorMessage[100];
std::snprintf(
errorMessage, sizeof(errorMessage),
"Argument count mismatch, was expecting: 1 but got: %zu",
count);
"Argument count mismatch, was expecting: %zu but got: %zu",
expectedCount, count);
throw jsi::JSError(runtime, errorMessage);
}
try {
auto source =
jsiconversion::getValue<std::string>(args[0], runtime);
auto constructorArgs =
meta::createConstructorArgsWithCallInvoker<ModelT>(
args, runtime, jsCallInvoker);

auto modelImplementationPtr =
std::make_shared<ModelT>(source, jsCallInvoker);
auto modelImplementationPtr = std::make_shared<ModelT>(
std::make_from_tuple<ModelT>(constructorArgs));
auto modelHostObject = std::make_shared<ModelHostObject<ModelT>>(
modelImplementationPtr, jsCallInvoker);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
#include <rnexecutorch/jsi/OwningArrayBuffer.h>

#include <rnexecutorch/TypeConcepts.h>
#include <rnexecutorch/metaprogramming/TypeConcepts.h>
#include <rnexecutorch/models/object_detection/Constants.h>
#include <rnexecutorch/models/object_detection/Utils.h>

Expand All @@ -23,7 +23,7 @@ using namespace facebook;
template <typename T> T getValue(const jsi::Value &val, jsi::Runtime &runtime);

template <typename T>
requires IsNumeric<T>
requires meta::IsNumeric<T>
inline T getValue(const jsi::Value &val, jsi::Runtime &runtime) {
static_assert(std::is_integral<T>::value || std::is_floating_point<T>::value,
"Only integral and floating-point types are supported");
Expand Down Expand Up @@ -239,24 +239,4 @@ inline jsi::Value getJsiValue(const std::vector<Detection> &detections,
return array;
}

template <typename Model, typename R, typename... Types>
constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) {
return sizeof...(Types);
}

template <typename... Types, std::size_t... I>
std::tuple<Types...> fillTupleFromArgs(std::index_sequence<I...>,
const jsi::Value *args,
jsi::Runtime &runtime) {
return std::make_tuple(getValue<Types>(args[I], runtime)...);
}

template <typename Model, typename R, typename... Types>
std::tuple<Types...> createArgsTupleFromJsi(R (Model::*f)(Types...),
const jsi::Value *args,
jsi::Runtime &runtime) {
return fillTupleFromArgs<Types...>(std::index_sequence_for<Types...>{}, args,
runtime);
}

} // namespace rnexecutorch::jsiconversion
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
#include <ReactCommon/CallInvoker.h>

#include <rnexecutorch/Log.h>
#include <rnexecutorch/TypeConcepts.h>
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
#include <rnexecutorch/host_objects/JsiConversions.h>
#include <rnexecutorch/jsi/JsiHostObject.h>
#include <rnexecutorch/jsi/Promise.h>
#include <rnexecutorch/metaprogramming/FunctionHelpers.h>
#include <rnexecutorch/metaprogramming/TypeConcepts.h>
#include <rnexecutorch/models/BaseModel.h>

namespace rnexecutorch {
Expand All @@ -22,24 +23,24 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
explicit ModelHostObject(const std::shared_ptr<Model> &model,
std::shared_ptr<react::CallInvoker> callInvoker)
: model(model), callInvoker(callInvoker) {
if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
if constexpr (meta::DerivedFromOrSameAs<Model, BaseModel>) {
addFunctions(
JSI_EXPORT_FUNCTION(ModelHostObject<Model>, unload, "unload"));
}

if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
if constexpr (meta::DerivedFromOrSameAs<Model, BaseModel>) {
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
promiseHostFunction<&Model::forwardJS>,
"forward"));
}

if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
if constexpr (meta::DerivedFromOrSameAs<Model, BaseModel>) {
addFunctions(JSI_EXPORT_FUNCTION(
ModelHostObject<Model>, promiseHostFunction<&Model::getInputShape>,
"getInputShape"));
}

if constexpr (HasGenerate<Model>) {
if constexpr (meta::HasGenerate<Model>) {
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
promiseHostFunction<&Model::generate>,
"generate"));
Expand All @@ -54,7 +55,7 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
runtime, callInvoker,
[this, count, args, &runtime](std::shared_ptr<Promise> promise) {
constexpr std::size_t functionArgCount =
jsiconversion::getArgumentCount(FnPtr);
meta::getArgumentCount(FnPtr);
if (functionArgCount != count) {
char errorMessage[100];
std::snprintf(
Expand All @@ -67,7 +68,7 @@ template <typename Model> class ModelHostObject : public JsiHostObject {

try {
auto argsConverted =
jsiconversion::createArgsTupleFromJsi(FnPtr, args, runtime);
meta::createArgsTupleFromJsi(FnPtr, args, runtime);

// We need to dispatch a thread if we want the function to be
// asynchronous. In this thread all accesses to jsi::Runtime need to
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#pragma once

#include <concepts>
#include <jsi/jsi.h>
#include <memory>
#include <rnexecutorch/host_objects/JsiConversions.h>
#include <tuple>
#include <type_traits>

namespace facebook::react {
class CallInvoker;
}

namespace rnexecutorch {
namespace meta {

using namespace facebook;

/**
* To be able to generically invoke constructors, we need to know what types
* are the arguments that need to be passed to it. To do this, we specialize the
* ConstructorTraits struct template for each class which constructor we want to
* use. See REGISTER_CONSTRUCTOR macro for a easy way to do this.
*
* Note: a downside for this method is that we can specialize ConstructorTraits
* only for a single constructor signature per class.
*/

template <typename T> struct ConstructorTraits;

template <typename T>
concept HasConstructorTraits =
requires { typename ConstructorTraits<T>::arg_types; };

template <typename T, typename Tuple> struct is_constructible_from_tuple;

template <typename T, typename... Args>
struct is_constructible_from_tuple<T, std::tuple<Args...>>
: std::is_constructible<T, Args...> {};

template <typename T, typename Tuple>
concept ConstructibleFromTuple = is_constructible_from_tuple<T, Tuple>::value;

template <typename NotTuple>
struct last_element_is_call_invoker : std::false_type {};

template <typename... Args>
struct last_element_is_call_invoker<std::tuple<Args...>> {
private:
template <typename Last> static constexpr bool check() {
return std::is_same_v<Last, std::shared_ptr<facebook::react::CallInvoker>>;
}

template <typename First, typename Second, typename... Rest>
static constexpr bool check_last() {
return check_last<Second, Rest...>();
}

template <typename Last> static constexpr bool check_last() {
return check<Last>();
}

public:
static constexpr bool value = sizeof...(Args) > 0 && check_last<Args...>();
};

// HasConstructorTraits<T> could be removed as typename
// ConstructorTraits<T>::arg_types would still resolve the concept to false if
// it wouldn't be defined, but we keep it for readability
template <typename T>
concept ValidConstructorTraits =
HasConstructorTraits<T> &&
ConstructibleFromTuple<T, typename ConstructorTraits<T>::arg_types>;

template <typename T>
concept CallInvokerLastInConstructor =
HasConstructorTraits<T> &&
last_element_is_call_invoker<
typename ConstructorTraits<T>::arg_types>::value;

template <typename... Types, std::size_t... I>
std::tuple<Types...> fillConstructorTupleFromArgs(
std::index_sequence<I...>, const jsi::Value *args, jsi::Runtime &runtime,
std::shared_ptr<react::CallInvoker> jsCallInvoker) {
constexpr std::size_t lastIndex = sizeof...(Types) - 1;
return std::make_tuple([&]() {
if constexpr (I == lastIndex) {
return jsCallInvoker;
} else {
return jsiconversion::getValue<Types>(args[I], runtime);
}
}()...);
}

/// @brief A method that creates a tuple of arguments based on types specified
/// in a ConstructorTraits specialization. The class has to have CallInvoker as
/// the last argument in the constructor.
/// @tparam T The class for which we want to construct the tuple
/// @param args JSI args passed from JS that will be converted according to
/// getValue<T> from JsiConversions
/// @param runtime JS runtime reference
/// @param jsCallInvoker CallInvoker that will be passed to the constructed
/// object. This is the only argument that is not created from jsi::Value.
/// @return A tuple which can then be used to instantiate the class T.
template <typename T>
requires ValidConstructorTraits<T> && CallInvokerLastInConstructor<T>
auto createConstructorArgsWithCallInvoker(
const jsi::Value *args, jsi::Runtime &runtime,
std::shared_ptr<react::CallInvoker> jsCallInvoker) {
return std::apply(
[&](auto... typeWrappers) {
return fillConstructorTupleFromArgs<decltype(typeWrappers)...>(
std::index_sequence_for<decltype(typeWrappers)...>{}, args, runtime,
jsCallInvoker);
},
typename ConstructorTraits<T>::arg_types{});
}

} // namespace meta

// A helper macro to create ConstructorTraits for a class. The variadic pack
// ("...") should list the types of the constructor arguments.
// A class declaration is added so that we don't need to include the class
// definition.
#define REGISTER_CONSTRUCTOR(Class, ...) \
class Class; \
template <> struct meta::ConstructorTraits<Class> { \
using arg_types = std::tuple<__VA_ARGS__>; \
}

} // namespace rnexecutorch
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include <cstddef>
#include <jsi/jsi.h>
#include <tuple>

#include <rnexecutorch/host_objects/JsiConversions.h>

namespace rnexecutorch::meta {
using namespace facebook;

template <typename Model, typename R, typename... Types>
constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) {
return sizeof...(Types);
}

template <typename... Types, std::size_t... I>
std::tuple<Types...> fillTupleFromArgs(std::index_sequence<I...>,
const jsi::Value *args,
jsi::Runtime &runtime) {
return std::make_tuple(jsiconversion::getValue<Types>(args[I], runtime)...);
}

/**
* createArgsTupleFromJsi creates a tuple that can be used as a collection of
* arguments for method supplied with a pointer. The types in the tuple are
* inferred from the method pointer.
*/

template <typename Model, typename R, typename... Types>
std::tuple<Types...> createArgsTupleFromJsi(R (Model::*f)(Types...),
const jsi::Value *args,
jsi::Runtime &runtime) {
return fillTupleFromArgs<Types...>(std::index_sequence_for<Types...>{}, args,
runtime);
}
} // namespace rnexecutorch::meta
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <concepts>
#include <type_traits>

namespace rnexecutorch {
namespace rnexecutorch::meta {

template <typename T, typename Base>
concept DerivedFromOrSameAs = std::is_base_of_v<Base, T>;
Expand All @@ -16,4 +16,9 @@ concept HasGenerate = requires(T t) {
template <typename T>
concept IsNumeric = std::is_arithmetic_v<T>;

} // namespace rnexecutorch
template <typename T>
concept ProvidesMemoryLowerBound = requires(T t) {
{ &T::getMemoryLowerBound };
};

} // namespace rnexecutorch::meta