diff --git a/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt b/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt new file mode 100644 index 0000000000..c18fa8ed32 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt @@ -0,0 +1,58 @@ +package com.swmansion.rnexecutorch + +import android.util.Log +import com.facebook.react.bridge.Promise +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.ReadableArray +import com.swmansion.rnexecutorch.models.imagesegmentation.ImageSegmentationModel +import com.swmansion.rnexecutorch.utils.ETError +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.android.OpenCVLoader + +class ImageSegmentation( + reactContext: ReactApplicationContext, +) : NativeImageSegmentationSpec(reactContext) { + private lateinit var model: ImageSegmentationModel + + companion object { + const val NAME = "ImageSegmentation" + + init { + if (!OpenCVLoader.initLocal()) { + Log.d("rn_executorch", "OpenCV not loaded") + } else { + Log.d("rn_executorch", "OpenCV loaded") + } + } + } + + override fun loadModule( + modelSource: String, + promise: Promise, + ) { + try { + model = ImageSegmentationModel(reactApplicationContext) + model.loadModel(modelSource) + promise.resolve(0) + } catch (e: Exception) { + promise.reject(e.message!!, ETError.InvalidModelSource.toString()) + } + } + + override fun forward( + input: String, + classesOfInterest: ReadableArray, + resize: Boolean, + promise: Promise, + ) { + try { + val output = + model.runModel(Triple(ImageProcessor.readImage(input), classesOfInterest, resize)) + promise.resolve(output) + } catch (e: Exception) { + promise.reject(e.message!!, e.message) + } + } + + override fun getName(): String = NAME +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index c88e3870a0..3c78d4d7fa 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -30,6 +30,8 @@ class RnExecutorchPackage : TurboReactPackage() { OCR(reactContext) } else if (name == VerticalOCR.NAME) { VerticalOCR(reactContext) + } else if (name == ImageSegmentation.NAME) { + ImageSegmentation(reactContext) } else { null } @@ -115,6 +117,13 @@ class RnExecutorchPackage : TurboReactPackage() { false, // isCxxModule true, ) + + moduleInfos[ImageSegmentation.NAME] = ReactModuleInfo( + ImageSegmentation.NAME, ImageSegmentation.NAME, false, // canOverrideExistingModule + false, // needsEagerInit + false, // isCxxModule + true + ) moduleInfos } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt b/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt index 54132b88b1..224794e17f 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt @@ -3,7 +3,7 @@ package com.swmansion.rnexecutorch import android.util.Log import com.facebook.react.bridge.Promise import com.facebook.react.bridge.ReactApplicationContext -import com.swmansion.rnexecutorch.models.StyleTransferModel +import com.swmansion.rnexecutorch.models.styletransfer.StyleTransferModel import com.swmansion.rnexecutorch.utils.ETError import com.swmansion.rnexecutorch.utils.ImageProcessor import org.opencv.android.OpenCVLoader diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt index b60b0998c4..776f9a5397 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt @@ -3,6 +3,7 @@ package com.swmansion.rnexecutorch.models.classification import com.facebook.react.bridge.ReactApplicationContext import com.swmansion.rnexecutorch.models.BaseModel import com.swmansion.rnexecutorch.utils.ImageProcessor +import com.swmansion.rnexecutorch.utils.softmax import org.opencv.core.Mat import org.opencv.core.Size import org.opencv.imgproc.Imgproc diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt new file mode 100644 index 0000000000..7ba7fcb5c1 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt @@ -0,0 +1,26 @@ +package com.swmansion.rnexecutorch.models.imagesegmentation + +val deeplabv3_resnet50_labels: Array = + arrayOf( + "BACKGROUND", + "AEROPLANE", + "BICYCLE", + "BIRD", + "BOAT", + "BOTTLE", + "BUS", + "CAR", + "CAT", + "CHAIR", + "COW", + "DININGTABLE", + "DOG", + "HORSE", + "MOTORBIKE", + "PERSON", + "POTTEDPLANT", + "SHEEP", + "SOFA", + "TRAIN", + "TVMONITOR", + ) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt new file mode 100644 index 0000000000..36c1594b49 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt @@ -0,0 +1,139 @@ +package com.swmansion.rnexecutorch.models.imagesegmentation + +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.ReadableArray +import com.facebook.react.bridge.WritableMap +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.utils.ArrayUtils +import com.swmansion.rnexecutorch.utils.ImageProcessor +import com.swmansion.rnexecutorch.utils.softmax +import org.opencv.core.CvType +import org.opencv.core.Mat +import org.opencv.core.Size +import org.opencv.imgproc.Imgproc +import org.pytorch.executorch.EValue + +class ImageSegmentationModel( + reactApplicationContext: ReactApplicationContext, +) : BaseModel, WritableMap>(reactApplicationContext) { + private lateinit var originalSize: Size + + private fun getModelImageSize(): Size { + val inputShape = module.getInputShape(0) + val width = inputShape[inputShape.lastIndex] + val height = inputShape[inputShape.lastIndex - 1] + + return Size(height.toDouble(), width.toDouble()) + } + + fun preprocess(input: Mat): EValue { + originalSize = input.size() + Imgproc.resize(input, input, getModelImageSize()) + return ImageProcessor.matToEValue(input, module.getInputShape(0)) + } + + private fun extractResults( + result: FloatArray, + numLabels: Int, + resize: Boolean, + ): List { + val modelSize = getModelImageSize() + val numModelPixels = (modelSize.height * modelSize.width).toInt() + + val extractedLabelScores = mutableListOf() + + for (label in 0.., + numLabels: Int, + outputSize: Size, + ): IntArray { + val numPixels = (outputSize.height * outputSize.width).toInt() + val argMax = IntArray(numPixels) + for (pixel in 0..() + for (buffer in labelScores) { + scores.add(buffer[pixel]) + } + val adjustedScores = softmax(scores.toTypedArray()) + for (label in 0.., + classesOfInterest: ReadableArray, + resize: Boolean, + ): WritableMap { + val outputData = output[0].toTensor().dataAsFloatArray + val modelSize = getModelImageSize() + val numLabels = deeplabv3_resnet50_labels.size + + require(outputData.count() == (numLabels * modelSize.height * modelSize.width).toInt()) { "Model generated unexpected output size." } + + val outputSize = if (resize) originalSize else modelSize + + val extractedResults = extractResults(outputData, numLabels, resize) + + val argMax = adjustScoresPerPixel(extractedResults, numLabels, outputSize) + + val labelSet = mutableSetOf() + // Filter by the label set when base class changed + for (i in 0..): WritableMap { + val modelInput = preprocess(input.first) + val modelOutput = forward(modelInput) + return postprocess(modelOutput, input.second, input.third) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/StyleTransferModel.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/styleTransfer/StyleTransferModel.kt similarity index 92% rename from android/src/main/java/com/swmansion/rnexecutorch/models/StyleTransferModel.kt rename to android/src/main/java/com/swmansion/rnexecutorch/models/styleTransfer/StyleTransferModel.kt index 72d3bc6d36..4019015dd8 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/StyleTransferModel.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/styleTransfer/StyleTransferModel.kt @@ -1,4 +1,4 @@ -package com.swmansion.rnexecutorch.models +package com.swmansion.rnexecutorch.models.styletransfer import com.facebook.react.bridge.ReactApplicationContext import com.swmansion.rnexecutorch.utils.ImageProcessor @@ -6,6 +6,7 @@ import org.opencv.core.Mat import org.opencv.core.Size import org.opencv.imgproc.Imgproc import org.pytorch.executorch.EValue +import com.swmansion.rnexecutorch.models.BaseModel class StyleTransferModel( reactApplicationContext: ReactApplicationContext, diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/classification/Utils.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/Numerical.kt similarity index 77% rename from android/src/main/java/com/swmansion/rnexecutorch/models/classification/Utils.kt rename to android/src/main/java/com/swmansion/rnexecutorch/utils/Numerical.kt index e919950a0a..603699e35f 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/classification/Utils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/Numerical.kt @@ -1,4 +1,4 @@ -package com.swmansion.rnexecutorch.models.classification +package com.swmansion.rnexecutorch.utils fun softmax(x: Array): Array { val max = x.maxOrNull()!! diff --git a/src/modules/computer_vision/ImageSegmentationModule.ts b/src/modules/computer_vision/ImageSegmentationModule.ts index 1d078c1cfd..006e23f0a0 100644 --- a/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/src/modules/computer_vision/ImageSegmentationModule.ts @@ -1,21 +1,32 @@ import { BaseModule } from '../BaseModule'; import { _ImageSegmentationModule } from '../../native/RnExecutorchModules'; import { getError } from '../../Error'; +import { DeeplabLabel } from '../../types/image_segmentation'; export class ImageSegmentationModule extends BaseModule { static module = new _ImageSegmentationModule(); static async forward( input: string, - classesOfInterest: string[], - resize: boolean + classesOfInterest?: DeeplabLabel[], + resize?: boolean ) { try { - return await (this.module.forward( + const stringDict = await (this.module.forward( input, - classesOfInterest, - resize + (classesOfInterest || []).map((label) => DeeplabLabel[label]), + resize || false ) as ReturnType<_ImageSegmentationModule['forward']>); + + let enumDict: { [key in DeeplabLabel]?: number[] } = {}; + + for (const key in stringDict) { + if (key in DeeplabLabel) { + const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel]; + enumDict[enumKey] = stringDict[key]; + } + } + return enumDict; } catch (e) { throw new Error(getError(e)); }