Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.swmansion.rnexecutorch

import android.util.Log
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.models.imagesegmentation.ImageSegmentationModel
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader

class ImageSegmentation(
reactContext: ReactApplicationContext,
) : NativeImageSegmentationSpec(reactContext) {
private lateinit var model: ImageSegmentationModel

companion object {
const val NAME = "ImageSegmentation"

init {
if (!OpenCVLoader.initLocal()) {
Log.d("rn_executorch", "OpenCV not loaded")
} else {
Log.d("rn_executorch", "OpenCV loaded")
}
}
}

override fun loadModule(
modelSource: String,
promise: Promise,
) {
try {
model = ImageSegmentationModel(reactApplicationContext)
model.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

override fun forward(
input: String,
classesOfInterest: ReadableArray,
resize: Boolean,
promise: Promise,
) {
try {
val output =
model.runModel(Triple(ImageProcessor.readImage(input), classesOfInterest, resize))
promise.resolve(output)
} catch (e: Exception) {
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String = NAME
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class RnExecutorchPackage : TurboReactPackage() {
OCR(reactContext)
} else if (name == VerticalOCR.NAME) {
VerticalOCR(reactContext)
} else if (name == ImageSegmentation.NAME) {
ImageSegmentation(reactContext)
} else {
null
}
Expand Down Expand Up @@ -115,6 +117,13 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true,
)

moduleInfos[ImageSegmentation.NAME] = ReactModuleInfo(
ImageSegmentation.NAME, ImageSegmentation.NAME, false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.swmansion.rnexecutorch
import android.util.Log
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.StyleTransferModel
import com.swmansion.rnexecutorch.models.styletransfer.StyleTransferModel
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.swmansion.rnexecutorch.models.classification
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.ImageProcessor
import com.swmansion.rnexecutorch.utils.softmax
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.swmansion.rnexecutorch.models.imagesegmentation

val deeplabv3_resnet50_labels: Array<String> =
arrayOf(
"BACKGROUND",
"AEROPLANE",
"BICYCLE",
"BIRD",
"BOAT",
"BOTTLE",
"BUS",
"CAR",
"CAT",
"CHAIR",
"COW",
"DININGTABLE",
"DOG",
"HORSE",
"MOTORBIKE",
"PERSON",
"POTTEDPLANT",
"SHEEP",
"SOFA",
"TRAIN",
"TVMONITOR",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package com.swmansion.rnexecutorch.models.imagesegmentation

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.facebook.react.bridge.WritableMap
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.ImageProcessor
import com.swmansion.rnexecutorch.utils.softmax
import org.opencv.core.CvType
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import org.pytorch.executorch.EValue

class ImageSegmentationModel(
reactApplicationContext: ReactApplicationContext,
) : BaseModel<Triple<Mat, ReadableArray, Boolean>, WritableMap>(reactApplicationContext) {
private lateinit var originalSize: Size

private fun getModelImageSize(): Size {
val inputShape = module.getInputShape(0)
val width = inputShape[inputShape.lastIndex]
val height = inputShape[inputShape.lastIndex - 1]

return Size(height.toDouble(), width.toDouble())
}

fun preprocess(input: Mat): EValue {
originalSize = input.size()
Imgproc.resize(input, input, getModelImageSize())
return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

private fun extractResults(
result: FloatArray,
numLabels: Int,
resize: Boolean,
): List<FloatArray> {
val modelSize = getModelImageSize()
val numModelPixels = (modelSize.height * modelSize.width).toInt()

val extractedLabelScores = mutableListOf<FloatArray>()

for (label in 0..<numLabels) {
// Calls to OpenCV via JNI are very slow so we do as much as we can
// with pure Kotlin
val range = IntRange(label * numModelPixels, (label + 1) * numModelPixels - 1)
val pixelBuffer = result.slice(range).toFloatArray()

if (resize) {
// Rescale the image with OpenCV
val mat = Mat(modelSize, CvType.CV_32F)
mat.put(0, 0, pixelBuffer)
val resizedMat = Mat()
Imgproc.resize(mat, resizedMat, originalSize)
val resizedBuffer = FloatArray((originalSize.height * originalSize.width).toInt())
resizedMat.get(0, 0, resizedBuffer)
extractedLabelScores.add(resizedBuffer)
} else {
extractedLabelScores.add(pixelBuffer)
}
}
return extractedLabelScores
}

private fun adjustScoresPerPixel(
labelScores: List<FloatArray>,
numLabels: Int,
outputSize: Size,
): IntArray {
val numPixels = (outputSize.height * outputSize.width).toInt()
val argMax = IntArray(numPixels)
for (pixel in 0..<numPixels) {
val scores = mutableListOf<Float>()
for (buffer in labelScores) {
scores.add(buffer[pixel])
}
val adjustedScores = softmax(scores.toTypedArray())
for (label in 0..<numLabels) {
labelScores[label][pixel] = adjustedScores[label]
}

val maxIndex = scores.withIndex().maxBy { it.value }.index
argMax[pixel] = maxIndex
}

return argMax
}

fun postprocess(
output: Array<EValue>,
classesOfInterest: ReadableArray,
resize: Boolean,
): WritableMap {
val outputData = output[0].toTensor().dataAsFloatArray
val modelSize = getModelImageSize()
val numLabels = deeplabv3_resnet50_labels.size

require(outputData.count() == (numLabels * modelSize.height * modelSize.width).toInt()) { "Model generated unexpected output size." }

val outputSize = if (resize) originalSize else modelSize

val extractedResults = extractResults(outputData, numLabels, resize)

val argMax = adjustScoresPerPixel(extractedResults, numLabels, outputSize)

val labelSet = mutableSetOf<String>()
// Filter by the label set when base class changed
for (i in 0..<classesOfInterest.size()) {
labelSet.add(classesOfInterest.getString(i))
}

val res = Arguments.createMap()

for (label in 0..<numLabels) {
if (labelSet.contains(deeplabv3_resnet50_labels[label])) {
res.putArray(
deeplabv3_resnet50_labels[label],
ArrayUtils.createReadableArrayFromFloatArray(extractedResults[label]),
)
}
}

res.putArray(
"ARGMAX",
ArrayUtils.createReadableArrayFromIntArray(argMax),
)

return res
}

override fun runModel(input: Triple<Mat, ReadableArray, Boolean>): WritableMap {
val modelInput = preprocess(input.first)
val modelOutput = forward(modelInput)
return postprocess(modelOutput, input.second, input.third)
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package com.swmansion.rnexecutorch.models
package com.swmansion.rnexecutorch.models.styletransfer

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import org.pytorch.executorch.EValue
import com.swmansion.rnexecutorch.models.BaseModel

class StyleTransferModel(
reactApplicationContext: ReactApplicationContext,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.swmansion.rnexecutorch.models.classification
package com.swmansion.rnexecutorch.utils

fun softmax(x: Array<Float>): Array<Float> {
val max = x.maxOrNull()!!
Expand Down
21 changes: 16 additions & 5 deletions src/modules/computer_vision/ImageSegmentationModule.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
import { BaseModule } from '../BaseModule';
import { _ImageSegmentationModule } from '../../native/RnExecutorchModules';
import { getError } from '../../Error';
import { DeeplabLabel } from '../../types/image_segmentation';

export class ImageSegmentationModule extends BaseModule {
static module = new _ImageSegmentationModule();

static async forward(
input: string,
classesOfInterest: string[],
resize: boolean
classesOfInterest?: DeeplabLabel[],
resize?: boolean
) {
try {
return await (this.module.forward(
const stringDict = await (this.module.forward(
input,
classesOfInterest,
resize
(classesOfInterest || []).map((label) => DeeplabLabel[label]),
resize || false
) as ReturnType<_ImageSegmentationModule['forward']>);

let enumDict: { [key in DeeplabLabel]?: number[] } = {};

for (const key in stringDict) {
if (key in DeeplabLabel) {
const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel];
enumDict[enumKey] = stringDict[key];
}
}
return enumDict;
} catch (e) {
throw new Error(getError(e));
}
Expand Down