diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h index e68340e5ab..42d0a86bdc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h @@ -7,15 +7,28 @@ #include #include -#include #include #include +#include +#include +#include namespace rnexecutorch { using FetchUrlFunc_t = std::function(std::string)>; extern FetchUrlFunc_t fetchUrlFunc; +REGISTER_CONSTRUCTOR(StyleTransfer, std::string, + std::shared_ptr); +REGISTER_CONSTRUCTOR(ImageSegmentation, std::string, + std::shared_ptr); +REGISTER_CONSTRUCTOR(Classification, std::string, + std::shared_ptr); +REGISTER_CONSTRUCTOR(ObjectDetection, std::string, + std::shared_ptr); +REGISTER_CONSTRUCTOR(BaseModel, std::string, + std::shared_ptr); + using namespace facebook; class RnExecutorchInstaller { @@ -26,7 +39,10 @@ class RnExecutorchInstaller { FetchUrlFunc_t fetchDataFromUrl); private: - template ModelT> + template + requires meta::ValidConstructorTraits && + meta::CallInvokerLastInConstructor && + meta::ProvidesMemoryLowerBound static jsi::Function loadModel(jsi::Runtime *jsiRuntime, std::shared_ptr jsCallInvoker, @@ -37,20 +53,24 @@ class RnExecutorchInstaller { 0, [jsCallInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue, const jsi::Value *args, size_t count) -> jsi::Value { - if (count != 1) { + constexpr std::size_t expectedCount = std::tuple_size_v< + typename meta::ConstructorTraits::arg_types>; + // count doesn't account for the JSCallInvoker + if (count != expectedCount - 1) { char errorMessage[100]; std::snprintf( errorMessage, sizeof(errorMessage), - "Argument count mismatch, was expecting: 1 but got: %zu", - count); + "Argument count mismatch, was expecting: %zu but got: %zu", + expectedCount, count); throw jsi::JSError(runtime, errorMessage); } try { - auto source = - jsiconversion::getValue(args[0], runtime); + auto constructorArgs = + meta::createConstructorArgsWithCallInvoker( + args, runtime, jsCallInvoker); - auto modelImplementationPtr = - std::make_shared(source, jsCallInvoker); + auto modelImplementationPtr = std::make_shared( + std::make_from_tuple(constructorArgs)); auto modelHostObject = std::make_shared>( modelImplementationPtr, jsCallInvoker); diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 447c0439c6..f723db5ce9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include @@ -23,7 +23,7 @@ using namespace facebook; template T getValue(const jsi::Value &val, jsi::Runtime &runtime); template - requires IsNumeric + requires meta::IsNumeric inline T getValue(const jsi::Value &val, jsi::Runtime &runtime) { static_assert(std::is_integral::value || std::is_floating_point::value, "Only integral and floating-point types are supported"); @@ -239,24 +239,4 @@ inline jsi::Value getJsiValue(const std::vector &detections, return array; } -template -constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) { - return sizeof...(Types); -} - -template -std::tuple fillTupleFromArgs(std::index_sequence, - const jsi::Value *args, - jsi::Runtime &runtime) { - return std::make_tuple(getValue(args[I], runtime)...); -} - -template -std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...), - const jsi::Value *args, - jsi::Runtime &runtime) { - return fillTupleFromArgs(std::index_sequence_for{}, args, - runtime); -} - } // namespace rnexecutorch::jsiconversion \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index db6190c30b..95f5e01071 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -8,11 +8,12 @@ #include #include -#include #include #include #include #include +#include +#include #include namespace rnexecutorch { @@ -22,24 +23,24 @@ template class ModelHostObject : public JsiHostObject { explicit ModelHostObject(const std::shared_ptr &model, std::shared_ptr callInvoker) : model(model), callInvoker(callInvoker) { - if constexpr (DerivedFromOrSameAs) { + if constexpr (meta::DerivedFromOrSameAs) { addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); } - if constexpr (DerivedFromOrSameAs) { + if constexpr (meta::DerivedFromOrSameAs) { addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::forwardJS>, "forward")); } - if constexpr (DerivedFromOrSameAs) { + if constexpr (meta::DerivedFromOrSameAs) { addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, promiseHostFunction<&Model::getInputShape>, "getInputShape")); } - if constexpr (HasGenerate) { + if constexpr (meta::HasGenerate) { addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::generate>, "generate")); @@ -54,7 +55,7 @@ template class ModelHostObject : public JsiHostObject { runtime, callInvoker, [this, count, args, &runtime](std::shared_ptr promise) { constexpr std::size_t functionArgCount = - jsiconversion::getArgumentCount(FnPtr); + meta::getArgumentCount(FnPtr); if (functionArgCount != count) { char errorMessage[100]; std::snprintf( @@ -67,7 +68,7 @@ template class ModelHostObject : public JsiHostObject { try { auto argsConverted = - jsiconversion::createArgsTupleFromJsi(FnPtr, args, runtime); + meta::createArgsTupleFromJsi(FnPtr, args, runtime); // We need to dispatch a thread if we want the function to be // asynchronous. In this thread all accesses to jsi::Runtime need to diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ConstructorHelpers.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ConstructorHelpers.h new file mode 100644 index 0000000000..c5b776dc5f --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ConstructorHelpers.h @@ -0,0 +1,131 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace facebook::react { +class CallInvoker; +} + +namespace rnexecutorch { +namespace meta { + +using namespace facebook; + +/** + * To be able to generically invoke constructors, we need to know what types + * are the arguments that need to be passed to it. To do this, we specialize the + * ConstructorTraits struct template for each class which constructor we want to + * use. See REGISTER_CONSTRUCTOR macro for a easy way to do this. + * + * Note: a downside for this method is that we can specialize ConstructorTraits + * only for a single constructor signature per class. + */ + +template struct ConstructorTraits; + +template +concept HasConstructorTraits = + requires { typename ConstructorTraits::arg_types; }; + +template struct is_constructible_from_tuple; + +template +struct is_constructible_from_tuple> + : std::is_constructible {}; + +template +concept ConstructibleFromTuple = is_constructible_from_tuple::value; + +template +struct last_element_is_call_invoker : std::false_type {}; + +template +struct last_element_is_call_invoker> { +private: + template static constexpr bool check() { + return std::is_same_v>; + } + + template + static constexpr bool check_last() { + return check_last(); + } + + template static constexpr bool check_last() { + return check(); + } + +public: + static constexpr bool value = sizeof...(Args) > 0 && check_last(); +}; + +// HasConstructorTraits could be removed as typename +// ConstructorTraits::arg_types would still resolve the concept to false if +// it wouldn't be defined, but we keep it for readability +template +concept ValidConstructorTraits = + HasConstructorTraits && + ConstructibleFromTuple::arg_types>; + +template +concept CallInvokerLastInConstructor = + HasConstructorTraits && + last_element_is_call_invoker< + typename ConstructorTraits::arg_types>::value; + +template +std::tuple fillConstructorTupleFromArgs( + std::index_sequence, const jsi::Value *args, jsi::Runtime &runtime, + std::shared_ptr jsCallInvoker) { + constexpr std::size_t lastIndex = sizeof...(Types) - 1; + return std::make_tuple([&]() { + if constexpr (I == lastIndex) { + return jsCallInvoker; + } else { + return jsiconversion::getValue(args[I], runtime); + } + }()...); +} + +/// @brief A method that creates a tuple of arguments based on types specified +/// in a ConstructorTraits specialization. The class has to have CallInvoker as +/// the last argument in the constructor. +/// @tparam T The class for which we want to construct the tuple +/// @param args JSI args passed from JS that will be converted according to +/// getValue from JsiConversions +/// @param runtime JS runtime reference +/// @param jsCallInvoker CallInvoker that will be passed to the constructed +/// object. This is the only argument that is not created from jsi::Value. +/// @return A tuple which can then be used to instantiate the class T. +template + requires ValidConstructorTraits && CallInvokerLastInConstructor +auto createConstructorArgsWithCallInvoker( + const jsi::Value *args, jsi::Runtime &runtime, + std::shared_ptr jsCallInvoker) { + return std::apply( + [&](auto... typeWrappers) { + return fillConstructorTupleFromArgs( + std::index_sequence_for{}, args, runtime, + jsCallInvoker); + }, + typename ConstructorTraits::arg_types{}); +} + +} // namespace meta + +// A helper macro to create ConstructorTraits for a class. The variadic pack +// ("...") should list the types of the constructor arguments. +// A class declaration is added so that we don't need to include the class +// definition. +#define REGISTER_CONSTRUCTOR(Class, ...) \ + class Class; \ + template <> struct meta::ConstructorTraits { \ + using arg_types = std::tuple<__VA_ARGS__>; \ + } + +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h new file mode 100644 index 0000000000..ff7a5fa53c --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +#include + +namespace rnexecutorch::meta { +using namespace facebook; + +template +constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) { + return sizeof...(Types); +} + +template +std::tuple fillTupleFromArgs(std::index_sequence, + const jsi::Value *args, + jsi::Runtime &runtime) { + return std::make_tuple(jsiconversion::getValue(args[I], runtime)...); +} + +/** + * createArgsTupleFromJsi creates a tuple that can be used as a collection of + * arguments for method supplied with a pointer. The types in the tuple are + * inferred from the method pointer. + */ + +template +std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...), + const jsi::Value *args, + jsi::Runtime &runtime) { + return fillTupleFromArgs(std::index_sequence_for{}, args, + runtime); +} +} // namespace rnexecutorch::meta \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h similarity index 64% rename from packages/react-native-executorch/common/rnexecutorch/TypeConcepts.h rename to packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index b7414a5a48..ae5111ba37 100644 --- a/packages/react-native-executorch/common/rnexecutorch/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -3,7 +3,7 @@ #include #include -namespace rnexecutorch { +namespace rnexecutorch::meta { template concept DerivedFromOrSameAs = std::is_base_of_v; @@ -16,4 +16,9 @@ concept HasGenerate = requires(T t) { template concept IsNumeric = std::is_arithmetic_v; -} // namespace rnexecutorch \ No newline at end of file +template +concept ProvidesMemoryLowerBound = requires(T t) { + { &T::getMemoryLowerBound }; +}; + +} // namespace rnexecutorch::meta \ No newline at end of file