diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index dcff11d2f6..b82d24316e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -37,5 +37,11 @@ void RnExecutorchInstaller::injectJSIBindings( *jsiRuntime, "loadObjectDetection", RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadObjectDetection")); + + jsiRuntime->global().setProperty( + *jsiRuntime, "loadExecutorchModule", + RnExecutorchInstaller::loadModel(jsiRuntime, jsCallInvoker, + "loadExecutorchModule")); } + } // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h index f7e105909f..e68340e5ab 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h @@ -7,7 +7,7 @@ #include #include -#include +#include #include #include @@ -26,7 +26,7 @@ class RnExecutorchInstaller { FetchUrlFunc_t fetchDataFromUrl); private: - template + template ModelT> static jsi::Function loadModel(jsi::Runtime *jsiRuntime, std::shared_ptr jsCallInvoker, diff --git a/packages/react-native-executorch/common/rnexecutorch/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/TypeConcepts.h new file mode 100644 index 0000000000..b7414a5a48 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/TypeConcepts.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace rnexecutorch { + +template +concept DerivedFromOrSameAs = std::is_base_of_v; + +template +concept HasGenerate = requires(T t) { + { &T::generate }; +}; + +template +concept IsNumeric = std::is_arithmetic_v; + +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/TypeConstraints.h b/packages/react-native-executorch/common/rnexecutorch/TypeConstraints.h deleted file mode 100644 index 45e1649a82..0000000000 --- a/packages/react-native-executorch/common/rnexecutorch/TypeConstraints.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -#include - -namespace rnexecutorch { - -template -concept DerivedFromBaseModel = std::is_base_of_v; - -} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h new file mode 100644 index 0000000000..4057950b23 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h @@ -0,0 +1,12 @@ +#pragma once + +namespace rnexecutorch { + +using executorch::aten::ScalarType; + +struct JSTensorViewIn { + void *dataPtr; + std::vector sizes; + ScalarType scalarType; +}; +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewOut.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewOut.h new file mode 100644 index 0000000000..4ea4fefc1c --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewOut.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include +#include + +namespace rnexecutorch { + +using executorch::runtime::etensor::ScalarType; + +struct JSTensorViewOut { + std::shared_ptr dataPtr; + std::vector sizes; + ScalarType scalarType; + + JSTensorViewOut(std::vector sizes, ScalarType scalarType, + std::shared_ptr dataPtr) + : sizes(std::move(sizes)), scalarType(scalarType), + dataPtr(std::move(dataPtr)) {} +}; +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index c10f3c1824..447c0439c6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -4,8 +4,13 @@ #include #include +#include #include +#include +#include +#include +#include #include #include @@ -17,9 +22,12 @@ using namespace facebook; template T getValue(const jsi::Value &val, jsi::Runtime &runtime); -template <> -inline double getValue(const jsi::Value &val, jsi::Runtime &runtime) { - return val.asNumber(); +template + requires IsNumeric +inline T getValue(const jsi::Value &val, jsi::Runtime &runtime) { + static_assert(std::is_integral::value || std::is_floating_point::value, + "Only integral and floating-point types are supported"); + return static_cast(val.asNumber()); } template <> @@ -33,6 +41,78 @@ inline std::string getValue(const jsi::Value &val, return val.getString(runtime).utf8(runtime); } +template <> +inline JSTensorViewIn getValue(const jsi::Value &val, + jsi::Runtime &runtime) { + jsi::Object obj = val.asObject(runtime); + JSTensorViewIn tensorView; + + int scalarTypeInt = obj.getProperty(runtime, "scalarType").asNumber(); + tensorView.scalarType = static_cast(scalarTypeInt); + + jsi::Value shapeValue = obj.getProperty(runtime, "sizes"); + jsi::Array shapeArray = shapeValue.asObject(runtime).asArray(runtime); + size_t numShapeDims = shapeArray.size(runtime); + tensorView.sizes.reserve(numShapeDims); + + for (size_t i = 0; i < numShapeDims; ++i) { + int dim = getValue(shapeArray.getValueAtIndex(runtime, i), runtime); + tensorView.sizes.push_back(static_cast(dim)); + } + + // On JS side, TensorPtr objects hold a 'data' property which should be either + // an ArrayBuffer or TypedArray + jsi::Value dataValue = obj.getProperty(runtime, "dataPtr"); + jsi::Object dataObj = dataValue.asObject(runtime); + + // Check if it's an ArrayBuffer or TypedArray + if (dataObj.isArrayBuffer(runtime)) { + jsi::ArrayBuffer arrayBuffer = dataObj.getArrayBuffer(runtime); + tensorView.dataPtr = arrayBuffer.data(runtime); + + } else { + // Handle typed arrays (Float32Array, Int32Array, etc.) + const bool isValidTypedArray = dataObj.hasProperty(runtime, "buffer") && + dataObj.hasProperty(runtime, "byteOffset") && + dataObj.hasProperty(runtime, "byteLength") && + dataObj.hasProperty(runtime, "length"); + if (!isValidTypedArray) { + throw jsi::JSError(runtime, "Data must be an ArrayBuffer or TypedArray"); + } + jsi::Value bufferValue = dataObj.getProperty(runtime, "buffer"); + if (!bufferValue.isObject() || + !bufferValue.asObject(runtime).isArrayBuffer(runtime)) { + throw jsi::JSError(runtime, + "TypedArray buffer property must be an ArrayBuffer"); + } + + jsi::ArrayBuffer arrayBuffer = + bufferValue.asObject(runtime).getArrayBuffer(runtime); + size_t byteOffset = + getValue(dataObj.getProperty(runtime, "byteOffset"), runtime); + + tensorView.dataPtr = + static_cast(arrayBuffer.data(runtime)) + byteOffset; + } + return tensorView; +} + +template <> +inline std::vector +getValue>(const jsi::Value &val, + jsi::Runtime &runtime) { + jsi::Array array = val.asObject(runtime).asArray(runtime); + size_t length = array.size(runtime); + std::vector result; + result.reserve(length); + + for (size_t i = 0; i < length; ++i) { + jsi::Value element = array.getValueAtIndex(runtime, i); + result.push_back(getValue(element, runtime)); + } + return result; +} + template <> inline std::vector getValue>(const jsi::Value &val, @@ -78,6 +158,51 @@ inline jsi::Value getJsiValue(std::shared_ptr valuePtr, return std::move(*valuePtr); } +inline jsi::Value getJsiValue(const std::vector &vec, + jsi::Runtime &runtime) { + jsi::Array array(runtime, vec.size()); + for (size_t i = 0; i < vec.size(); i++) { + array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); + } + return jsi::Value(runtime, array); +} + +inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) { + return jsi::Value(runtime, val); +} + +inline jsi::Value +getJsiValue(const std::vector> &vec, + jsi::Runtime &runtime) { + jsi::Array array(runtime, vec.size()); + for (size_t i = 0; i < vec.size(); i++) { + jsi::ArrayBuffer arrayBuffer(runtime, vec[i]); + array.setValueAtIndex(runtime, i, jsi::Value(runtime, arrayBuffer)); + } + return jsi::Value(runtime, array); +} + +inline jsi::Value +getJsiValue(const std::vector> &vec, + jsi::Runtime &runtime) { + jsi::Array array(runtime, vec.size()); + for (size_t i = 0; i < vec.size(); i++) { + jsi::Object tensorObj(runtime); + + tensorObj.setProperty(runtime, "sizes", + getJsiValue(vec[i]->sizes, runtime)); + + tensorObj.setProperty(runtime, "scalarType", + jsi::Value(static_cast(vec[i]->scalarType))); + + jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->dataPtr); + tensorObj.setProperty(runtime, "dataPtr", arrayBuffer); + + array.setValueAtIndex(runtime, i, tensorObj); + } + return jsi::Value(runtime, array); +} + inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) { return jsi::String::createFromAscii(runtime, str); } diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index e233471583..db6190c30b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -8,10 +8,12 @@ #include #include -#include +#include +#include #include #include #include +#include namespace rnexecutorch { @@ -20,13 +22,28 @@ template class ModelHostObject : public JsiHostObject { explicit ModelHostObject(const std::shared_ptr &model, std::shared_ptr callInvoker) : model(model), callInvoker(callInvoker) { - addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::forward>, - "forward")); - if constexpr (DerivedFromBaseModel) { + if constexpr (DerivedFromOrSameAs) { addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); } + + if constexpr (DerivedFromOrSameAs) { + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::forwardJS>, + "forward")); + } + + if constexpr (DerivedFromOrSameAs) { + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, promiseHostFunction<&Model::getInputShape>, + "getInputShape")); + } + + if constexpr (HasGenerate) { + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::generate>, + "generate")); + } } // A generic host function that resolves a promise with a result of a diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp index 60b6f73beb..d4f846518d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp @@ -2,12 +2,12 @@ #include -#include +#include namespace rnexecutorch { using namespace facebook; -using ::executorch::extension::Module; +using namespace executorch::extension; using ::executorch::runtime::Error; BaseModel::BaseModel(const std::string &modelSource, @@ -17,8 +17,8 @@ BaseModel::BaseModel(const std::string &modelSource, modelSource, Module::LoadMode::MmapUseMlockIgnoreErrors)) { Error loadError = module->load(); if (loadError != Error::Ok) { - throw std::runtime_error("Couldn't load the model, error: " + - std::to_string(static_cast(loadError))); + throw std::runtime_error("Failed to load model: Error " + + std::to_string(static_cast(loadError))); } // We use the size of the model .pte file as the lower bound for the memory // occupied by the ET module. This is not the whole size however, the module @@ -28,14 +28,43 @@ BaseModel::BaseModel(const std::string &modelSource, memorySizeLowerBound = std::filesystem::file_size(modelPath); } -std::vector> BaseModel::getInputShape() { +std::vector BaseModel::getInputShape(std::string method_name, + int index) { if (!module) { - throw std::runtime_error("getInputShape called on unloaded model"); + throw std::runtime_error("Model not loaded: Cannot get input shape"); } - auto method_meta = module->method_meta("forward"); + auto method_meta = module->method_meta(method_name); if (!method_meta.ok()) { - throw std::runtime_error("Failed to load forward"); + throw std::runtime_error( + "Failed to get metadata for method '" + method_name + "': Error " + + std::to_string(static_cast(method_meta.error()))); + } + + auto input_meta = method_meta->input_tensor_meta(index); + if (!input_meta.ok()) { + throw std::runtime_error( + "Failed to get metadata for input tensor at index " + + std::to_string(index) + " in method '" + method_name + "': Error " + + std::to_string(static_cast(input_meta.error()))); + } + + auto sizes = input_meta->sizes(); + std::vector input_shape(sizes.begin(), sizes.end()); + return input_shape; +} + +std::vector> +BaseModel::getAllInputShapes(std::string methodName) { + if (!module) { + throw std::runtime_error("Model not loaded: Cannot get all input shapes"); + } + + auto method_meta = module->method_meta(methodName); + if (!method_meta.ok()) { + throw std::runtime_error( + "Failed to get metadata for method '" + methodName + "': Error " + + std::to_string(static_cast(method_meta.error()))); } std::vector> output; std::size_t numInputs = method_meta->num_inputs(); @@ -43,7 +72,10 @@ std::vector> BaseModel::getInputShape() { for (std::size_t input = 0; input < numInputs; ++input) { auto input_meta = method_meta->input_tensor_meta(input); if (!input_meta.ok()) { - throw std::runtime_error("Failed to load forward input"); + throw std::runtime_error( + "Failed to get metadata for input tensor at index " + + std::to_string(input) + " in method '" + methodName + "': Error " + + std::to_string(static_cast(input_meta.error()))); } auto shape = input_meta->sizes(); output.emplace_back(std::vector(shape.begin(), shape.end())); @@ -51,15 +83,78 @@ std::vector> BaseModel::getInputShape() { return output; } -std::size_t BaseModel::getMemoryLowerBound() { return memorySizeLowerBound; } +std::vector> +BaseModel::forwardJS(const std::vector tensorViewVec) { + if (!module) { + throw std::runtime_error("Model not loaded: Cannot perform forward pass"); + } + std::vector evalues; + evalues.reserve(tensorViewVec.size()); + // Because EValue doesn't hold to the dynamic data and metadata from + // TensorPtr, we need to make sure that the TensorPtr for each EValue is valid + // as long as that EValue is in use. Therefore we create a vec solely for + // keeping references to the TensorPtr + std::vector tensorPtrs; + tensorPtrs.reserve(evalues.size()); -void BaseModel::unload() { module.reset(nullptr); } + for (size_t i = 0; i < tensorViewVec.size(); i++) { + const auto &currTensorView = tensorViewVec[i]; + auto tensorPtr = + make_tensor_ptr(currTensorView.sizes, currTensorView.dataPtr, + currTensorView.scalarType); + tensorPtrs.emplace_back(tensorPtr); + evalues.emplace_back(*tensorPtr); // Dereference TensorPtr to get Tensor, + // which implicitly converts to EValue + } + + auto result = module->forward(evalues); + if (!result.ok()) { + throw std::runtime_error("Forward pass failed: Error " + + std::to_string(static_cast(result.error()))); + } + + auto &outputs = result.get(); + std::vector> output; + output.reserve(outputs.size()); + + // Convert ET outputs to a vector of JSTensorViewOut which are later + // converted to JSI types via JsiConversions.h + for (size_t i = 0; i < outputs.size(); i++) { + auto &outputTensor = outputs[i].toTensor(); + std::vector sizes = getTensorShape(outputTensor); + size_t bufferSize = outputTensor.numel() * outputTensor.element_size(); + auto buffer = std::make_shared(bufferSize); + std::memcpy(buffer->data(), outputTensor.const_data_ptr(), bufferSize); + auto jsTensor = std::make_shared( + sizes, outputTensor.scalar_type(), buffer); + output.emplace_back(jsTensor); + } + return output; +} + +Result> BaseModel::forward(const EValue &input_evalue) { + if (!module) { + throw std::runtime_error("Model not loaded: Cannot perform forward pass"); + } + return module->forward(input_evalue); +} -Result> BaseModel::forwardET(const EValue &input_value) { +Result> +BaseModel::forward(const std::vector &input_evalues) { if (!module) { - throw std::runtime_error("Forward called on unloaded model"); + throw std::runtime_error("Model not loaded: Cannot perform forward pass"); } - return module->forward(input_value); + return module->forward(input_evalues); +} + +std::size_t BaseModel::getMemoryLowerBound() { return memorySizeLowerBound; } + +void BaseModel::unload() { module.reset(nullptr); } + +std::vector +BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) { + auto sizes = tensor.sizes(); + return std::vector(sizes.begin(), sizes.end()); } } // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index de7e57d77a..79e426f199 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -6,6 +6,8 @@ #include #include #include +#include +#include namespace rnexecutorch { using namespace facebook; @@ -15,20 +17,26 @@ class BaseModel { public: BaseModel(const std::string &modelSource, std::shared_ptr callInvoker); - std::vector> getInputShape(); std::size_t getMemoryLowerBound(); void unload(); + std::vector getInputShape(std::string method_name, int index); + std::vector> + getAllInputShapes(std::string methodName = "forward"); + std::vector> + forwardJS(std::vector tensorViewVec); protected: - Result> forwardET(const EValue &input_value); + Result> forward(const EValue &input_value); + Result> forward(const std::vector &input_value); // If possible, models should not use the JS runtime to keep JSI internals // away from logic, however, sometimes this would incur too big of a penalty // (unnecessary copies instead of working on JS memory). In this case // CallInvoker can be used to get jsi::Runtime, and use it in a safe manner. std::shared_ptr callInvoker; - std::size_t memorySizeLowerBound{0}; private: + std::size_t memorySizeLowerBound{0}; std::unique_ptr module; + std::vector getTensorShape(const executorch::aten::Tensor &tensor); }; } // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index 69924f878c..c6b843f66d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -11,7 +11,7 @@ namespace rnexecutorch { Classification::Classification(const std::string &modelSource, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker) { - auto inputShapes = getInputShape(); + auto inputShapes = getAllInputShapes(); if (inputShapes.size() == 0) { throw std::runtime_error("Model seems to not take any input tensors."); } @@ -29,11 +29,10 @@ Classification::Classification(const std::string &modelSource, } std::unordered_map -Classification::forward(std::string imageSource) { - auto inputTensor = - imageprocessing::readImageToTensor(imageSource, getInputShape()[0]).first; - - auto forwardResult = forwardET(inputTensor); +Classification::generate(std::string imageSource) { + auto inputTensor = + imageprocessing::readImageToTensor(imageSource, getAllInputShapes()[0]).first; + auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { throw std::runtime_error( "Failed to forward, error: " + diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h index 75f9834610..1b0950e9a4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h @@ -15,7 +15,7 @@ class Classification : public BaseModel { public: Classification(const std::string &modelSource, std::shared_ptr callInvoker); - std::unordered_map forward(std::string imageSource); + std::unordered_map generate(std::string imageSource); private: std::unordered_map postprocess(const Tensor &tensor); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp index 237387cbb6..66f6ea4471 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp @@ -15,7 +15,7 @@ ImageSegmentation::ImageSegmentation( const std::string &modelSource, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker) { - auto inputShapes = getInputShape(); + auto inputShapes = getAllInputShapes(); if (inputShapes.size() == 0) { throw std::runtime_error("Model seems to not take any input tensors."); } @@ -33,14 +33,13 @@ ImageSegmentation::ImageSegmentation( numModelPixels = modelImageSize.area(); } -std::shared_ptr -ImageSegmentation::forward(std::string imageSource, - std::set> classesOfInterest, - bool resize) { +std::shared_ptr ImageSegmentation::generate( + std::string imageSource, + std::set> classesOfInterest, bool resize) { auto [inputTensor, originalSize] = - imageprocessing::readImageToTensor(imageSource, getInputShape()[0]); + imageprocessing::readImageToTensor(imageSource, getAllInputShapes()[0]); - auto forwardResult = forwardET(inputTensor); + auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { throw std::runtime_error( "Failed to forward, error: " + diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h index 42fe58d5d9..a3368ed081 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h @@ -22,8 +22,8 @@ class ImageSegmentation : public BaseModel { ImageSegmentation(const std::string &modelSource, std::shared_ptr callInvoker); std::shared_ptr - forward(std::string imageSource, - std::set> classesOfInterest, bool resize); + generate(std::string imageSource, + std::set> classesOfInterest, bool resize); private: std::shared_ptr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 023d13494b..e776612dff 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -8,7 +8,7 @@ ObjectDetection::ObjectDetection( const std::string &modelSource, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker) { - auto inputTensors = getInputShape(); + auto inputTensors = getAllInputShapes(); if (inputTensors.size() == 0) { throw std::runtime_error("Model seems to not take any input tensors."); } @@ -65,12 +65,12 @@ ObjectDetection::postprocess(const std::vector &tensors, return output; } -std::vector ObjectDetection::forward(std::string imageSource, - double detectionThreshold) { +std::vector ObjectDetection::generate(std::string imageSource, + double detectionThreshold) { auto [inputTensor, originalSize] = - imageprocessing::readImageToTensor(imageSource, getInputShape()[0]); + imageprocessing::readImageToTensor(imageSource, getAllInputShapes()[0]); - auto forwardResult = forwardET(inputTensor); + auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { throw std::runtime_error( "Failed to forward, error: " + diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h index ded334be13..2f63c7148a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h @@ -17,8 +17,8 @@ class ObjectDetection : public BaseModel { public: ObjectDetection(const std::string &modelSource, std::shared_ptr callInvoker); - std::vector forward(std::string imageSource, - double detectionThreshold); + std::vector generate(std::string imageSource, + double detectionThreshold); private: std::vector postprocess(const std::vector &tensors, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index a352984da1..52e3bbece2 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -16,7 +16,7 @@ using executorch::runtime::Error; StyleTransfer::StyleTransfer(const std::string &modelSource, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker) { - auto inputShapes = getInputShape(); + auto inputShapes = getAllInputShapes(); if (inputShapes.size() == 0) { throw std::runtime_error("Model seems to not take any input tensors."); } @@ -41,11 +41,11 @@ std::string StyleTransfer::postprocess(const Tensor &tensor, return imageprocessing::saveToTempFile(mat); } -std::string StyleTransfer::forward(std::string imageSource) { - auto [inputTensor, originalSize] = - imageprocessing::readImageToTensor(imageSource, getInputShape()[0]); +std::string StyleTransfer::generate(std::string imageSource) { + auto [inputTensor, originalSize] = + imageprocessing::readImageToTensor(imageSource, getAllInputShapes()[0]); - auto forwardResult = forwardET(inputTensor); + auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { throw std::runtime_error( "Failed to forward, error: " + diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h index 2e4c99f0e6..c9da1308e4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -19,7 +19,7 @@ class StyleTransfer : public BaseModel { public: StyleTransfer(const std::string &modelSource, std::shared_ptr callInvoker); - std::string forward(std::string imageSource); + std::string generate(std::string imageSource); private: std::string postprocess(const Tensor &tensor, cv::Size originalSize); diff --git a/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts b/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts index c7b1b501c9..42826a07a5 100644 --- a/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts +++ b/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts @@ -1,6 +1,6 @@ import { ExecutorchModule } from '../../modules/general/ExecutorchModule'; import { ResourceSource } from '../../types/common'; -import { useModule } from '../useModule'; +import { useNonStaticModule } from '../useNonStaticModule'; interface Props { modelSource: ResourceSource; @@ -11,7 +11,7 @@ export const useExecutorchModule = ({ modelSource, preventLoad = false, }: Props) => - useModule({ + useNonStaticModule({ module: ExecutorchModule, loadArgs: [modelSource], preventLoad, diff --git a/packages/react-native-executorch/src/index.tsx b/packages/react-native-executorch/src/index.tsx index 6721f4a12c..cb6d2ca24b 100644 --- a/packages/react-native-executorch/src/index.tsx +++ b/packages/react-native-executorch/src/index.tsx @@ -8,16 +8,20 @@ declare global { var loadImageSegmentation: (source: string) => any; var loadClassification: (source: string) => any; var loadObjectDetection: (source: string) => any; + var loadExecutorchModule: (source: string) => any; } // eslint-disable no-var - -if (global.loadStyleTransfer == null) { +if ( + global.loadStyleTransfer == null || + global.loadImageSegmentation == null || + global.loadExecutorchModule == null || + global.loadClassification == null +) { if (!ETInstallerNativeModule) { throw new Error( `Failed to install react-native-executorch: The native module could not be found.` ); } - ETInstallerNativeModule.install(); } @@ -43,14 +47,13 @@ export * from './modules/computer_vision/StyleTransferModule'; export * from './modules/computer_vision/ImageSegmentationModule'; export * from './modules/computer_vision/OCRModule'; export * from './modules/computer_vision/VerticalOCRModule'; +export * from './modules/general/ExecutorchModule'; export * from './modules/natural_language_processing/LLMModule'; export * from './modules/natural_language_processing/SpeechToTextModule'; export * from './modules/natural_language_processing/TextEmbeddingsModule'; export * from './modules/natural_language_processing/TokenizerModule'; -export * from './modules/general/ExecutorchModule'; - // utils export * from './utils/ResourceFetcher'; diff --git a/packages/react-native-executorch/src/modules/BaseNonStaticModule.ts b/packages/react-native-executorch/src/modules/BaseNonStaticModule.ts index 028a1550a8..9ad6b41c13 100644 --- a/packages/react-native-executorch/src/modules/BaseNonStaticModule.ts +++ b/packages/react-native-executorch/src/modules/BaseNonStaticModule.ts @@ -1,5 +1,23 @@ -export class BaseNonStaticModule { +import { ResourceSource } from '../types/common'; +import { TensorPtr } from '../types/common'; + +export abstract class BaseNonStaticModule { nativeModule: any = null; + + abstract load( + modelSource: ResourceSource, + onDownloadProgressCallback: (_: number) => void, + ...args: any[] + ): Promise; + + protected async forwardET(inputTensor: TensorPtr[]): Promise { + return await this.nativeModule.forward(inputTensor); + } + + async getInputShape(methodName: string, index: number): Promise { + return this.nativeModule.getInputShape(methodName, index); + } + delete() { if (this.nativeModule !== null) { this.nativeModule.unload(); diff --git a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts index 39c74034dd..1211811cf7 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts @@ -18,6 +18,6 @@ export class ClassificationModule extends BaseNonStaticModule { async forward(imageSource: string) { if (this.nativeModule == null) throw new Error(getError(ETError.ModuleNotLoaded)); - return await this.nativeModule.forward(imageSource); + return await this.nativeModule.generate(imageSource); } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index 83a33d3917..d2e1757a85 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -25,7 +25,7 @@ export class ImageSegmentationModule extends BaseNonStaticModule { throw new Error(getError(ETError.ModuleNotLoaded)); } - const stringDict = await this.nativeModule.forward( + const stringDict = await this.nativeModule.generate( imageSource, (classesOfInterest || []).map((label) => DeeplabLabel[label]), resize || false diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index 344943686d..abade93bba 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -22,6 +22,6 @@ export class ObjectDetectionModule extends BaseNonStaticModule { ): Promise { if (this.nativeModule == null) throw new Error(getError(ETError.ModuleNotLoaded)); - return await this.nativeModule.forward(imageSource, detectionThreshold); + return await this.nativeModule.generate(imageSource, detectionThreshold); } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index 0824f74235..7ce10a7a8d 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -18,6 +18,6 @@ export class StyleTransferModule extends BaseNonStaticModule { async forward(imageSource: string): Promise { if (this.nativeModule == null) throw new Error(getError(ETError.ModuleNotLoaded)); - return await this.nativeModule.forward(imageSource); + return await this.nativeModule.generate(imageSource); } } diff --git a/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts b/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts index c356c17be5..0fc19c6976 100644 --- a/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts +++ b/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts @@ -1,54 +1,21 @@ -import { ETError, getError } from '../../Error'; -import { ETModuleNativeModule } from '../../native/RnExecutorchModules'; +import { TensorPtr } from '../../types/common'; +import { BaseNonStaticModule } from '../BaseNonStaticModule'; import { ResourceSource } from '../../types/common'; -import { ETInput } from '../../types/common'; -import { getTypeIdentifier } from '../../types/common'; -import { BaseModule } from '../BaseModule'; - -export class ExecutorchModule extends BaseModule { - protected static override nativeModule = ETModuleNativeModule; - - static override async load(modelSource: ResourceSource) { - return await super.load(modelSource); - } - - static override async forward(input: ETInput[] | ETInput, shape: number[][]) { - if (!Array.isArray(input)) { - input = [input]; - } - - let inputTypeIdentifiers = []; - let modelInputs = []; - - for (let idx = 0; idx < input.length; idx++) { - let currentInputTypeIdentifier = getTypeIdentifier(input[idx] as ETInput); - if (currentInputTypeIdentifier === -1) { - throw new Error(getError(ETError.InvalidArgument)); - } - inputTypeIdentifiers.push(currentInputTypeIdentifier); - modelInputs.push([...(input[idx] as unknown as number[])]); - } - - try { - return await this.nativeModule.forward( - modelInputs, - shape, - inputTypeIdentifiers - ); - } catch (e) { - throw new Error(getError(e)); - } - } - - static async loadMethod(methodName: string) { - try { - await this.nativeModule.loadMethod(methodName); - } catch (e) { - throw new Error(getError(e)); - } +import { ResourceFetcher } from '../../utils/ResourceFetcher'; + +export class ExecutorchModule extends BaseNonStaticModule { + async load( + modelSource: ResourceSource, + onDownloadProgressCallback: (_: number) => void = () => {} + ): Promise { + const paths = await ResourceFetcher.fetchMultipleResources( + onDownloadProgressCallback, + modelSource + ); + this.nativeModule = global.loadExecutorchModule(paths[0] || ''); } - static async loadForward() { - await this.loadMethod('forward'); + async forward(inputTensor: TensorPtr[]): Promise { + return await this.forwardET(inputTensor); } } diff --git a/packages/react-native-executorch/src/types/common.ts b/packages/react-native-executorch/src/types/common.ts index 688ac3869e..3b375c338f 100644 --- a/packages/react-native-executorch/src/types/common.ts +++ b/packages/react-native-executorch/src/types/common.ts @@ -15,3 +15,47 @@ export type ETInput = | BigInt64Array | Float32Array | Float64Array; + +export enum ScalarType { + BYTE = 0, + CHAR = 1, + SHORT = 2, + INT = 3, + LONG = 4, + HALF = 5, + FLOAT = 6, + DOUBLE = 7, + BOOL = 11, + QINT8 = 12, + QUINT8 = 13, + QINT32 = 14, + QUINT4X2 = 16, + QUINT2X4 = 17, + BITS16 = 22, + FLOAT8E5M2 = 23, + FLOAT8E4M3FN = 24, + FLOAT8E5M2FNUZ = 25, + FLOAT8E4M3FNUZ = 26, + UINT16 = 27, + UINT32 = 28, + UINT64 = 29, +} + +export type TensorBuffer = + | ArrayBuffer + | Float32Array + | Float64Array + | Int8Array + | Int16Array + | Int32Array + | Uint8Array + | Uint16Array + | Uint32Array + | BigInt64Array + | BigUint64Array; + +export interface TensorPtr { + dataPtr: TensorBuffer; + sizes: number[]; + scalarType: ScalarType; +}