diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt deleted file mode 100644 index 0c11e4f1f5..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt +++ /dev/null @@ -1,64 +0,0 @@ -package com.swmansion.rnexecutorch - -import android.util.Log -import com.facebook.react.bridge.Arguments -import com.facebook.react.bridge.Promise -import com.facebook.react.bridge.ReactApplicationContext -import com.facebook.react.bridge.WritableArray -import com.swmansion.rnexecutorch.models.objectdetection.SSDLiteLargeModel -import com.swmansion.rnexecutorch.utils.ETError -import com.swmansion.rnexecutorch.utils.ImageProcessor -import org.opencv.android.OpenCVLoader - -class ObjectDetection( - reactContext: ReactApplicationContext, -) : NativeObjectDetectionSpec(reactContext) { - private lateinit var ssdLiteLarge: SSDLiteLargeModel - - companion object { - const val NAME = "ObjectDetection" - } - - init { - if (!OpenCVLoader.initLocal()) { - Log.d("rn_executorch", "OpenCV not loaded") - } else { - Log.d("rn_executorch", "OpenCV loaded") - } - } - - override fun loadModule( - modelSource: String, - promise: Promise, - ) { - try { - ssdLiteLarge = SSDLiteLargeModel(reactApplicationContext) - ssdLiteLarge.loadModel(modelSource) - promise.resolve(0) - } catch (e: Exception) { - promise.reject(e.message!!, ETError.InvalidModelSource.toString()) - } - } - - override fun forward( - input: String, - promise: Promise, - ) { - try { - val inputImage = ImageProcessor.readImage(input) - val output = ssdLiteLarge.runModel(inputImage) - val outputWritableArray: WritableArray = Arguments.createArray() - output - .map { detection -> - detection.toWritableMap() - }.forEach { writableMap -> - outputWritableArray.pushMap(writableMap) - } - promise.resolve(outputWritableArray) - } catch (e: Exception) { - promise.reject(e.message!!, e.message) - } - } - - override fun getName(): String = NAME -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index c819c56642..61d82be832 100644 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -18,8 +18,6 @@ class RnExecutorchPackage : TurboReactPackage() { LLM(reactContext) } else if (name == ETModule.NAME) { ETModule(reactContext) - } else if (name == ObjectDetection.NAME) { - ObjectDetection(reactContext) } else if (name == SpeechToText.NAME) { SpeechToText(reactContext) } else if (name == OCR.NAME) { @@ -60,17 +58,6 @@ class RnExecutorchPackage : TurboReactPackage() { true, ) - moduleInfos[ObjectDetection.NAME] = - ReactModuleInfo( - ObjectDetection.NAME, - ObjectDetection.NAME, - false, // canOverrideExistingModule - false, // needsEagerInit - true, // hasConstants - false, // isCxxModule - true, - ) - moduleInfos[SpeechToText.NAME] = ReactModuleInfo( SpeechToText.NAME, diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/objectDetection/SSDLiteLargeModel.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/objectDetection/SSDLiteLargeModel.kt deleted file mode 100644 index 6d303f2c5b..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/objectDetection/SSDLiteLargeModel.kt +++ /dev/null @@ -1,74 +0,0 @@ -package com.swmansion.rnexecutorch.models.objectdetection - -import com.facebook.react.bridge.ReactApplicationContext -import com.swmansion.rnexecutorch.models.BaseModel -import com.swmansion.rnexecutorch.utils.Bbox -import com.swmansion.rnexecutorch.utils.CocoLabel -import com.swmansion.rnexecutorch.utils.Detection -import com.swmansion.rnexecutorch.utils.ImageProcessor -import com.swmansion.rnexecutorch.utils.nms -import org.opencv.core.Mat -import org.opencv.core.Size -import org.opencv.imgproc.Imgproc -import org.pytorch.executorch.EValue - -const val DETECTION_SCORE_THRESHOLD = .7f -const val IOU_THRESHOLD = .55f - -class SSDLiteLargeModel( - reactApplicationContext: ReactApplicationContext, -) : BaseModel>(reactApplicationContext) { - private var heightRatio: Float = 1.0f - private var widthRatio: Float = 1.0f - - private fun getModelImageSize(): Size { - val inputShape = module.getInputShape(0) - val width = inputShape[inputShape.lastIndex] - val height = inputShape[inputShape.lastIndex - 1] - - return Size(height.toDouble(), width.toDouble()) - } - - fun preprocess(input: Mat): EValue { - this.widthRatio = (input.size().width / getModelImageSize().width).toFloat() - this.heightRatio = (input.size().height / getModelImageSize().height).toFloat() - Imgproc.resize(input, input, getModelImageSize()) - return ImageProcessor.matToEValue(input, module.getInputShape(0)) - } - - override fun runModel(input: Mat): Array { - val modelInput = preprocess(input) - val modelOutput = forward(modelInput) - return postprocess(modelOutput) - } - - fun postprocess(output: Array): Array { - val scoresTensor = output[1].toTensor() - val numel = scoresTensor.numel() - val bboxes = output[0].toTensor().dataAsFloatArray - val scores = scoresTensor.dataAsFloatArray - val labels = output[2].toTensor().dataAsFloatArray - - val detections: MutableList = mutableListOf() - for (idx in 0 until numel.toInt()) { - val score = scores[idx] - if (score < DETECTION_SCORE_THRESHOLD) { - continue - } - val bbox = - Bbox( - bboxes[idx * 4 + 0] * this.widthRatio, - bboxes[idx * 4 + 1] * this.heightRatio, - bboxes[idx * 4 + 2] * this.widthRatio, - bboxes[idx * 4 + 3] * this.heightRatio, - ) - val label = labels[idx] - detections.add( - Detection(bbox, score, CocoLabel.fromId(label.toInt())!!), - ) - } - - val detectionsPostNms = nms(detections, IOU_THRESHOLD) - return detectionsPostNms.toTypedArray() - } -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt deleted file mode 100644 index 6f58f373c8..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt +++ /dev/null @@ -1,201 +0,0 @@ -package com.swmansion.rnexecutorch.utils - -import com.facebook.react.bridge.Arguments -import com.facebook.react.bridge.WritableMap - -fun nms( - detections: MutableList, - iouThreshold: Float, -): List { - if (detections.isEmpty()) { - return emptyList() - } - - // Sort detections first by label, then by score (descending) - val sortedDetections = detections.sortedWith(compareBy({ it.label }, { -it.score })) - - val result = mutableListOf() - - // Process NMS for each label group - var i = 0 - while (i < sortedDetections.size) { - val currentLabel = sortedDetections[i].label - - // Collect detections for the current label - val labelDetections = mutableListOf() - while (i < sortedDetections.size && sortedDetections[i].label == currentLabel) { - labelDetections.add(sortedDetections[i]) - i++ - } - - // Filter out detections with high IoU - val filteredLabelDetections = mutableListOf() - while (labelDetections.isNotEmpty()) { - val current = labelDetections.removeAt(0) - filteredLabelDetections.add(current) - - // Remove detections that overlap with the current detection above the IoU threshold - val iterator = labelDetections.iterator() - while (iterator.hasNext()) { - val other = iterator.next() - if (calculateIoU(current.bbox, other.bbox) > iouThreshold) { - iterator.remove() // Remove detection if IoU is above threshold - } - } - } - - // Add the filtered detections to the result - result.addAll(filteredLabelDetections) - } - - return result -} - -fun calculateIoU( - bbox1: Bbox, - bbox2: Bbox, -): Float { - val x1 = maxOf(bbox1.x1, bbox2.x1) - val y1 = maxOf(bbox1.y1, bbox2.y1) - val x2 = minOf(bbox1.x2, bbox2.x2) - val y2 = minOf(bbox1.y2, bbox2.y2) - - val intersectionArea = maxOf(0f, x2 - x1) * maxOf(0f, y2 - y1) - val bbox1Area = (bbox1.x2 - bbox1.x1) * (bbox1.y2 - bbox1.y1) - val bbox2Area = (bbox2.x2 - bbox2.x1) * (bbox2.y2 - bbox2.y1) - - val unionArea = bbox1Area + bbox2Area - intersectionArea - return if (unionArea == 0f) 0f else intersectionArea / unionArea -} - -data class Bbox( - val x1: Float, - val y1: Float, - val x2: Float, - val y2: Float, -) { - fun toWritableMap(): WritableMap { - val map = Arguments.createMap() - map.putDouble("x1", x1.toDouble()) - map.putDouble("x2", x2.toDouble()) - map.putDouble("y1", y1.toDouble()) - map.putDouble("y2", y2.toDouble()) - return map - } -} - -data class Detection( - val bbox: Bbox, - val score: Float, - val label: CocoLabel, -) { - fun toWritableMap(): WritableMap { - val map = Arguments.createMap() - map.putMap("bbox", bbox.toWritableMap()) - map.putDouble("score", score.toDouble()) - map.putString("label", label.name) - return map - } -} - -enum class CocoLabel( - val id: Int, -) { - PERSON(1), - BICYCLE(2), - CAR(3), - MOTORCYCLE(4), - AIRPLANE(5), - BUS(6), - TRAIN(7), - TRUCK(8), - BOAT(9), - TRAFFIC_LIGHT(10), - FIRE_HYDRANT(11), - STREET_SIGN(12), - STOP_SIGN(13), - PARKING(14), - BENCH(15), - BIRD(16), - CAT(17), - DOG(18), - HORSE(19), - SHEEP(20), - COW(21), - ELEPHANT(22), - BEAR(23), - ZEBRA(24), - GIRAFFE(25), - HAT(26), - BACKPACK(27), - UMBRELLA(28), - SHOE(29), - EYE(30), - HANDBAG(31), - TIE(32), - SUITCASE(33), - FRISBEE(34), - SKIS(35), - SNOWBOARD(36), - SPORTS(37), - KITE(38), - BASEBALL(39), - SKATEBOARD(41), - SURFBOARD(42), - TENNIS_RACKET(43), - BOTTLE(44), - PLATE(45), - WINE_GLASS(46), - CUP(47), - FORK(48), - KNIFE(49), - SPOON(50), - BOWL(51), - BANANA(52), - APPLE(53), - SANDWICH(54), - ORANGE(55), - BROCCOLI(56), - CARROT(57), - HOT_DOG(58), - PIZZA(59), - DONUT(60), - CAKE(61), - CHAIR(62), - COUCH(63), - POTTED_PLANT(64), - BED(65), - MIRROR(66), - DINING_TABLE(67), - WINDOW(68), - DESK(69), - TOILET(70), - DOOR(71), - TV(72), - LAPTOP(73), - MOUSE(74), - REMOTE(75), - KEYBOARD(76), - CELL_PHONE(77), - MICROWAVE(78), - OVEN(79), - TOASTER(80), - SINK(81), - REFRIGERATOR(82), - BLENDER(83), - BOOK(84), - CLOCK(85), - VASE(86), - SCISSORS(87), - TEDDY_BEAR(88), - HAIR_DRIER(89), - TOOTHBRUSH(90), - HAIR_BRUSH(91), - ; - - companion object { - private val idToLabelMap = values().associateBy(CocoLabel::id) - - fun fromId(id: Int): CocoLabel? = idToLabelMap[id] - } -} diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index 532cbc1b1b..dcff11d2f6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace rnexecutorch { @@ -31,5 +32,10 @@ void RnExecutorchInstaller::injectJSIBindings( *jsiRuntime, "loadClassification", RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadClassification")); + + jsiRuntime->global().setProperty( + *jsiRuntime, "loadObjectDetection", + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadObjectDetection")); } } // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp index 20461d1690..932942b427 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp @@ -111,6 +111,7 @@ cv::Mat readImage(const std::string &imageURI) { throw std::runtime_error("Read image error: invalid argument"); } + cv::cvtColor(image, image, cv::COLOR_BGR2RGB); return image; } diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 2529555e34..c10f3c1824 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -6,6 +6,9 @@ #include +#include +#include + namespace rnexecutorch::jsiconversion { using namespace facebook; @@ -90,6 +93,27 @@ getJsiValue(const std::unordered_map &map, return mapObj; } +inline jsi::Value getJsiValue(const std::vector &detections, + jsi::Runtime &runtime) { + jsi::Array array(runtime, detections.size()); + for (std::size_t i = 0; i < detections.size(); ++i) { + jsi::Object detection(runtime); + jsi::Object bbox(runtime); + bbox.setProperty(runtime, "x1", detections[i].x1); + bbox.setProperty(runtime, "y1", detections[i].y1); + bbox.setProperty(runtime, "x2", detections[i].x2); + bbox.setProperty(runtime, "y2", detections[i].y2); + + detection.setProperty(runtime, "bbox", bbox); + detection.setProperty(runtime, "label", + jsi::String::createFromAscii( + runtime, cocoLabelsMap.at(detections[i].label))); + detection.setProperty(runtime, "score", detections[i].score); + array.setValueAtIndex(runtime, i, detection); + } + return array; +} + template constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) { return sizeof...(Types); diff --git a/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h b/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h index 51e9b63e49..d9a6e1229d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h +++ b/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h @@ -6,6 +6,17 @@ namespace rnexecutorch { using namespace facebook; +/** + * JSI offers the MutableBuffer as an interface for accessing native memory + * directly from JS. A class inheriting from the MutableBuffer could be used to + * access memory which is owned by C++ or memory which should be freed once JS + * is done with it. OwningArrayBuffer is an example of the latter, memory is + * allocated on creation and freed on deletion. JS holds a pointer to all + * MutableBuffers via a shared_ptr, so the destructor will be called only when + * no reference to it is kept. For a handy JS access to the data, MutableBuffers + * can be inspected via a data view, such as Float32Array. See + * ImageSegmentation.cpp for an example usage. + */ class OwningArrayBuffer : public jsi::MutableBuffer { public: OwningArrayBuffer(const size_t size) : size_(size) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index 0e32dec468..96050b254c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -44,7 +44,6 @@ Classification::forward(std::string imageSource) { TensorPtr Classification::preprocess(const std::string &imageSource) { cv::Mat image = imageprocessing::readImage(imageSource); - cv::cvtColor(image, image, cv::COLOR_BGR2RGB); cv::resize(image, image, modelImageSize); return imageprocessing::getTensorFromMatrix(getInputShape()[0], image); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h index a6d69e1c29..a27303211a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h @@ -10,4 +10,4 @@ inline constexpr std::array deeplabv3_resnet50_labels = { "COW", "DININGTABLE", "DOG", "HORSE", "MOTORBIKE", "PERSON", "POTTEDPLANT", "SHEEP", "SOFA", "TRAIN", "TVMONITOR"}; -} \ No newline at end of file +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp index fa67844886..3ef5686750 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp @@ -53,7 +53,6 @@ ImageSegmentation::forward(std::string imageSource, std::pair ImageSegmentation::preprocess(const std::string &imageSource) { cv::Mat input = imageprocessing::readImage(imageSource); - cv::cvtColor(input, input, cv::COLOR_BGR2RGB); cv::Size inputSize = input.size(); cv::resize(input, input, modelImageSize); diff --git a/packages/react-native-executorch/ios/RnExecutorch/utils/Constants.mm b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Constants.h similarity index 94% rename from packages/react-native-executorch/ios/RnExecutorch/utils/Constants.mm rename to packages/react-native-executorch/common/rnexecutorch/models/object_detection/Constants.h index e93359f91f..5665337886 100644 --- a/packages/react-native-executorch/ios/RnExecutorch/utils/Constants.mm +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Constants.h @@ -1,5 +1,9 @@ -#include "Constants.h" +#pragma once +#include +#include + +namespace rnexecutorch { const std::unordered_map cocoLabelsMap = { {1, "PERSON"}, {2, "BICYCLE"}, {3, "CAR"}, {4, "MOTORCYCLE"}, {5, "AIRPLANE"}, {6, "BUS"}, @@ -32,3 +36,4 @@ {85, "CLOCK"}, {86, "VASE"}, {87, "SCISSORS"}, {88, "TEDDY_BEAR"}, {89, "HAIR_DRIER"}, {90, "TOOTHBRUSH"}, {91, "HAIR_BRUSH"}}; +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp new file mode 100644 index 0000000000..7d93994191 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -0,0 +1,91 @@ +#include "ObjectDetection.h" + +#include + +namespace rnexecutorch { + +ObjectDetection::ObjectDetection( + const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + auto inputTensors = getInputShape(); + if (inputTensors.size() == 0) { + throw std::runtime_error("Model seems to not take any input tensors."); + } + std::vector modelInputShape = inputTensors[0]; + if (modelInputShape.size() < 2) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Unexpected model input size, expected at least 2 dimentions " + "but got: %zu.", + modelInputShape.size()); + throw std::runtime_error(errorMessage); + } + modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], + modelInputShape[modelInputShape.size() - 2]); +} + +std::pair +ObjectDetection::preprocess(const std::string &imageSource) { + cv::Mat image = imageprocessing::readImage(imageSource); + auto originalSize = image.size(); + cv::resize(image, image, modelImageSize); + + return {imageprocessing::getTensorFromMatrix(getInputShape()[0], image), + originalSize}; +} + +std::vector +ObjectDetection::postprocess(const std::vector &tensors, + cv::Size originalSize, double detectionThreshold) { + float widthRatio = + static_cast(originalSize.width) / modelImageSize.width; + float heightRatio = + static_cast(originalSize.height) / modelImageSize.height; + + std::vector detections; + auto bboxTensor = tensors.at(0).toTensor(); + std::span bboxes( + static_cast(bboxTensor.const_data_ptr()), + bboxTensor.numel()); + + auto scoreTensor = tensors.at(1).toTensor(); + std::span scores( + static_cast(scoreTensor.const_data_ptr()), + scoreTensor.numel()); + + auto labelTensor = tensors.at(2).toTensor(); + std::span labels( + static_cast(labelTensor.const_data_ptr()), + labelTensor.numel()); + + for (std::size_t i = 0; i < scores.size(); ++i) { + if (scores[i] < detectionThreshold) { + continue; + } + float x1 = bboxes[i * 4] * widthRatio; + float y1 = bboxes[i * 4 + 1] * heightRatio; + float x2 = bboxes[i * 4 + 2] * widthRatio; + float y2 = bboxes[i * 4 + 3] * heightRatio; + detections.emplace_back(x1, y1, x2, y2, static_cast(labels[i]), + scores[i]); + } + + std::vector output = nonMaxSuppression(detections); + return output; +} + +std::vector ObjectDetection::forward(std::string imageSource, + double detectionThreshold) { + auto [tensor, originalSize] = preprocess(imageSource); + + auto forwardResult = forwardET(tensor); + if (!forwardResult.ok()) { + throw std::runtime_error( + "Failed to forward, error: " + + std::to_string(static_cast(forwardResult.error()))); + } + + return postprocess(forwardResult.get(), originalSize, detectionThreshold); +} +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h new file mode 100644 index 0000000000..8acf06ee2f --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +#include +#include +#include + +#include +#include + +namespace rnexecutorch { +using executorch::extension::TensorPtr; +using executorch::runtime::EValue; + +class ObjectDetection : public BaseModel { +public: + ObjectDetection(const std::string &modelSource, + std::shared_ptr callInvoker); + std::vector forward(std::string imageSource, + double detectionThreshold); + +private: + std::pair preprocess(const std::string &imageSource); + std::vector postprocess(const std::vector &tensors, + cv::Size originalSize, + double detectionThreshold); + + cv::Size modelImageSize{0, 0}; +}; + +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/ios/RnExecutorch/utils/ObjectDetectionUtils.mm b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Utils.cpp similarity index 62% rename from packages/react-native-executorch/ios/RnExecutorch/utils/ObjectDetectionUtils.mm rename to packages/react-native-executorch/common/rnexecutorch/models/object_detection/Utils.cpp index e1b8366f2d..9cedccbc2c 100644 --- a/packages/react-native-executorch/ios/RnExecutorch/utils/ObjectDetectionUtils.mm +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Utils.cpp @@ -1,30 +1,7 @@ -#include "ObjectDetectionUtils.hpp" -#include "Constants.h" +#include "Utils.h" -NSString *floatLabelToNSString(float label) { - int intLabel = static_cast(label); - auto it = cocoLabelsMap.find(intLabel); - if (it != cocoLabelsMap.end()) { - return [NSString stringWithUTF8String:it->second.c_str()]; - } else { - return [NSString stringWithUTF8String:"unknown"]; - } -} - -NSDictionary *detectionToNSDictionary(const Detection &detection) { - return @{ - @"bbox" : @{ - @"x1" : @(detection.x1), - @"y1" : @(detection.y1), - @"x2" : @(detection.x2), - @"y2" : @(detection.y2), - }, - @"label" : floatLabelToNSString(detection.label), - @"score" : @(detection.score) - }; -} - -float iou(const Detection &a, const Detection &b) { +namespace rnexecutorch { +float intersectionOverUnion(const Detection &a, const Detection &b) { float x1 = std::max(a.x1, b.x1); float y1 = std::max(a.y1, b.y1); float x2 = std::min(a.x2, b.x2); @@ -36,10 +13,9 @@ float iou(const Detection &a, const Detection &b) { float unionArea = areaA + areaB - intersectionArea; return intersectionArea / unionArea; -}; +} -std::vector nms(std::vector detections, - float iouThreshold) { +std::vector nonMaxSuppression(std::vector detections) { if (detections.empty()) { return {}; } @@ -70,8 +46,9 @@ float iou(const Detection &a, const Detection &b) { filteredLabelDetections.push_back(current); labelDetections.erase( std::remove_if(labelDetections.begin(), labelDetections.end(), - [&](const Detection &other) { - return iou(current, other) > iouThreshold; + [¤t](const Detection &other) { + return intersectionOverUnion(current, other) > + iouThreshold; }), labelDetections.end()); } @@ -80,3 +57,5 @@ float iou(const Detection &a, const Detection &b) { } return result; } + +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Utils.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Utils.h new file mode 100644 index 0000000000..b226da95a4 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Utils.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +namespace rnexecutorch { +struct Detection { + float x1; + float y1; + float x2; + float y2; + int label; + float score; +}; + +inline constexpr float iouThreshold = 0.55; + +float intersectionOverUnion(const Detection &a, const Detection &b); +std::vector nonMaxSuppression(std::vector detections); +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index 0dff20b671..dbc9b1b330 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -36,7 +36,6 @@ StyleTransfer::StyleTransfer(const std::string &modelSource, std::pair StyleTransfer::preprocess(const std::string &imageSource) { cv::Mat image = imageprocessing::readImage(imageSource); - cv::cvtColor(image, image, cv::COLOR_BGR2RGB); auto originalSize = image.size(); cv::resize(image, image, modelImageSize); diff --git a/packages/react-native-executorch/ios/RnExecutorch/ObjectDetection.h b/packages/react-native-executorch/ios/RnExecutorch/ObjectDetection.h deleted file mode 100644 index d41fccb8a4..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/ObjectDetection.h +++ /dev/null @@ -1,5 +0,0 @@ -#import - -@interface ObjectDetection : NSObject - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/ObjectDetection.mm b/packages/react-native-executorch/ios/RnExecutorch/ObjectDetection.mm deleted file mode 100644 index 265d7e19ad..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/ObjectDetection.mm +++ /dev/null @@ -1,56 +0,0 @@ -#import "ObjectDetection.h" -#import "models/object_detection/SSDLiteLargeModel.hpp" -#import "utils/ImageProcessor.h" - -@implementation ObjectDetection { - SSDLiteLargeModel *model; -} - -RCT_EXPORT_MODULE() - -- (void)releaseResources { - model = nil; -} - -- (void)loadModule:(NSString *)modelSource - resolve:(RCTPromiseResolveBlock)resolve - reject:(RCTPromiseRejectBlock)reject { - model = [[SSDLiteLargeModel alloc] init]; - - NSNumber *errorCode = [model loadModel:modelSource]; - if ([errorCode intValue] != 0) { - [self releaseResources]; - NSError *error = [NSError - errorWithDomain:@"StyleTransferErrorDomain" - code:[errorCode intValue] - userInfo:@{ - NSLocalizedDescriptionKey : [NSString - stringWithFormat:@"%ld", (long)[errorCode longValue]] - }]; - reject(@"init_module_error", error.localizedDescription, error); - return; - } - - resolve(@0); -} - -- (void)forward:(NSString *)input - resolve:(RCTPromiseResolveBlock)resolve - reject:(RCTPromiseRejectBlock)reject { - @try { - cv::Mat image = [ImageProcessor readImage:input]; - NSArray *result = [model runModel:image]; - resolve(result); - } @catch (NSException *exception) { - reject(@"forward_error", - [NSString stringWithFormat:@"%@", exception.reason], nil); - } -} - -- (std::shared_ptr)getTurboModule: - (const facebook::react::ObjCTurboModule::InitParams &)params { - return std::make_shared( - params); -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/image_segmentation/Constants.mm b/packages/react-native-executorch/ios/RnExecutorch/models/image_segmentation/Constants.mm deleted file mode 100644 index a6d69e1c29..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/image_segmentation/Constants.mm +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include -#include - -namespace rnexecutorch { -inline constexpr std::array deeplabv3_resnet50_labels = { - "BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT", - "BOTTLE", "BUS", "CAR", "CAT", "CHAIR", - "COW", "DININGTABLE", "DOG", "HORSE", "MOTORBIKE", - "PERSON", "POTTEDPLANT", "SHEEP", "SOFA", "TRAIN", - "TVMONITOR"}; -} \ No newline at end of file diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.hpp b/packages/react-native-executorch/ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.hpp deleted file mode 100644 index 34ea34c0a0..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.hpp +++ /dev/null @@ -1,11 +0,0 @@ -#import "../BaseModel.h" -#import -#include - -@interface SSDLiteLargeModel : BaseModel - -- (NSArray *)runModel:(cv::Mat)input; -- (NSArray *)preprocess:(cv::Mat)input; -- (NSArray *)postprocess:(NSArray *)input; - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.mm b/packages/react-native-executorch/ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.mm deleted file mode 100644 index 57e60dac40..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.mm +++ /dev/null @@ -1,64 +0,0 @@ -#include "SSDLiteLargeModel.hpp" -#include "../../utils/ObjectDetectionUtils.hpp" -#include "ImageProcessor.h" - -float constexpr iouThreshold = 0.55; -float constexpr detectionThreshold = 0.7; -int constexpr modelInputWidth = 320; -int constexpr modelInputHeight = 320; - -@implementation SSDLiteLargeModel - -- (NSArray *)preprocess:(cv::Mat)input { - cv::resize(input, input, cv::Size(modelInputWidth, modelInputHeight)); - NSArray *modelInput = [ImageProcessor matToNSArray:input]; - return modelInput; -} - -- (NSArray *)postprocess:(NSArray *)input - widthRatio:(float)widthRatio - heightRatio:(float)heightRatio { - NSArray *bboxes = [input objectAtIndex:0]; - NSArray *scores = [input objectAtIndex:1]; - NSArray *labels = [input objectAtIndex:2]; - - std::vector detections; - - for (NSUInteger idx = 0; idx < scores.count; idx++) { - float score = [scores[idx] floatValue]; - float label = [labels[idx] floatValue]; - if (score < detectionThreshold) { - continue; - } - float x1 = [bboxes[idx * 4] floatValue] * widthRatio; - float y1 = [bboxes[idx * 4 + 1] floatValue] * heightRatio; - float x2 = [bboxes[idx * 4 + 2] floatValue] * widthRatio; - float y2 = [bboxes[idx * 4 + 3] floatValue] * heightRatio; - - Detection det = {x1, y1, x2, y2, label, score}; - detections.push_back(det); - } - std::vector nms_output = nms(detections, iouThreshold); - - NSMutableArray *output = [NSMutableArray array]; - for (Detection &detection : nms_output) { - [output addObject:detectionToNSDictionary(detection)]; - } - - return output; -} - -- (NSArray *)runModel:(cv::Mat)input { - cv::Size size = input.size(); - int inputImageWidth = size.width; - int inputImageHeight = size.height; - NSArray *modelInput = [self preprocess:input]; - NSArray *forwardResult = [self forward:@[ modelInput ]]; - NSArray *output = - [self postprocess:forwardResult - widthRatio:inputImageWidth / (float)modelInputWidth - heightRatio:inputImageHeight / (float)modelInputHeight]; - return output; -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/utils/Constants.h b/packages/react-native-executorch/ios/RnExecutorch/utils/Constants.h deleted file mode 100644 index f7fe57e609..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/utils/Constants.h +++ /dev/null @@ -1,8 +0,0 @@ -#ifndef Constants_h -#define Constants_h - -#include - -extern const std::unordered_map cocoLabelsMap; - -#endif /* Constants_h */ diff --git a/packages/react-native-executorch/ios/RnExecutorch/utils/ObjectDetectionUtils.hpp b/packages/react-native-executorch/ios/RnExecutorch/utils/ObjectDetectionUtils.hpp deleted file mode 100644 index 32652e3ace..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/utils/ObjectDetectionUtils.hpp +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef ObjectDetectionUtils_hpp -#define ObjectDetectionUtils_hpp - -#import -#include -#include - -struct Detection { - float x1; - float y1; - float x2; - float y2; - float label; - float score; -}; - -NSString *floatLabelToNSString(float label); -NSDictionary *detectionToNSDictionary(const Detection &detection); -float iou(const Detection &a, const Detection &b); -std::vector nms(std::vector detections, - float iouThreshold); - -#endif /* ObjectDetectionUtils_hpp */ diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts index 1c36bebc20..4d674f3e6b 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts @@ -1,5 +1,5 @@ import { ResourceSource } from '../../types/common'; -import { useModule } from '../useModule'; +import { useNonStaticModule } from '../useNonStaticModule'; import { ObjectDetectionModule } from '../../modules/computer_vision/ObjectDetectionModule'; interface Props { @@ -11,8 +11,8 @@ export const useObjectDetection = ({ modelSource, preventLoad = false, }: Props) => - useModule({ + useNonStaticModule({ module: ObjectDetectionModule, loadArgs: [modelSource], - preventLoad, + preventLoad: preventLoad, }); diff --git a/packages/react-native-executorch/src/index.tsx b/packages/react-native-executorch/src/index.tsx index d226737b4f..6721f4a12c 100644 --- a/packages/react-native-executorch/src/index.tsx +++ b/packages/react-native-executorch/src/index.tsx @@ -7,6 +7,7 @@ declare global { var loadStyleTransfer: (source: string) => any; var loadImageSegmentation: (source: string) => any; var loadClassification: (source: string) => any; + var loadObjectDetection: (source: string) => any; } // eslint-disable no-var diff --git a/packages/react-native-executorch/src/modules/BaseNonStaticModule.ts b/packages/react-native-executorch/src/modules/BaseNonStaticModule.ts index b9606eb653..028a1550a8 100644 --- a/packages/react-native-executorch/src/modules/BaseNonStaticModule.ts +++ b/packages/react-native-executorch/src/modules/BaseNonStaticModule.ts @@ -1,6 +1,8 @@ export class BaseNonStaticModule { nativeModule: any = null; delete() { - this.nativeModule.unload(); + if (this.nativeModule !== null) { + this.nativeModule.unload(); + } } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index 2383368c01..344943686d 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -1,17 +1,27 @@ -import { ObjectDetectionNativeModule } from '../../native/RnExecutorchModules'; +import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { ResourceSource } from '../../types/common'; -import { BaseModule } from '../BaseModule'; +import { Detection } from '../../types/objectDetection'; +import { ETError, getError } from '../../Error'; +import { BaseNonStaticModule } from '../BaseNonStaticModule'; -export class ObjectDetectionModule extends BaseModule { - protected static override nativeModule = ObjectDetectionNativeModule; - - static override async load(modelSource: ResourceSource) { - return await super.load(modelSource); +export class ObjectDetectionModule extends BaseNonStaticModule { + async load( + modelSource: ResourceSource, + onDownloadProgressCallback: (_: number) => void = () => {} + ): Promise { + const paths = await ResourceFetcher.fetchMultipleResources( + onDownloadProgressCallback, + modelSource + ); + this.nativeModule = global.loadObjectDetection(paths[0] || ''); } - static override async forward( - input: string - ): ReturnType { - return await this.nativeModule.forward(input); + async forward( + imageSource: string, + detectionThreshold: number = 0.7 + ): Promise { + if (this.nativeModule == null) + throw new Error(getError(ETError.ModuleNotLoaded)); + return await this.nativeModule.forward(imageSource, detectionThreshold); } } diff --git a/packages/react-native-executorch/src/native/NativeObjectDetection.ts b/packages/react-native-executorch/src/native/NativeObjectDetection.ts deleted file mode 100644 index 7fa3845001..0000000000 --- a/packages/react-native-executorch/src/native/NativeObjectDetection.ts +++ /dev/null @@ -1,10 +0,0 @@ -import type { TurboModule } from 'react-native'; -import { TurboModuleRegistry } from 'react-native'; -import { Detection } from '../types/objectDetection'; - -export interface Spec extends TurboModule { - loadModule(modelSource: string): Promise; - forward(input: string): Promise; -} - -export default TurboModuleRegistry.get('ObjectDetection'); diff --git a/packages/react-native-executorch/src/native/RnExecutorchModules.ts b/packages/react-native-executorch/src/native/RnExecutorchModules.ts index 9dc0350759..b767a0db5f 100644 --- a/packages/react-native-executorch/src/native/RnExecutorchModules.ts +++ b/packages/react-native-executorch/src/native/RnExecutorchModules.ts @@ -1,5 +1,4 @@ import { Platform } from 'react-native'; -import { Spec as ObjectDetectionInterface } from './NativeObjectDetection'; import { Spec as ETModuleInterface } from './NativeETModule'; import { Spec as OCRInterface } from './NativeOCR'; import { Spec as VerticalOCRInterface } from './NativeVerticalOCR'; @@ -34,8 +33,6 @@ const LLMNativeModule: LLMInterface = returnSpecOrThrowLinkingError( const ETModuleNativeModule: ETModuleInterface = returnSpecOrThrowLinkingError( require('./NativeETModule').default ); -const ObjectDetectionNativeModule: ObjectDetectionInterface = - returnSpecOrThrowLinkingError(require('./NativeObjectDetection').default); const SpeechToTextNativeModule: SpeechToTextInterface = returnSpecOrThrowLinkingError(require('./NativeSpeechToText').default); const OCRNativeModule: OCRInterface = returnSpecOrThrowLinkingError( @@ -54,7 +51,6 @@ const ETInstallerNativeModule: ETInstallerInterface = export { LLMNativeModule, ETModuleNativeModule, - ObjectDetectionNativeModule, SpeechToTextNativeModule, OCRNativeModule, VerticalOCRNativeModule,