diff --git a/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt b/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt deleted file mode 100644 index c18fa8ed32..0000000000 --- a/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt +++ /dev/null @@ -1,58 +0,0 @@ -package com.swmansion.rnexecutorch - -import android.util.Log -import com.facebook.react.bridge.Promise -import com.facebook.react.bridge.ReactApplicationContext -import com.facebook.react.bridge.ReadableArray -import com.swmansion.rnexecutorch.models.imagesegmentation.ImageSegmentationModel -import com.swmansion.rnexecutorch.utils.ETError -import com.swmansion.rnexecutorch.utils.ImageProcessor -import org.opencv.android.OpenCVLoader - -class ImageSegmentation( - reactContext: ReactApplicationContext, -) : NativeImageSegmentationSpec(reactContext) { - private lateinit var model: ImageSegmentationModel - - companion object { - const val NAME = "ImageSegmentation" - - init { - if (!OpenCVLoader.initLocal()) { - Log.d("rn_executorch", "OpenCV not loaded") - } else { - Log.d("rn_executorch", "OpenCV loaded") - } - } - } - - override fun loadModule( - modelSource: String, - promise: Promise, - ) { - try { - model = ImageSegmentationModel(reactApplicationContext) - model.loadModel(modelSource) - promise.resolve(0) - } catch (e: Exception) { - promise.reject(e.message!!, ETError.InvalidModelSource.toString()) - } - } - - override fun forward( - input: String, - classesOfInterest: ReadableArray, - resize: Boolean, - promise: Promise, - ) { - try { - val output = - model.runModel(Triple(ImageProcessor.readImage(input), classesOfInterest, resize)) - promise.resolve(output) - } catch (e: Exception) { - promise.reject(e.message!!, e.message) - } - } - - override fun getName(): String = NAME -} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index f90a182b5b..ee289e7a1d 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -28,8 +28,6 @@ class RnExecutorchPackage : TurboReactPackage() { OCR(reactContext) } else if (name == VerticalOCR.NAME) { VerticalOCR(reactContext) - } else if (name == ImageSegmentation.NAME) { - ImageSegmentation(reactContext) } else if (name == ETInstaller.NAME) { ETInstaller(reactContext) } else if (name == Tokenizer.NAME) { @@ -119,17 +117,6 @@ class RnExecutorchPackage : TurboReactPackage() { true, ) - moduleInfos[ImageSegmentation.NAME] = - ReactModuleInfo( - ImageSegmentation.NAME, - ImageSegmentation.NAME, - false, // canOverrideExistingModule - false, // needsEagerInit - true, // hasConstants - false, // isCxxModule - true, - ) - moduleInfos[Tokenizer.NAME] = ReactModuleInfo( Tokenizer.NAME, diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt deleted file mode 100644 index 7ba7fcb5c1..0000000000 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt +++ /dev/null @@ -1,26 +0,0 @@ -package com.swmansion.rnexecutorch.models.imagesegmentation - -val deeplabv3_resnet50_labels: Array = - arrayOf( - "BACKGROUND", - "AEROPLANE", - "BICYCLE", - "BIRD", - "BOAT", - "BOTTLE", - "BUS", - "CAR", - "CAT", - "CHAIR", - "COW", - "DININGTABLE", - "DOG", - "HORSE", - "MOTORBIKE", - "PERSON", - "POTTEDPLANT", - "SHEEP", - "SOFA", - "TRAIN", - "TVMONITOR", - ) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt deleted file mode 100644 index 36c1594b49..0000000000 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt +++ /dev/null @@ -1,139 +0,0 @@ -package com.swmansion.rnexecutorch.models.imagesegmentation - -import com.facebook.react.bridge.Arguments -import com.facebook.react.bridge.ReactApplicationContext -import com.facebook.react.bridge.ReadableArray -import com.facebook.react.bridge.WritableMap -import com.swmansion.rnexecutorch.models.BaseModel -import com.swmansion.rnexecutorch.utils.ArrayUtils -import com.swmansion.rnexecutorch.utils.ImageProcessor -import com.swmansion.rnexecutorch.utils.softmax -import org.opencv.core.CvType -import org.opencv.core.Mat -import org.opencv.core.Size -import org.opencv.imgproc.Imgproc -import org.pytorch.executorch.EValue - -class ImageSegmentationModel( - reactApplicationContext: ReactApplicationContext, -) : BaseModel, WritableMap>(reactApplicationContext) { - private lateinit var originalSize: Size - - private fun getModelImageSize(): Size { - val inputShape = module.getInputShape(0) - val width = inputShape[inputShape.lastIndex] - val height = inputShape[inputShape.lastIndex - 1] - - return Size(height.toDouble(), width.toDouble()) - } - - fun preprocess(input: Mat): EValue { - originalSize = input.size() - Imgproc.resize(input, input, getModelImageSize()) - return ImageProcessor.matToEValue(input, module.getInputShape(0)) - } - - private fun extractResults( - result: FloatArray, - numLabels: Int, - resize: Boolean, - ): List { - val modelSize = getModelImageSize() - val numModelPixels = (modelSize.height * modelSize.width).toInt() - - val extractedLabelScores = mutableListOf() - - for (label in 0.., - numLabels: Int, - outputSize: Size, - ): IntArray { - val numPixels = (outputSize.height * outputSize.width).toInt() - val argMax = IntArray(numPixels) - for (pixel in 0..() - for (buffer in labelScores) { - scores.add(buffer[pixel]) - } - val adjustedScores = softmax(scores.toTypedArray()) - for (label in 0.., - classesOfInterest: ReadableArray, - resize: Boolean, - ): WritableMap { - val outputData = output[0].toTensor().dataAsFloatArray - val modelSize = getModelImageSize() - val numLabels = deeplabv3_resnet50_labels.size - - require(outputData.count() == (numLabels * modelSize.height * modelSize.width).toInt()) { "Model generated unexpected output size." } - - val outputSize = if (resize) originalSize else modelSize - - val extractedResults = extractResults(outputData, numLabels, resize) - - val argMax = adjustScoresPerPixel(extractedResults, numLabels, outputSize) - - val labelSet = mutableSetOf() - // Filter by the label set when base class changed - for (i in 0..): WritableMap { - val modelInput = preprocess(input.first) - val modelOutput = forward(modelInput) - return postprocess(modelOutput, input.second, input.third) - } -} diff --git a/common/rnexecutorch/RnExecutorchInstaller.cpp b/common/rnexecutorch/RnExecutorchInstaller.cpp index 5a655be862..bdfdab231a 100644 --- a/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -1,6 +1,7 @@ #include "RnExecutorchInstaller.h" #include +#include #include namespace rnexecutorch { @@ -19,5 +20,10 @@ void RnExecutorchInstaller::injectJSIBindings( *jsiRuntime, "loadStyleTransfer", RnExecutorchInstaller::loadModel(jsiRuntime, jsCallInvoker, "loadStyleTransfer")); + + jsiRuntime->global().setProperty( + *jsiRuntime, "loadImageSegmentation", + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadImageSegmentation")); } } // namespace rnexecutorch \ No newline at end of file diff --git a/common/rnexecutorch/RnExecutorchInstaller.h b/common/rnexecutorch/RnExecutorchInstaller.h index 8d8a712ba9..3acb45c9b1 100644 --- a/common/rnexecutorch/RnExecutorchInstaller.h +++ b/common/rnexecutorch/RnExecutorchInstaller.h @@ -49,7 +49,7 @@ class RnExecutorchInstaller { jsiconversion::getValue(args[0], runtime); auto modelImplementationPtr = - std::make_shared(source, &runtime); + std::make_shared(source, jsCallInvoker); auto modelHostObject = std::make_shared>( modelImplementationPtr, jsCallInvoker); diff --git a/common/rnexecutorch/data_processing/Numerical.cpp b/common/rnexecutorch/data_processing/Numerical.cpp new file mode 100644 index 0000000000..618683f007 --- /dev/null +++ b/common/rnexecutorch/data_processing/Numerical.cpp @@ -0,0 +1,19 @@ +#include "Numerical.h" + +#include +#include + +namespace rnexecutorch::numerical { +void softmax(std::vector &v) { + float max = *std::max_element(v.begin(), v.end()); + + float sum = 0.0f; + for (float &x : v) { + x = std::exp(x - max); + sum += x; + } + for (float &x : v) { + x /= sum; + } +} +} // namespace rnexecutorch::numerical \ No newline at end of file diff --git a/common/rnexecutorch/data_processing/Numerical.h b/common/rnexecutorch/data_processing/Numerical.h new file mode 100644 index 0000000000..66d96b51ee --- /dev/null +++ b/common/rnexecutorch/data_processing/Numerical.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace rnexecutorch::numerical { +void softmax(std::vector &v); +} // namespace rnexecutorch::numerical \ No newline at end of file diff --git a/common/rnexecutorch/host_objects/JsiConversions.h b/common/rnexecutorch/host_objects/JsiConversions.h index c9dad8f834..4cd1c43dcd 100644 --- a/common/rnexecutorch/host_objects/JsiConversions.h +++ b/common/rnexecutorch/host_objects/JsiConversions.h @@ -1,8 +1,10 @@ #pragma once -#include +#include #include +#include + namespace rnexecutorch::jsiconversion { using namespace facebook; @@ -43,19 +45,33 @@ getValue>(const jsi::Value &val, return result; } +// C++ set from JS array. Set with heterogenerous look-up (adding std::less<> +// enables querying with std::string_view). +template <> +inline std::set> +getValue>>(const jsi::Value &val, + jsi::Runtime &runtime) { + + jsi::Array array = val.asObject(runtime).asArray(runtime); + size_t length = array.size(runtime); + std::set> result; + + for (size_t i = 0; i < length; ++i) { + jsi::Value element = array.getValueAtIndex(runtime, i); + result.insert(getValue(element, runtime)); + } + return result; +} + // Conversion from C++ types to jsi -------------------------------------------- // Implementation functions might return any type, but in a promise we can only // return jsi::Value or jsi::Object. For each type being returned // we add a function here. -// Identity function for the sake of completeness -inline jsi::Value getJsiValue(jsi::Value &&value, jsi::Runtime &runtime) { - return std::move(value); -} - -inline jsi::Value getJsiValue(jsi::Object &&value, jsi::Runtime &runtime) { - return jsi::Value(std::move(value)); +inline jsi::Value getJsiValue(std::shared_ptr valuePtr, + jsi::Runtime &runtime) { + return std::move(*valuePtr); } inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) { diff --git a/common/rnexecutorch/host_objects/ModelHostObject.h b/common/rnexecutorch/host_objects/ModelHostObject.h index 922c07b825..a81aebcce7 100644 --- a/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/common/rnexecutorch/host_objects/ModelHostObject.h @@ -55,9 +55,10 @@ template class ModelHostObject : public JsiHostObject { try { auto result = std::apply(std::bind_front(FnPtr, model), argsConverted); - - callInvoker->invokeAsync([promise, result = std::move(result)]( - jsi::Runtime &runtime) { + // The result is copied. It should either be quickly copiable, + // or passed with a shared_ptr. + callInvoker->invokeAsync([promise, + result](jsi::Runtime &runtime) { promise->resolve( jsiconversion::getJsiValue(std::move(result), runtime)); }); diff --git a/common/rnexecutorch/jsi/OwningArrayBuffer.h b/common/rnexecutorch/jsi/OwningArrayBuffer.h new file mode 100644 index 0000000000..51e9b63e49 --- /dev/null +++ b/common/rnexecutorch/jsi/OwningArrayBuffer.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +namespace rnexecutorch { + +using namespace facebook; + +class OwningArrayBuffer : public jsi::MutableBuffer { +public: + OwningArrayBuffer(const size_t size) : size_(size) { + data_ = new uint8_t[size]; + } + ~OwningArrayBuffer() override { delete[] data_; } + + OwningArrayBuffer(const OwningArrayBuffer &) = delete; + OwningArrayBuffer(OwningArrayBuffer &&) = delete; + OwningArrayBuffer &operator=(const OwningArrayBuffer &) = delete; + OwningArrayBuffer &operator=(OwningArrayBuffer &&) = delete; + + [[nodiscard]] size_t size() const override { return size_; } + uint8_t *data() override { return data_; } + +private: + uint8_t *data_; + const size_t size_; +}; + +} // namespace rnexecutorch \ No newline at end of file diff --git a/common/rnexecutorch/models/BaseModel.cpp b/common/rnexecutorch/models/BaseModel.cpp index 5da87eeee8..bb9073640f 100644 --- a/common/rnexecutorch/models/BaseModel.cpp +++ b/common/rnexecutorch/models/BaseModel.cpp @@ -4,14 +4,15 @@ namespace rnexecutorch { +using namespace facebook; using ::executorch::extension::Module; using ::executorch::runtime::Error; BaseModel::BaseModel(const std::string &modelSource, - facebook::jsi::Runtime *runtime) + std::shared_ptr callInvoker) : module(std::make_unique( modelSource, Module::LoadMode::MmapUseMlockIgnoreErrors)), - runtime(runtime) { + callInvoker(callInvoker) { Error loadError = module->load(); if (loadError != Error::Ok) { throw std::runtime_error("Couldn't load the model, error: " + diff --git a/common/rnexecutorch/models/BaseModel.h b/common/rnexecutorch/models/BaseModel.h index 8c486f38a2..492ae8c044 100644 --- a/common/rnexecutorch/models/BaseModel.h +++ b/common/rnexecutorch/models/BaseModel.h @@ -2,19 +2,25 @@ #include +#include #include #include namespace rnexecutorch { +using namespace facebook; class BaseModel { public: - BaseModel(const std::string &modelSource, facebook::jsi::Runtime *runtime); + BaseModel(const std::string &modelSource, + std::shared_ptr callInvoker); std::vector> getInputShape(); protected: std::unique_ptr module; - facebook::jsi::Runtime *runtime; + // If possible, models should not use the JS runtime to keep JSI internals + // away from logic, however, sometimes this would incur too big of a penalty + // (unnecessary copies instead of working on JS memory). In this case + // CallInvoker can be used to get jsi::Runtime, and use it in a safe manner. + std::shared_ptr callInvoker; }; - } // namespace rnexecutorch \ No newline at end of file diff --git a/ios/RnExecutorch/models/image_segmentation/Constants.mm b/common/rnexecutorch/models/image_segmentation/Constants.h similarity index 62% rename from ios/RnExecutorch/models/image_segmentation/Constants.mm rename to common/rnexecutorch/models/image_segmentation/Constants.h index f28693c75f..a6d69e1c29 100644 --- a/ios/RnExecutorch/models/image_segmentation/Constants.mm +++ b/common/rnexecutorch/models/image_segmentation/Constants.h @@ -1,8 +1,13 @@ -#import "Constants.h" +#pragma once -const std::vector deeplabv3_resnet50_labels = { +#include +#include + +namespace rnexecutorch { +inline constexpr std::array deeplabv3_resnet50_labels = { "BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT", "BOTTLE", "BUS", "CAR", "CAT", "CHAIR", "COW", "DININGTABLE", "DOG", "HORSE", "MOTORBIKE", "PERSON", "POTTEDPLANT", "SHEEP", "SOFA", "TRAIN", "TVMONITOR"}; +} \ No newline at end of file diff --git a/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp b/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp new file mode 100644 index 0000000000..f2f595745b --- /dev/null +++ b/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp @@ -0,0 +1,186 @@ +#include "ImageSegmentation.h" + +#include + +#include + +#include +#include +#include +#include + +namespace rnexecutorch { + +ImageSegmentation::ImageSegmentation( + const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + auto inputTensors = getInputShape(); + if (inputTensors.size() == 0) { + throw std::runtime_error("Model seems to not take any input tensors."); + } + std::vector modelInputShape = inputTensors[0]; + if (modelInputShape.size() < 2) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Unexpected model input size, expected at least 2 dimentions " + "but got: %zu.", + modelInputShape.size()); + throw std::runtime_error(errorMessage); + } + modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], + modelInputShape[modelInputShape.size() - 2]); + numModelPixels = modelImageSize.area(); +} + +std::shared_ptr +ImageSegmentation::forward(std::string imageSource, + std::set> classesOfInterest, + bool resize) { + auto [inputTensor, originalSize] = preprocess(imageSource); + + auto forwardResult = module->forward(inputTensor); + if (!forwardResult.ok()) { + throw std::runtime_error( + "Failed to forward, error: " + + std::to_string(static_cast(forwardResult.error()))); + } + + return postprocess(forwardResult->at(0).toTensor(), originalSize, + classesOfInterest, resize); +} + +std::pair +ImageSegmentation::preprocess(const std::string &imageSource) { + cv::Mat input = imageprocessing::readImage(imageSource); + cv::Size inputSize = input.size(); + + cv::resize(input, input, modelImageSize); + + std::vector inputVector = imageprocessing::colorMatToVector(input); + return { + executorch::extension::make_tensor_ptr(getInputShape()[0], inputVector), + inputSize}; +} + +std::shared_ptr ImageSegmentation::postprocess( + const Tensor &tensor, cv::Size originalSize, + std::set> classesOfInterest, bool resize) { + + auto dataPtr = static_cast(tensor.const_data_ptr()); + auto resultData = std::span(dataPtr, tensor.numel()); + + // We copy the ET-owned data to jsi array buffers that can be directly + // returned to JS + std::vector> resultClasses; + resultClasses.reserve(numClasses); + for (std::size_t cl = 0; cl < numClasses; ++cl) { + auto classBuffer = + std::make_shared(numModelPixels * sizeof(float)); + resultClasses.push_back(classBuffer); + std::memcpy(classBuffer->data(), &resultData[cl * numModelPixels], + numModelPixels * sizeof(float)); + } + + // Apply softmax per each pixel across all classes + for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { + std::vector classValues(numClasses); + for (std::size_t cl = 0; cl < numClasses; ++cl) { + classValues[cl] = + reinterpret_cast(resultClasses[cl]->data())[pixel]; + } + numerical::softmax(classValues); + for (std::size_t cl = 0; cl < numClasses; ++cl) { + reinterpret_cast(resultClasses[cl]->data())[pixel] = + classValues[cl]; + } + } + + // Calculate the maximum class for each pixel + auto argmax = + std::make_shared(numModelPixels * sizeof(int32_t)); + for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { + float max = reinterpret_cast(resultClasses[0]->data())[pixel]; + int maxInd = 0; + for (int cl = 1; cl < numClasses; ++cl) { + if (reinterpret_cast(resultClasses[cl]->data())[pixel] > max) { + maxInd = cl; + max = reinterpret_cast(resultClasses[cl]->data())[pixel]; + } + } + reinterpret_cast(argmax->data())[pixel] = maxInd; + } + + auto buffersToReturn = std::make_shared>>(); + for (std::size_t cl = 0; cl < numClasses; ++cl) { + if (classesOfInterest.contains(deeplabv3_resnet50_labels[cl])) { + (*buffersToReturn)[deeplabv3_resnet50_labels[cl]] = resultClasses[cl]; + } + } + + // Resize selected classes and argmax + if (resize) { + cv::Mat argmaxMat(modelImageSize, CV_32SC1, argmax->data()); + cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0, + cv::InterpolationFlags::INTER_NEAREST); + argmax = std::make_shared(originalSize.area() * + sizeof(int32_t)); + std::memcpy(argmax->data(), argmaxMat.data, + originalSize.area() * sizeof(int32_t)); + + for (auto &[label, arrayBuffer] : *buffersToReturn) { + cv::Mat classMat(modelImageSize, CV_32FC1, arrayBuffer->data()); + cv::resize(classMat, classMat, originalSize); + arrayBuffer = std::make_shared(originalSize.area() * + sizeof(float)); + std::memcpy(arrayBuffer->data(), classMat.data, + originalSize.area() * sizeof(float)); + } + } + return populateDictionary(argmax, buffersToReturn); +} + +std::shared_ptr ImageSegmentation::populateDictionary( + std::shared_ptr argmax, + std::shared_ptr>> + classesToOutput) { + // Synchronize the invoked thread to return when the dict is constructed + auto promisePtr = std::make_shared>(); + std::future doneFuture = promisePtr->get_future(); + + std::shared_ptr dictPtr = nullptr; + callInvoker->invokeAsync( + [argmax, classesToOutput, &dictPtr, promisePtr](jsi::Runtime &runtime) { + dictPtr = std::make_shared(runtime); + auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax); + + auto int32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Int32Array"); + auto int32Array = + int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) + .getObject(runtime); + dictPtr->setProperty(runtime, "ARGMAX", int32Array); + + for (auto &[classLabel, owningBuffer] : *classesToOutput) { + auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); + + auto float32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Float32Array"); + auto float32Array = + float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) + .getObject(runtime); + + dictPtr->setProperty( + runtime, jsi::String::createFromAscii(runtime, classLabel.data()), + float32Array); + } + promisePtr->set_value(); + }); + + doneFuture.wait(); + return dictPtr; +} + +} // namespace rnexecutorch \ No newline at end of file diff --git a/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h b/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h new file mode 100644 index 0000000000..ce919a7539 --- /dev/null +++ b/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace rnexecutorch { +using namespace facebook; + +using executorch::aten::Tensor; +using executorch::extension::TensorPtr; + +class ImageSegmentation : public BaseModel { +public: + ImageSegmentation(const std::string &modelSource, + std::shared_ptr callInvoker); + std::shared_ptr + forward(std::string imageSource, + std::set> classesOfInterest, bool resize); + +private: + std::pair preprocess(const std::string &imageSource); + std::shared_ptr + postprocess(const Tensor &tensor, cv::Size originalSize, + std::set> classesOfInterest, + bool resize); + std::shared_ptr populateDictionary( + std::shared_ptr argmax, + std::shared_ptr>> + classesToOutput); + + static constexpr std::size_t numClasses{deeplabv3_resnet50_labels.size()}; + cv::Size modelImageSize; + std::size_t numModelPixels; +}; +} // namespace rnexecutorch \ No newline at end of file diff --git a/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index 54aa6d318a..7c5234aa82 100644 --- a/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -15,9 +15,21 @@ using executorch::extension::TensorPtr; using executorch::runtime::Error; StyleTransfer::StyleTransfer(const std::string &modelSource, - jsi::Runtime *runtime) - : BaseModel(modelSource, runtime) { - std::vector modelInputShape = getInputShape()[0]; + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + auto inputTensors = getInputShape(); + if (inputTensors.size() == 0) { + throw std::runtime_error("Model seems to not take any input tensors."); + } + std::vector modelInputShape = inputTensors[0]; + if (modelInputShape.size() < 2) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Unexpected model input size, expected at least 2 dimentions " + "but got: %zu.", + modelInputShape.size()); + throw std::runtime_error(errorMessage); + } modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], modelInputShape[modelInputShape.size() - 2]); } diff --git a/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/common/rnexecutorch/models/style_transfer/StyleTransfer.h index d1da39907b..809cac804d 100644 --- a/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -17,7 +17,8 @@ using executorch::extension::TensorPtr; class StyleTransfer : public BaseModel { public: - StyleTransfer(const std::string &modelSource, jsi::Runtime *runtime); + StyleTransfer(const std::string &modelSource, + std::shared_ptr callInvoker); std::string forward(std::string imageSource); private: diff --git a/examples/computer-vision/ios/Podfile.lock b/examples/computer-vision/ios/Podfile.lock index 036a62c773..b11ac3d434 100644 --- a/examples/computer-vision/ios/Podfile.lock +++ b/examples/computer-vision/ios/Podfile.lock @@ -43,7 +43,7 @@ PODS: - hermes-engine (0.76.9): - hermes-engine/Pre-built (= 0.76.9) - hermes-engine/Pre-built (0.76.9) - - opencv-rne (0.1.0) + - opencv-rne (4.11.0) - RCT-Folly (2024.10.14.00): - boost - DoubleConversion @@ -1326,7 +1326,7 @@ PODS: - DoubleConversion - glog - hermes-engine - - opencv-rne (~> 0.1.0) + - opencv-rne (~> 4.11.0) - RCT-Folly (= 2024.10.14.00) - RCTRequired - RCTTypeSafety @@ -1343,6 +1343,7 @@ PODS: - ReactCodegen - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core + - sqlite3 - Yoga - react-native-image-picker (7.2.3): - DoubleConversion @@ -1857,6 +1858,9 @@ PODS: - ReactCommon/turbomodule/core - Yoga - SocketRocket (0.7.1) + - sqlite3 (3.49.1): + - sqlite3/common (= 3.49.1) + - sqlite3/common (3.49.1) - Yoga (0.0.0) DEPENDENCIES: @@ -1944,6 +1948,7 @@ SPEC REPOS: trunk: - opencv-rne - SocketRocket + - sqlite3 EXTERNAL SOURCES: boost: @@ -2117,7 +2122,7 @@ SPEC CHECKSUMS: fmt: 01b82d4ca6470831d1cc0852a1af644be019e8f6 glog: 08b301085f15bcbb6ff8632a8ebaf239aae04e6a hermes-engine: 9e868dc7be781364296d6ee2f56d0c1a9ef0bb11 - opencv-rne: 63e933ae2373fc91351f9a348dc46c3f523c2d3f + opencv-rne: 2305807573b6e29c8c87e3416ab096d09047a7a0 RCT-Folly: ea9d9256ba7f9322ef911169a9f696e5857b9e17 RCTDeprecation: ebe712bb05077934b16c6bf25228bdec34b64f83 RCTRequired: ca91e5dd26b64f577b528044c962baf171c6b716 @@ -2181,6 +2186,7 @@ SPEC CHECKSUMS: RNReanimated: 2e5069649cbab2c946652d3b97589b2ae0526220 RNSVG: b889dc9c1948eeea0576a16cc405c91c37a12c19 SocketRocket: d4aabe649be1e368d1318fdf28a022d714d65748 + sqlite3: fc1400008a9b3525f5914ed715a5d1af0b8f4983 Yoga: feb4910aba9742cfedc059e2b2902e22ffe9954a PODFILE CHECKSUM: d2d76566c3147849493ab633854730a1f661227b diff --git a/examples/computer-vision/screens/ImageSegmentationScreen.tsx b/examples/computer-vision/screens/ImageSegmentationScreen.tsx index 3a43236b98..6c3ce8e600 100644 --- a/examples/computer-vision/screens/ImageSegmentationScreen.tsx +++ b/examples/computer-vision/screens/ImageSegmentationScreen.tsx @@ -3,8 +3,8 @@ import { BottomBar } from '../components/BottomBar'; import { getImage } from '../utils'; import { useImageSegmentation, - DeeplabLabel, DEEPLAB_V3_RESNET50, + DeeplabLabel, } from 'react-native-executorch'; import { Canvas, diff --git a/examples/llm-tool-calling/ios/Podfile.lock b/examples/llm-tool-calling/ios/Podfile.lock index a60d10dced..7bca4faae2 100644 --- a/examples/llm-tool-calling/ios/Podfile.lock +++ b/examples/llm-tool-calling/ios/Podfile.lock @@ -47,7 +47,7 @@ PODS: - hermes-engine (0.76.9): - hermes-engine/Pre-built (= 0.76.9) - hermes-engine/Pre-built (0.76.9) - - opencv-rne (0.1.0) + - opencv-rne (4.11.0) - RCT-Folly (2024.10.14.00): - boost - DoubleConversion @@ -1330,7 +1330,7 @@ PODS: - DoubleConversion - glog - hermes-engine - - opencv-rne (~> 0.1.0) + - opencv-rne (~> 4.11.0) - RCT-Folly (= 2024.10.14.00) - RCTRequired - RCTTypeSafety @@ -1347,6 +1347,7 @@ PODS: - ReactCodegen - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core + - sqlite3 - Yoga - react-native-safe-area-context (4.12.0): - DoubleConversion @@ -1817,6 +1818,9 @@ PODS: - ReactCommon/turbomodule/core - Yoga - SocketRocket (0.7.1) + - sqlite3 (3.49.1): + - sqlite3/common (= 3.49.1) + - sqlite3/common (3.49.1) - Yoga (0.0.0) DEPENDENCIES: @@ -1904,6 +1908,7 @@ SPEC REPOS: trunk: - opencv-rne - SocketRocket + - sqlite3 EXTERNAL SOURCES: boost: @@ -2079,8 +2084,8 @@ SPEC CHECKSUMS: fmt: 01b82d4ca6470831d1cc0852a1af644be019e8f6 glog: 08b301085f15bcbb6ff8632a8ebaf239aae04e6a hermes-engine: 9e868dc7be781364296d6ee2f56d0c1a9ef0bb11 - opencv-rne: 63e933ae2373fc91351f9a348dc46c3f523c2d3f - RCT-Folly: 7b4f73a92ad9571b9dbdb05bb30fad927fa971e1 + opencv-rne: 2305807573b6e29c8c87e3416ab096d09047a7a0 + RCT-Folly: ea9d9256ba7f9322ef911169a9f696e5857b9e17 RCTDeprecation: ebe712bb05077934b16c6bf25228bdec34b64f83 RCTRequired: ca91e5dd26b64f577b528044c962baf171c6b716 RCTTypeSafety: e7678bd60850ca5a41df9b8dc7154638cb66871f @@ -2141,7 +2146,8 @@ SPEC CHECKSUMS: RNReanimated: b95559eb62609b22b99f6e7f20cb892c20b393dc RNSVG: 81d52481cde97ce0dcc81a55b0310723817088d0 SocketRocket: d4aabe649be1e368d1318fdf28a022d714d65748 - Yoga: 40f19fff64dce86773bf8b602c7070796c007970 + sqlite3: fc1400008a9b3525f5914ed715a5d1af0b8f4983 + Yoga: feb4910aba9742cfedc059e2b2902e22ffe9954a PODFILE CHECKSUM: 9378e690c7b699685381b113789d682762e327e8 diff --git a/examples/speech-to-text/ios/Podfile.lock b/examples/speech-to-text/ios/Podfile.lock index 37238a6a62..490aee02eb 100644 --- a/examples/speech-to-text/ios/Podfile.lock +++ b/examples/speech-to-text/ios/Podfile.lock @@ -43,7 +43,7 @@ PODS: - hermes-engine (0.76.9): - hermes-engine/Pre-built (= 0.76.9) - hermes-engine/Pre-built (0.76.9) - - opencv-rne (0.1.0) + - opencv-rne (4.11.0) - RCT-Folly (2024.10.14.00): - boost - DoubleConversion @@ -1326,7 +1326,7 @@ PODS: - DoubleConversion - glog - hermes-engine - - opencv-rne (~> 0.1.0) + - opencv-rne (~> 4.11.0) - RCT-Folly (= 2024.10.14.00) - RCTRequired - RCTTypeSafety @@ -1343,6 +1343,7 @@ PODS: - ReactCodegen - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core + - sqlite3 - Yoga - react-native-image-picker (7.2.3): - DoubleConversion @@ -1859,6 +1860,9 @@ PODS: - ReactCommon/turbomodule/core - Yoga - SocketRocket (0.7.1) + - sqlite3 (3.49.1): + - sqlite3/common (= 3.49.1) + - sqlite3/common (3.49.1) - Yoga (0.0.0) DEPENDENCIES: @@ -1948,6 +1952,7 @@ SPEC REPOS: trunk: - opencv-rne - SocketRocket + - sqlite3 EXTERNAL SOURCES: boost: @@ -2125,7 +2130,7 @@ SPEC CHECKSUMS: fmt: 01b82d4ca6470831d1cc0852a1af644be019e8f6 glog: 08b301085f15bcbb6ff8632a8ebaf239aae04e6a hermes-engine: 9e868dc7be781364296d6ee2f56d0c1a9ef0bb11 - opencv-rne: 63e933ae2373fc91351f9a348dc46c3f523c2d3f + opencv-rne: 2305807573b6e29c8c87e3416ab096d09047a7a0 RCT-Folly: ea9d9256ba7f9322ef911169a9f696e5857b9e17 RCTDeprecation: ebe712bb05077934b16c6bf25228bdec34b64f83 RCTRequired: ca91e5dd26b64f577b528044c962baf171c6b716 @@ -2191,6 +2196,7 @@ SPEC CHECKSUMS: RNReanimated: 2e5069649cbab2c946652d3b97589b2ae0526220 RNSVG: b889dc9c1948eeea0576a16cc405c91c37a12c19 SocketRocket: d4aabe649be1e368d1318fdf28a022d714d65748 + sqlite3: fc1400008a9b3525f5914ed715a5d1af0b8f4983 Yoga: feb4910aba9742cfedc059e2b2902e22ffe9954a PODFILE CHECKSUM: 8264e1ef5c1c85c206e4efb2c2c7e7b66ab269ed diff --git a/examples/text-embeddings/ios/Podfile.lock b/examples/text-embeddings/ios/Podfile.lock index a0c9f4f036..86cf9ca44e 100644 --- a/examples/text-embeddings/ios/Podfile.lock +++ b/examples/text-embeddings/ios/Podfile.lock @@ -42,7 +42,7 @@ PODS: - hermes-engine (0.76.8): - hermes-engine/Pre-built (= 0.76.8) - hermes-engine/Pre-built (0.76.8) - - opencv-rne (0.1.0) + - opencv-rne (4.11.0) - RCT-Folly (2024.01.01.00): - boost - DoubleConversion @@ -1282,7 +1282,7 @@ PODS: - DoubleConversion - glog - hermes-engine - - opencv-rne (~> 0.1.0) + - opencv-rne (~> 4.11.0) - RCT-Folly (= 2024.01.01.00) - RCTRequired - RCTTypeSafety @@ -1299,6 +1299,7 @@ PODS: - ReactCodegen - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core + - sqlite3 - Yoga - React-nativeconfig (0.76.8) - React-NativeModulesApple (0.76.8): @@ -1568,6 +1569,9 @@ PODS: - React-perflogger (= 0.76.8) - React-utils (= 0.76.8) - SocketRocket (0.7.1) + - sqlite3 (3.49.1): + - sqlite3/common (= 3.49.1) + - sqlite3/common (3.49.1) - Yoga (0.0.0) DEPENDENCIES: @@ -1649,6 +1653,7 @@ SPEC REPOS: trunk: - opencv-rne - SocketRocket + - sqlite3 EXTERNAL SOURCES: boost: @@ -1809,7 +1814,7 @@ SPEC CHECKSUMS: fmt: 10c6e61f4be25dc963c36bd73fc7b1705fe975be glog: 08b301085f15bcbb6ff8632a8ebaf239aae04e6a hermes-engine: ea89b864870ef107096c440c56eb6cba409b2689 - opencv-rne: 63e933ae2373fc91351f9a348dc46c3f523c2d3f + opencv-rne: 2305807573b6e29c8c87e3416ab096d09047a7a0 RCT-Folly: 84578c8756030547307e4572ab1947de1685c599 RCTDeprecation: 7fa7002418c68d8ff065b29e9e9cfd8d904d6c64 RCTRequired: cabedb3345dcfd519a89098b8a320969e2cb961e @@ -1868,6 +1873,7 @@ SPEC CHECKSUMS: ReactCodegen: 6d884ae1e7d4a51a4ca6d3a1a428a89daa8335fd ReactCommon: 2b7118eace1eab072a4ccdae57303eaefc2a3941 SocketRocket: d4aabe649be1e368d1318fdf28a022d714d65748 + sqlite3: fc1400008a9b3525f5914ed715a5d1af0b8f4983 Yoga: 9f2ca179441625f0b05abb2a72517acdb35b36bd PODFILE CHECKSUM: 68dda3cf67ac49f79b776f2f27f38a3d7967d74c diff --git a/ios/RnExecutorch/ImageSegmentation.h b/ios/RnExecutorch/ImageSegmentation.h deleted file mode 100644 index 59ed56a45c..0000000000 --- a/ios/RnExecutorch/ImageSegmentation.h +++ /dev/null @@ -1,5 +0,0 @@ -#import - -@interface ImageSegmentation : NSObject - -@end \ No newline at end of file diff --git a/ios/RnExecutorch/ImageSegmentation.mm b/ios/RnExecutorch/ImageSegmentation.mm deleted file mode 100644 index d64a73abc9..0000000000 --- a/ios/RnExecutorch/ImageSegmentation.mm +++ /dev/null @@ -1,60 +0,0 @@ -#import "ImageSegmentation.h" -#import "ImageProcessor.h" -#import "models/image_segmentation/ImageSegmentationModel.h" - -@implementation ImageSegmentation { - ImageSegmentationModel *model; -} - -RCT_EXPORT_MODULE() - -- (void)releaseResources { - model = nil; -} - -- (void)loadModule:(NSString *)modelSource - resolve:(RCTPromiseResolveBlock)resolve - reject:(RCTPromiseRejectBlock)reject { - model = [[ImageSegmentationModel alloc] init]; - - NSNumber *errorCode = [model loadModel:modelSource]; - if ([errorCode intValue] != 0) { - [self releaseResources]; - reject(@"init_module_error", - [NSString stringWithFormat:@"%ld", (long)[errorCode longValue]], - nil); - return; - } - - resolve(@0); -} - -- (void)forward:(NSString *)input - classesOfInterest:(NSArray *)classesOfInterest - resize:(BOOL)resize - resolve:(RCTPromiseResolveBlock)resolve - reject:(RCTPromiseRejectBlock)reject { - - @try { - cv::Mat image = [ImageProcessor readImage:input]; - NSDictionary *result = [model runModel:image - returnClasses:classesOfInterest - resize:resize]; - - resolve(result); - return; - } @catch (NSException *exception) { - NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason); - reject(@"forward_error", - [NSString stringWithFormat:@"%@", exception.reason], nil); - return; - } -} - -- (std::shared_ptr)getTurboModule: - (const facebook::react::ObjCTurboModule::InitParams &)params { - return std::make_shared( - params); -} - -@end diff --git a/ios/RnExecutorch/models/image_segmentation/Constants.h b/ios/RnExecutorch/models/image_segmentation/Constants.h deleted file mode 100644 index f3d75cd80e..0000000000 --- a/ios/RnExecutorch/models/image_segmentation/Constants.h +++ /dev/null @@ -1,4 +0,0 @@ -#import -#import - -extern const std::vector deeplabv3_resnet50_labels; diff --git a/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h b/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h deleted file mode 100644 index 8d1c76edf0..0000000000 --- a/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h +++ /dev/null @@ -1,10 +0,0 @@ -#import "../BaseModel.h" -#import "opencv2/opencv.hpp" - -@interface ImageSegmentationModel : BaseModel -- (cv::Size)getModelImageSize; -- (NSDictionary *)runModel:(cv::Mat &)input - returnClasses:(NSArray *)classesOfInterest - resize:(BOOL)resize; - -@end diff --git a/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm b/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm deleted file mode 100644 index 4ee8c440ed..0000000000 --- a/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm +++ /dev/null @@ -1,146 +0,0 @@ -#import "ImageSegmentationModel.h" -#import "../../utils/Conversions.h" -#import "../../utils/ImageProcessor.h" -#import "../../utils/Numerical.h" -#import "Constants.h" -#import - -@interface ImageSegmentationModel () -- (NSArray *)preprocess:(cv::Mat &)input; -- (NSDictionary *)postprocess:(NSArray *)output - returnClasses:(NSArray *)classesOfInterest - resize:(BOOL)resize; -@end - -@implementation ImageSegmentationModel { - cv::Size originalSize; -} - -- (cv::Size)getModelImageSize { - NSArray *inputShape = [module getInputShape:@0]; - NSNumber *widthNumber = inputShape.lastObject; - NSNumber *heightNumber = inputShape[inputShape.count - 2]; - - int height = [heightNumber intValue]; - int width = [widthNumber intValue]; - - return cv::Size(height, width); -} - -- (NSArray *)preprocess:(cv::Mat &)input { - originalSize = cv::Size(input.cols, input.rows); - - cv::Size modelImageSize = [self getModelImageSize]; - cv::Mat output; - cv::resize(input, output, modelImageSize); - - NSArray *modelInput = [ImageProcessor matToNSArray:output]; - return modelInput; -} - -std::vector extractResults(NSArray *result, std::size_t numLabels, - cv::Size modelImageSize, - cv::Size originalSize, BOOL resize) { - std::size_t numModelPixels = modelImageSize.height * modelImageSize.width; - - std::vector resizedLabelScores(numLabels); - for (std::size_t label = 0; label < numLabels; ++label) { - cv::Mat labelMat = cv::Mat(modelImageSize, CV_64F); - - for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { - int row = pixel / modelImageSize.width; - int col = pixel % modelImageSize.width; - labelMat.at(row, col) = - [result[label * numModelPixels + pixel] doubleValue]; - } - - if (resize) { - cv::resize(labelMat, resizedLabelScores[label], originalSize); - } else { - resizedLabelScores[label] = std::move(labelMat); - } - } - return resizedLabelScores; -} - -void adjustScoresPerPixel(std::vector &labelScores, cv::Mat &argMax, - cv::Size outputSize, std::size_t numLabels) { - std::size_t numOutputPixels = outputSize.height * outputSize.width; - for (std::size_t pixel = 0; pixel < numOutputPixels; ++pixel) { - int row = pixel / outputSize.width; - int col = pixel % outputSize.width; - std::vector scores; - scores.reserve(numLabels); - for (const auto &mat : labelScores) { - scores.push_back(mat.at(row, col)); - } - - std::vector adjustedScores = softmax(scores); - - for (std::size_t label = 0; label < numLabels; ++label) { - labelScores[label].at(row, col) = adjustedScores[label]; - } - - auto maxIt = std::max_element(scores.begin(), scores.end()); - argMax.at(row, col) = std::distance(scores.begin(), maxIt); - } -} - -- (NSDictionary *)postprocess:(NSArray *)output - returnClasses:(NSArray *)classesOfInterest - resize:(BOOL)resize { - cv::Size modelImageSize = [self getModelImageSize]; - - std::size_t numLabels = deeplabv3_resnet50_labels.size(); - - NSAssert((std::size_t)output.count == - numLabels * modelImageSize.height * modelImageSize.width, - @"Model generated unexpected output size."); - - // For each label extract it's matrix, - // and rescale it to the original size if `resize` - std::vector resizedLabelScores = - extractResults(output, numLabels, modelImageSize, originalSize, resize); - - cv::Size outputSize = resize ? originalSize : modelImageSize; - cv::Mat argMax = cv::Mat(outputSize, CV_32S); - - // For each pixel apply softmax across all the labels and calculate the argMax - adjustScoresPerPixel(resizedLabelScores, argMax, outputSize, numLabels); - - std::unordered_set labelSet; - - for (id label in classesOfInterest) { - labelSet.insert(std::string([label UTF8String])); - } - - NSMutableDictionary *result = [NSMutableDictionary dictionary]; - - // Convert to NSArray and populate the final dictionary - for (std::size_t label = 0; label < numLabels; ++label) { - if (labelSet.contains(deeplabv3_resnet50_labels[label])) { - NSString *labelString = @(deeplabv3_resnet50_labels[label].c_str()); - NSArray *arr = simpleMatToNSArray(resizedLabelScores[label]); - result[labelString] = arr; - } - } - - result[@"ARGMAX"] = simpleMatToNSArray(argMax); - - return result; -} - -- (NSDictionary *)runModel:(cv::Mat &)input - returnClasses:(NSArray *)classesOfInterest - resize:(BOOL)resize { - NSArray *modelInput = [self preprocess:input]; - NSArray *result = [self forward:@[ modelInput ]]; - - NSDictionary *output = [self postprocess:result[0] - returnClasses:classesOfInterest - resize:resize]; - - return output; -} - -@end diff --git a/src/hooks/computer_vision/useImageSegmentation.ts b/src/hooks/computer_vision/useImageSegmentation.ts index c3f9f7a824..6b70d68a95 100644 --- a/src/hooks/computer_vision/useImageSegmentation.ts +++ b/src/hooks/computer_vision/useImageSegmentation.ts @@ -1,9 +1,18 @@ -import { useModule } from '../useModule'; +import { ResourceSource } from '../../types/common'; +import { useNonStaticModule } from '../useNonStaticModule'; import { ImageSegmentationModule } from '../../modules/computer_vision/ImageSegmentationModule'; interface Props { - modelSource: string | number; + modelSource: ResourceSource; + preventLoad?: boolean; } -export const useImageSegmentation = ({ modelSource }: Props) => - useModule({ module: ImageSegmentationModule, loadArgs: [modelSource] }); +export const useImageSegmentation = ({ + modelSource, + preventLoad = false, +}: Props) => + useNonStaticModule({ + module: ImageSegmentationModule, + loadArgs: [modelSource], + preventLoad: preventLoad, + }); diff --git a/src/index.tsx b/src/index.tsx index e5c9d0e78b..83de6503bb 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -5,6 +5,7 @@ import { ETInstallerNativeModule } from './native/RnExecutorchModules'; // eslint-disable no-var declare global { var loadStyleTransfer: (source: string) => Promise; + var loadImageSegmentation: (source: string) => Promise; } // eslint-disable no-var diff --git a/src/modules/computer_vision/ImageSegmentationModule.ts b/src/modules/computer_vision/ImageSegmentationModule.ts index aaf70ef7bf..2cdbf3ef63 100644 --- a/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/src/modules/computer_vision/ImageSegmentationModule.ts @@ -1,39 +1,45 @@ -import { BaseModule } from '../BaseModule'; -import { getError } from '../../Error'; -import { DeeplabLabel } from '../../types/image_segmentation'; +import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { ResourceSource } from '../../types/common'; -import { ImageSegmentationNativeModule } from '../../native/RnExecutorchModules'; +import { DeeplabLabel } from '../../types/image_segmentation'; +import { ETError, getError } from '../../Error'; -export class ImageSegmentationModule extends BaseModule { - protected static override nativeModule = ImageSegmentationNativeModule; +export class ImageSegmentationModule { + nativeModule: any = null; - static override async load(modelSource: ResourceSource) { - return await super.load(modelSource); + async load( + modelSource: ResourceSource, + onDownloadProgressCallback: (_: number) => void = () => {} + ): Promise { + const paths = await ResourceFetcher.fetchMultipleResources( + onDownloadProgressCallback, + modelSource + ); + this.nativeModule = global.loadImageSegmentation(paths[0] || ''); } - static override async forward( - input: string, + async forward( + imageSource: string, classesOfInterest?: DeeplabLabel[], resize?: boolean ) { - try { - const stringDict = await (this.nativeModule.forward( - input, - (classesOfInterest || []).map((label) => DeeplabLabel[label]), - resize || false - ) as ReturnType<(typeof this.nativeModule)['forward']>); + if (this.nativeModule == null) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + + const stringDict = await this.nativeModule.forward( + imageSource, + (classesOfInterest || []).map((label) => DeeplabLabel[label]), + resize || false + ); - let enumDict: { [key in DeeplabLabel]?: number[] } = {}; + let enumDict: { [key in DeeplabLabel]?: number[] } = {}; - for (const key in stringDict) { - if (key in DeeplabLabel) { - const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel]; - enumDict[enumKey] = stringDict[key]; - } + for (const key in stringDict) { + if (key in DeeplabLabel) { + const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel]; + enumDict[enumKey] = stringDict[key]; } - return enumDict; - } catch (e) { - throw new Error(getError(e)); } + return enumDict; } } diff --git a/src/native/NativeImageSegmentation.ts b/src/native/NativeImageSegmentation.ts deleted file mode 100644 index c66c874361..0000000000 --- a/src/native/NativeImageSegmentation.ts +++ /dev/null @@ -1,14 +0,0 @@ -import type { TurboModule } from 'react-native'; -import { TurboModuleRegistry } from 'react-native'; - -export interface Spec extends TurboModule { - loadModule(modelSource: string): Promise; - - forward( - input: string, - classesOfInterest: string[], - resize: boolean - ): Promise<{ [category: string]: number[] }>; -} - -export default TurboModuleRegistry.get('ImageSegmentation'); diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index e7e06154c5..2b9b245e53 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -1,6 +1,5 @@ import { Platform } from 'react-native'; import { Spec as ObjectDetectionInterface } from './NativeObjectDetection'; -import { Spec as ImageSegmentationInterface } from './NativeImageSegmentation'; import { Spec as ETModuleInterface } from './NativeETModule'; import { Spec as OCRInterface } from './NativeOCR'; import { Spec as VerticalOCRInterface } from './NativeVerticalOCR'; @@ -38,8 +37,6 @@ const ETModuleNativeModule: ETModuleInterface = returnSpecOrThrowLinkingError( ); const ClassificationNativeModule: ClassificationInterface = returnSpecOrThrowLinkingError(require('./NativeClassification').default); -const ImageSegmentationNativeModule: ImageSegmentationInterface = - returnSpecOrThrowLinkingError(require('./NativeImageSegmentation').default); const ObjectDetectionNativeModule: ObjectDetectionInterface = returnSpecOrThrowLinkingError(require('./NativeObjectDetection').default); const SpeechToTextNativeModule: SpeechToTextInterface = @@ -62,7 +59,6 @@ export { ETModuleNativeModule, ClassificationNativeModule, ObjectDetectionNativeModule, - ImageSegmentationNativeModule, SpeechToTextNativeModule, OCRNativeModule, VerticalOCRNativeModule,