From 0b86fd1029a7f06fbd3f3d09b07af66c3dceea44 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 11 Feb 2026 11:48:16 +0100 Subject: [PATCH 01/71] fix: correct frame data extraction --- .../object_detection/ObjectDetection.cpp | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 1dad1a61a5..2293a4abc8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -79,6 +79,46 @@ cv::Mat ObjectDetection::preprocessFrame(const cv::Mat &frame) const { return rgb; } +cv::Mat ObjectDetection::preprocessFrame(const cv::Mat &frame) const { + // Get target size from model input shape + const std::vector tensorDims = getAllInputShapes()[0]; + cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1], + tensorDims[tensorDims.size() - 2]); + + cv::Mat rgb; + + // Convert RGBA/BGRA to RGB if needed (for VisionCamera frames) + if (frame.channels() == 4) { +// Platform-specific color conversion: +// iOS uses BGRA format, Android uses RGBA format +#ifdef __APPLE__ + // iOS: BGRA → RGB + cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); +#else + // Android: RGBA → RGB + cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); +#endif + } else if (frame.channels() == 3) { + // Already RGB + rgb = frame; + } else { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Unsupported frame format: %d channels", frame.channels()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + // Only resize if dimensions don't match + if (rgb.size() != tensorSize) { + cv::Mat resized; + cv::resize(rgb, resized, tensorSize); + return resized; + } + + return rgb; +} + std::vector ObjectDetection::postprocess(const std::vector &tensors, cv::Size originalSize, double detectionThreshold) { From 40c314c8b7b878c4ebd9373980c3f65a45cf8664 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Feb 2026 10:37:11 +0100 Subject: [PATCH 02/71] feat: unify frame extraction and preprocessing --- .../models/classification/Classification.cpp | 2 +- .../src/modules/BaseModule.ts | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index 0fba071087..b9fad1b88b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -73,4 +73,4 @@ Classification::postprocess(const Tensor &tensor) { return probs; } -} // namespace rnexecutorch::models::classification +} // namespace rnexecutorch::models::classification \ No newline at end of file diff --git a/packages/react-native-executorch/src/modules/BaseModule.ts b/packages/react-native-executorch/src/modules/BaseModule.ts index c844cf358b..d4c3b699b1 100644 --- a/packages/react-native-executorch/src/modules/BaseModule.ts +++ b/packages/react-native-executorch/src/modules/BaseModule.ts @@ -88,4 +88,20 @@ export abstract class BaseModule { this.nativeModule.unload(); } } + + /** + * Bind JSI methods to this instance for worklet compatibility. + * + * This makes native JSI functions accessible from worklet threads, + * which is essential for VisionCamera frame processing. + * + * @internal + */ + protected bindJSIMethods() { + if (this.nativeModule && this.nativeModule.generateFromFrame) { + // Bind the native JSI method directly to this instance + // This makes it worklet-compatible since JSI functions work across threads + this.generateFromFrame = this.nativeModule.generateFromFrame; + } + } } From 797c87c85b8693165ee27c675f9d5bfcf96c6cf8 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Feb 2026 11:22:16 +0100 Subject: [PATCH 03/71] feat: remove unused bindJSIMethods --- .../src/modules/BaseModule.ts | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/packages/react-native-executorch/src/modules/BaseModule.ts b/packages/react-native-executorch/src/modules/BaseModule.ts index d4c3b699b1..c844cf358b 100644 --- a/packages/react-native-executorch/src/modules/BaseModule.ts +++ b/packages/react-native-executorch/src/modules/BaseModule.ts @@ -88,20 +88,4 @@ export abstract class BaseModule { this.nativeModule.unload(); } } - - /** - * Bind JSI methods to this instance for worklet compatibility. - * - * This makes native JSI functions accessible from worklet threads, - * which is essential for VisionCamera frame processing. - * - * @internal - */ - protected bindJSIMethods() { - if (this.nativeModule && this.nativeModule.generateFromFrame) { - // Bind the native JSI method directly to this instance - // This makes it worklet-compatible since JSI functions work across threads - this.generateFromFrame = this.nativeModule.generateFromFrame; - } - } } From 30025316341467ce3f50dff4e42cb962d5f79b9d Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 17 Feb 2026 13:05:14 +0100 Subject: [PATCH 04/71] feat: initial version of vision model API --- .../app/object_detection/index.tsx | 163 ++++++++++++++++- .../host_objects/ModelHostObject.h | 12 ++ .../metaprogramming/TypeConcepts.h | 9 +- .../models/embeddings/image/ImageEmbeddings.h | 2 +- .../BaseSemanticSegmentation.h | 2 +- .../ImageSegmentation.cpp | 170 ++++++++++++++++++ .../models/style_transfer/StyleTransfer.h | 2 +- .../src/hooks/useModule.ts | 1 + .../src/types/objectDetection.ts | 14 ++ 9 files changed, 363 insertions(+), 12 deletions(-) create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index 2f8fa6d58e..0dda13e9d9 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -1,16 +1,66 @@ import Spinner from '../../components/Spinner'; -import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; import { Detection, useObjectDetection, RF_DETR_NANO, } from 'react-native-executorch'; -import { View, StyleSheet, Image } from 'react-native'; +import { View, StyleSheet, Image, TouchableOpacity, Text } from 'react-native'; import ImageWithBboxes from '../../components/ImageWithBboxes'; import React, { useContext, useEffect, useState } from 'react'; import { GeneratingContext } from '../../context'; import ScreenWrapper from '../../ScreenWrapper'; +import ColorPalette from '../../colors'; +import { Images } from 'react-native-nitro-image'; + +// Helper function to convert image URI to raw pixel data using NitroImage +async function imageUriToPixelData( + uri: string, + targetWidth: number, + targetHeight: number +): Promise<{ + data: ArrayBuffer; + width: number; + height: number; + channels: number; +}> { + try { + // Load image and resize to target dimensions + const image = await Images.loadFromFileAsync(uri); + const resized = image.resize(targetWidth, targetHeight); + + // Get pixel data as ArrayBuffer (RGBA format) + const pixelData = resized.toRawPixelData(); + const buffer = + pixelData instanceof ArrayBuffer ? pixelData : pixelData.buffer; + + // Calculate actual buffer dimensions (accounts for device pixel ratio) + const bufferSize = buffer?.byteLength || 0; + const totalPixels = bufferSize / 4; // RGBA = 4 bytes per pixel + const aspectRatio = targetWidth / targetHeight; + const actualHeight = Math.sqrt(totalPixels / aspectRatio); + const actualWidth = totalPixels / actualHeight; + + console.log('Requested:', targetWidth, 'x', targetHeight); + console.log('Buffer size:', bufferSize); + console.log( + 'Actual dimensions:', + Math.round(actualWidth), + 'x', + Math.round(actualHeight) + ); + + return { + data: buffer, + width: Math.round(actualWidth), + height: Math.round(actualHeight), + channels: 4, // RGBA + }; + } catch (error) { + console.error('Error loading image with NitroImage:', error); + throw error; + } +} export default function ObjectDetectionScreen() { const [imageUri, setImageUri] = useState(''); @@ -45,7 +95,36 @@ export default function ObjectDetectionScreen() { const output = await rfDetr.forward(imageUri); setResults(output); } catch (e) { - console.error(e); + console.error('Error in runForward:', e); + } + } + }; + + const runForwardPixels = async () => { + if (imageUri && imageDimensions) { + try { + console.log('Converting image to pixel data...'); + // Resize to 640x640 to avoid memory issues + const intermediateSize = 640; + const pixelData = await imageUriToPixelData( + imageUri, + intermediateSize, + intermediateSize + ); + + console.log('Running forward with pixel data...', { + width: pixelData.width, + height: pixelData.height, + channels: pixelData.channels, + dataSize: pixelData.data.byteLength, + }); + + // Run inference using unified forward() API + const output = await ssdLite.forward(pixelData, 0.5); + console.log('Pixel data result:', output.length, 'detections'); + setResults(output); + } catch (e) { + console.error('Error in runForwardPixels:', e); } } }; @@ -81,10 +160,41 @@ export default function ObjectDetectionScreen() { )} - + + {/* Custom bottom bar with two buttons */} + + + handleCameraPress(false)}> + 📷 Gallery + + + + + + Run (String) + + + + Run (Pixels) + + + ); } @@ -129,4 +239,43 @@ const styles = StyleSheet.create({ width: '100%', height: '100%', }, + bottomContainer: { + width: '100%', + gap: 15, + alignItems: 'center', + padding: 16, + flex: 1, + }, + bottomIconsContainer: { + flexDirection: 'row', + justifyContent: 'center', + width: '100%', + }, + iconText: { + fontSize: 16, + color: ColorPalette.primary, + }, + buttonsRow: { + flexDirection: 'row', + width: '100%', + gap: 10, + }, + button: { + height: 50, + justifyContent: 'center', + alignItems: 'center', + backgroundColor: ColorPalette.primary, + color: '#fff', + borderRadius: 8, + }, + halfButton: { + flex: 1, + }, + buttonDisabled: { + opacity: 0.5, + }, + buttonText: { + color: '#fff', + fontSize: 16, + }, }); diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index c13b8991dc..e3c61afce9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -187,6 +187,10 @@ template class ModelHostObject : public JsiHostObject { addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::stream>, "stream")); + } + + // Register generateFromFrame for all VisionModel subclasses + if constexpr (meta::DerivedFromOrSameAs) { addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, synchronousHostFunction<&Model::streamStop>, "streamStop")); @@ -222,6 +226,14 @@ template class ModelHostObject : public JsiHostObject { promiseHostFunction<&Model::generateFromPixels>, "generateFromPixels")); } + + // Register generateFromPixels for models that support it + if constexpr (meta::HasGenerateFromPixels) { + addFunctions( + JSI_EXPORT_FUNCTION(ModelHostObject, + visionHostFunction<&Model::generateFromPixels>, + "generateFromPixels")); + } } // A generic host function that runs synchronously, works analogously to the diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index 2d7612f250..fdf8c9dba7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -12,8 +12,13 @@ template concept SameAs = std::is_same_v; template -concept HasGenerate = requires(T t) { - { &T::generate }; +concept HasGenerateFromString = requires(T t) { + { &T::generateFromString }; +}; + +template +concept HasGenerateFromPixels = requires(T t) { + { &T::generateFromPixels }; }; template diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h index 7e114e939d..9a1d6429bd 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h @@ -27,4 +27,4 @@ class ImageEmbeddings final : public BaseEmbeddings { REGISTER_CONSTRUCTOR(models::embeddings::ImageEmbeddings, std::string, std::shared_ptr); -} // namespace rnexecutorch +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index d39a7e5d4a..8ba422afba 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -59,4 +59,4 @@ REGISTER_CONSTRUCTOR(models::semantic_segmentation::BaseSemanticSegmentation, std::string, std::vector, std::vector, std::vector, std::shared_ptr); -} // namespace rnexecutorch +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp new file mode 100644 index 0000000000..08f2a4683a --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp @@ -0,0 +1,170 @@ +#include "ImageSegmentation.h" + +#include + +#include +#include +#include +#include +#include +#include + +namespace rnexecutorch::models::image_segmentation { + +ImageSegmentation::ImageSegmentation( + const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + auto inputShapes = getAllInputShapes(); + if (inputShapes.size() == 0) { + throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, + "Model seems to not take any input tensors."); + } + std::vector modelInputShape = inputShapes[0]; + if (modelInputShape.size() < 2) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Unexpected model input size, expected at least 2 dimentions " + "but got: %zu.", + modelInputShape.size()); + throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, + errorMessage); + } + modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], + modelInputShape[modelInputShape.size() - 2]); + numModelPixels = modelImageSize.area(); +} + +std::shared_ptr ImageSegmentation::generate( + std::string imageSource, + std::set> classesOfInterest, bool resize) { + auto [inputTensor, originalSize] = + image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); + + auto forwardResult = BaseModel::forward(inputTensor); + if (!forwardResult.ok()) { + throw RnExecutorchError(forwardResult.error(), + "The model's forward function did not succeed. " + "Ensure the model input is correct."); + } + + return postprocess(forwardResult->at(0).toTensor(), originalSize, + classesOfInterest, resize); +} + +std::shared_ptr ImageSegmentation::postprocess( + const Tensor &tensor, cv::Size originalSize, + std::set> classesOfInterest, bool resize) { + + auto dataPtr = static_cast(tensor.const_data_ptr()); + auto resultData = std::span(dataPtr, tensor.numel()); + + // We copy the ET-owned data to jsi array buffers that can be directly + // returned to JS + std::vector> resultClasses; + resultClasses.reserve(numClasses); + for (std::size_t cl = 0; cl < numClasses; ++cl) { + auto classBuffer = std::make_shared( + &resultData[cl * numModelPixels], numModelPixels * sizeof(float)); + resultClasses.push_back(classBuffer); + } + + // Apply softmax per each pixel across all classes + for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { + std::vector classValues(numClasses); + for (std::size_t cl = 0; cl < numClasses; ++cl) { + classValues[cl] = + reinterpret_cast(resultClasses[cl]->data())[pixel]; + } + numerical::softmax(classValues); + for (std::size_t cl = 0; cl < numClasses; ++cl) { + reinterpret_cast(resultClasses[cl]->data())[pixel] = + classValues[cl]; + } + } + + // Calculate the maximum class for each pixel + auto argmax = + std::make_shared(numModelPixels * sizeof(int32_t)); + for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { + float max = reinterpret_cast(resultClasses[0]->data())[pixel]; + int maxInd = 0; + for (int cl = 1; cl < numClasses; ++cl) { + if (reinterpret_cast(resultClasses[cl]->data())[pixel] > max) { + maxInd = cl; + max = reinterpret_cast(resultClasses[cl]->data())[pixel]; + } + } + reinterpret_cast(argmax->data())[pixel] = maxInd; + } + + auto buffersToReturn = std::make_shared>>(); + for (std::size_t cl = 0; cl < numClasses; ++cl) { + if (classesOfInterest.contains(constants::kDeeplabV3Resnet50Labels[cl])) { + (*buffersToReturn)[constants::kDeeplabV3Resnet50Labels[cl]] = + resultClasses[cl]; + } + } + + // Resize selected classes and argmax + if (resize) { + cv::Mat argmaxMat(modelImageSize, CV_32SC1, argmax->data()); + cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0, + cv::InterpolationFlags::INTER_NEAREST); + argmax = std::make_shared( + argmaxMat.data, originalSize.area() * sizeof(int32_t)); + + for (auto &[label, arrayBuffer] : *buffersToReturn) { + cv::Mat classMat(modelImageSize, CV_32FC1, arrayBuffer->data()); + cv::resize(classMat, classMat, originalSize); + arrayBuffer = std::make_shared( + classMat.data, originalSize.area() * sizeof(float)); + } + } + return populateDictionary(argmax, buffersToReturn); +} + +std::shared_ptr ImageSegmentation::populateDictionary( + std::shared_ptr argmax, + std::shared_ptr>> + classesToOutput) { + // Synchronize the invoked thread to return when the dict is constructed + auto promisePtr = std::make_shared>(); + std::future doneFuture = promisePtr->get_future(); + + std::shared_ptr dictPtr = nullptr; + callInvoker->invokeAsync( + [argmax, classesToOutput, &dictPtr, promisePtr](jsi::Runtime &runtime) { + dictPtr = std::make_shared(runtime); + auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax); + + auto int32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Int32Array"); + auto int32Array = + int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) + .getObject(runtime); + dictPtr->setProperty(runtime, "ARGMAX", int32Array); + + for (auto &[classLabel, owningBuffer] : *classesToOutput) { + auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); + + auto float32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Float32Array"); + auto float32Array = + float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) + .getObject(runtime); + + dictPtr->setProperty( + runtime, jsi::String::createFromAscii(runtime, classLabel.data()), + float32Array); + } + promisePtr->set_value(); + }); + + doneFuture.wait(); + return dictPtr; +} + +} // namespace rnexecutorch::models::image_segmentation \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h index 73744c4d82..8eed3c888d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -33,4 +33,4 @@ class StyleTransfer : public BaseModel { REGISTER_CONSTRUCTOR(models::style_transfer::StyleTransfer, std::string, std::shared_ptr); -} // namespace rnexecutorch +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/src/hooks/useModule.ts b/packages/react-native-executorch/src/hooks/useModule.ts index cc1fc1ef2e..a564492e6b 100644 --- a/packages/react-native-executorch/src/hooks/useModule.ts +++ b/packages/react-native-executorch/src/hooks/useModule.ts @@ -8,6 +8,7 @@ interface Module { load: (...args: any[]) => Promise; forward: (...args: any[]) => Promise; delete: () => void; + nativeModule?: any; // JSI host object with native methods } interface ModuleConstructor { diff --git a/packages/react-native-executorch/src/types/objectDetection.ts b/packages/react-native-executorch/src/types/objectDetection.ts index aa25e9c412..d8afc3354b 100644 --- a/packages/react-native-executorch/src/types/objectDetection.ts +++ b/packages/react-native-executorch/src/types/objectDetection.ts @@ -110,6 +110,20 @@ export interface ObjectDetectionType { * @param detectionThreshold - An optional number between 0 and 1 representing the minimum confidence score. Default is 0.7. * @returns A Promise that resolves to an array of `Detection` objects. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. + * + * @example + * ```typescript + * // String path + * const detections1 = await model.forward('file:///path/to/image.jpg'); + * + * // Pixel data + * const detections2 = await model.forward({ + * data: pixelBuffer, + * width: 640, + * height: 480, + * channels: 3 + * }); + * ``` */ forward: ( input: string | PixelData, From cee17f78512f874fd2c8a61620bf67b06e5f2cc7 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 17 Feb 2026 17:51:10 +0100 Subject: [PATCH 05/71] refactor: errors, logs, unnecessary comments, use existing TensorPtr --- .../app/object_detection/index.tsx | 61 ++++++++++--------- .../host_objects/JsiConversions.h | 19 ++++++ .../host_objects/ModelHostObject.h | 2 - .../src/hooks/useModule.ts | 1 - 4 files changed, 52 insertions(+), 31 deletions(-) diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index 0dda13e9d9..fb18a81282 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -13,7 +13,26 @@ import ScreenWrapper from '../../ScreenWrapper'; import ColorPalette from '../../colors'; import { Images } from 'react-native-nitro-image'; -// Helper function to convert image URI to raw pixel data using NitroImage +// Helper function to convert BGRA to RGB +function convertBGRAtoRGB( + buffer: ArrayBuffer, + width: number, + height: number +): ArrayBuffer { + const source = new Uint8Array(buffer); + const rgb = new Uint8Array(width * height * 3); + + for (let i = 0; i < width * height; i++) { + // BGRA format: [B, G, R, A] → RGB: [R, G, B] + rgb[i * 3 + 0] = source[i * 4 + 2]; // R + rgb[i * 3 + 1] = source[i * 4 + 1]; // G + rgb[i * 3 + 2] = source[i * 4 + 0]; // B + } + + return rgb.buffer; +} + +// Helper function to convert image URI to raw RGB pixel data async function imageUriToPixelData( uri: string, targetWidth: number, @@ -29,32 +48,19 @@ async function imageUriToPixelData( const image = await Images.loadFromFileAsync(uri); const resized = image.resize(targetWidth, targetHeight); - // Get pixel data as ArrayBuffer (RGBA format) - const pixelData = resized.toRawPixelData(); + // Get pixel data as ArrayBuffer (BGRA format from NitroImage) + const rawPixelData = resized.toRawPixelData(); const buffer = - pixelData instanceof ArrayBuffer ? pixelData : pixelData.buffer; - - // Calculate actual buffer dimensions (accounts for device pixel ratio) - const bufferSize = buffer?.byteLength || 0; - const totalPixels = bufferSize / 4; // RGBA = 4 bytes per pixel - const aspectRatio = targetWidth / targetHeight; - const actualHeight = Math.sqrt(totalPixels / aspectRatio); - const actualWidth = totalPixels / actualHeight; + rawPixelData instanceof ArrayBuffer ? rawPixelData : rawPixelData.buffer; - console.log('Requested:', targetWidth, 'x', targetHeight); - console.log('Buffer size:', bufferSize); - console.log( - 'Actual dimensions:', - Math.round(actualWidth), - 'x', - Math.round(actualHeight) - ); + // Convert BGRA to RGB as required by the native API + const rgbBuffer = convertBGRAtoRGB(buffer, targetWidth, targetHeight); return { - data: buffer, - width: Math.round(actualWidth), - height: Math.round(actualHeight), - channels: 4, // RGBA + data: rgbBuffer, + width: targetWidth, + height: targetHeight, + channels: 3, // RGB }; } catch (error) { console.error('Error loading image with NitroImage:', error); @@ -104,12 +110,11 @@ export default function ObjectDetectionScreen() { if (imageUri && imageDimensions) { try { console.log('Converting image to pixel data...'); - // Resize to 640x640 to avoid memory issues - const intermediateSize = 640; + // Use original dimensions - let the model resize internally const pixelData = await imageUriToPixelData( imageUri, - intermediateSize, - intermediateSize + imageDimensions.width, + imageDimensions.height ); console.log('Running forward with pixel data...', { @@ -120,7 +125,7 @@ export default function ObjectDetectionScreen() { }); // Run inference using unified forward() API - const output = await ssdLite.forward(pixelData, 0.5); + const output = await ssdLite.forward(pixelData, 0.3); console.log('Pixel data result:', output.length, 'detections'); setResults(output); } catch (e) { diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 96e3168ee7..8936711477 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -368,6 +368,25 @@ inline jsi::Value getJsiValue(uint64_t val, jsi::Runtime &runtime) { return {runtime, bigInt}; } +inline jsi::Value getJsiValue(const std::vector &vec, + jsi::Runtime &runtime) { + jsi::Array array(runtime, vec.size()); + for (size_t i = 0; i < vec.size(); i++) { + // JS numbers are doubles. Large uint64s > 2^53 will lose precision. + array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); + } + return {runtime, array}; +} + +inline jsi::Value getJsiValue(const std::vector &vec, + jsi::Runtime &runtime) { + jsi::Array array(runtime, vec.size()); + for (size_t i = 0; i < vec.size(); i++) { + array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); + } + return {runtime, array}; +} + inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) { return {runtime, val}; } diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index e3c61afce9..35789847dc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -189,7 +189,6 @@ template class ModelHostObject : public JsiHostObject { "stream")); } - // Register generateFromFrame for all VisionModel subclasses if constexpr (meta::DerivedFromOrSameAs) { addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, synchronousHostFunction<&Model::streamStop>, @@ -227,7 +226,6 @@ template class ModelHostObject : public JsiHostObject { "generateFromPixels")); } - // Register generateFromPixels for models that support it if constexpr (meta::HasGenerateFromPixels) { addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, diff --git a/packages/react-native-executorch/src/hooks/useModule.ts b/packages/react-native-executorch/src/hooks/useModule.ts index a564492e6b..cc1fc1ef2e 100644 --- a/packages/react-native-executorch/src/hooks/useModule.ts +++ b/packages/react-native-executorch/src/hooks/useModule.ts @@ -8,7 +8,6 @@ interface Module { load: (...args: any[]) => Promise; forward: (...args: any[]) => Promise; delete: () => void; - nativeModule?: any; // JSI host object with native methods } interface ModuleConstructor { From b0e61f35376e6898049652ec9e3e3ef104d372db Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 18 Feb 2026 12:49:15 +0100 Subject: [PATCH 06/71] feat: use TensorPtrish type for Pixel data input --- .../app/object_detection/index.tsx | 117 ++++++------------ .../src/types/objectDetection.ts | 7 +- 2 files changed, 41 insertions(+), 83 deletions(-) diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index fb18a81282..e601e9cb11 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -11,62 +11,6 @@ import React, { useContext, useEffect, useState } from 'react'; import { GeneratingContext } from '../../context'; import ScreenWrapper from '../../ScreenWrapper'; import ColorPalette from '../../colors'; -import { Images } from 'react-native-nitro-image'; - -// Helper function to convert BGRA to RGB -function convertBGRAtoRGB( - buffer: ArrayBuffer, - width: number, - height: number -): ArrayBuffer { - const source = new Uint8Array(buffer); - const rgb = new Uint8Array(width * height * 3); - - for (let i = 0; i < width * height; i++) { - // BGRA format: [B, G, R, A] → RGB: [R, G, B] - rgb[i * 3 + 0] = source[i * 4 + 2]; // R - rgb[i * 3 + 1] = source[i * 4 + 1]; // G - rgb[i * 3 + 2] = source[i * 4 + 0]; // B - } - - return rgb.buffer; -} - -// Helper function to convert image URI to raw RGB pixel data -async function imageUriToPixelData( - uri: string, - targetWidth: number, - targetHeight: number -): Promise<{ - data: ArrayBuffer; - width: number; - height: number; - channels: number; -}> { - try { - // Load image and resize to target dimensions - const image = await Images.loadFromFileAsync(uri); - const resized = image.resize(targetWidth, targetHeight); - - // Get pixel data as ArrayBuffer (BGRA format from NitroImage) - const rawPixelData = resized.toRawPixelData(); - const buffer = - rawPixelData instanceof ArrayBuffer ? rawPixelData : rawPixelData.buffer; - - // Convert BGRA to RGB as required by the native API - const rgbBuffer = convertBGRAtoRGB(buffer, targetWidth, targetHeight); - - return { - data: rgbBuffer, - width: targetWidth, - height: targetHeight, - channels: 3, // RGB - }; - } catch (error) { - console.error('Error loading image with NitroImage:', error); - throw error; - } -} export default function ObjectDetectionScreen() { const [imageUri, setImageUri] = useState(''); @@ -107,30 +51,45 @@ export default function ObjectDetectionScreen() { }; const runForwardPixels = async () => { - if (imageUri && imageDimensions) { - try { - console.log('Converting image to pixel data...'); - // Use original dimensions - let the model resize internally - const pixelData = await imageUriToPixelData( - imageUri, - imageDimensions.width, - imageDimensions.height - ); - - console.log('Running forward with pixel data...', { - width: pixelData.width, - height: pixelData.height, - channels: pixelData.channels, - dataSize: pixelData.data.byteLength, - }); - - // Run inference using unified forward() API - const output = await ssdLite.forward(pixelData, 0.3); - console.log('Pixel data result:', output.length, 'detections'); - setResults(output); - } catch (e) { - console.error('Error in runForwardPixels:', e); + try { + console.log('Testing with hardcoded pixel data...'); + + // Create a simple 320x320 test image (all zeros - black image) + // In a real scenario, you would load actual image pixel data here + const width = 320; + const height = 320; + const channels = 3; // RGB + + // Create a black image (you can replace this with actual pixel data) + const rgbData = new Uint8Array(width * height * channels); + + // Optionally, add some test pattern (e.g., white square in center) + for (let y = 100; y < 220; y++) { + for (let x = 100; x < 220; x++) { + const idx = (y * width + x) * 3; + rgbData[idx + 0] = 255; // R + rgbData[idx + 1] = 255; // G + rgbData[idx + 2] = 255; // B + } } + + const pixelData: PixelData = { + dataPtr: rgbData, + sizes: [height, width, channels], + scalarType: ScalarType.BYTE, + }; + + console.log('Running forward with hardcoded pixel data...', { + sizes: pixelData.sizes, + dataSize: pixelData.dataPtr.byteLength, + }); + + // Run inference using unified forward() API + const output = await ssdLite.forward(pixelData, 0.3); + console.log('Pixel data result:', output.length, 'detections'); + setResults(output); + } catch (e) { + console.error('Error in runForwardPixels:', e); } }; diff --git a/packages/react-native-executorch/src/types/objectDetection.ts b/packages/react-native-executorch/src/types/objectDetection.ts index d8afc3354b..38dc4bd12d 100644 --- a/packages/react-native-executorch/src/types/objectDetection.ts +++ b/packages/react-native-executorch/src/types/objectDetection.ts @@ -118,10 +118,9 @@ export interface ObjectDetectionType { * * // Pixel data * const detections2 = await model.forward({ - * data: pixelBuffer, - * width: 640, - * height: 480, - * channels: 3 + * dataPtr: new Uint8Array(rgbPixels), + * sizes: [480, 640, 3], + * scalarType: ScalarType.BYTE * }); * ``` */ From 4beb708c5c4f7d34af0e6cc617c903f736d79aa6 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 18 Feb 2026 13:03:22 +0100 Subject: [PATCH 07/71] refactor: add or remove empty lines --- .../rnexecutorch/models/classification/Classification.cpp | 2 +- .../rnexecutorch/models/embeddings/image/ImageEmbeddings.h | 2 +- .../models/semantic_segmentation/BaseSemanticSegmentation.h | 2 +- .../models/semantic_segmentation/ImageSegmentation.cpp | 2 +- .../common/rnexecutorch/models/style_transfer/StyleTransfer.h | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index b9fad1b88b..0fba071087 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -73,4 +73,4 @@ Classification::postprocess(const Tensor &tensor) { return probs; } -} // namespace rnexecutorch::models::classification \ No newline at end of file +} // namespace rnexecutorch::models::classification diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h index 9a1d6429bd..7e114e939d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h @@ -27,4 +27,4 @@ class ImageEmbeddings final : public BaseEmbeddings { REGISTER_CONSTRUCTOR(models::embeddings::ImageEmbeddings, std::string, std::shared_ptr); -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index 8ba422afba..d39a7e5d4a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -59,4 +59,4 @@ REGISTER_CONSTRUCTOR(models::semantic_segmentation::BaseSemanticSegmentation, std::string, std::vector, std::vector, std::vector, std::shared_ptr); -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp index 08f2a4683a..a2c1ae865b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp @@ -167,4 +167,4 @@ std::shared_ptr ImageSegmentation::populateDictionary( return dictPtr; } -} // namespace rnexecutorch::models::image_segmentation \ No newline at end of file +} // namespace rnexecutorch::models::image_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h index 8eed3c888d..73744c4d82 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -33,4 +33,4 @@ class StyleTransfer : public BaseModel { REGISTER_CONSTRUCTOR(models::style_transfer::StyleTransfer, std::string, std::shared_ptr); -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch From 899959d591dc14f42b653fb2a355b6c36660fa29 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 19 Feb 2026 22:34:20 +0100 Subject: [PATCH 08/71] fix: errors after rebase --- .../host_objects/JsiConversions.h | 10 -- .../host_objects/ModelHostObject.h | 2 +- .../metaprogramming/FunctionHelpers.h | 31 +++- yarn.lock | 136 ++++++++++++++++++ 4 files changed, 167 insertions(+), 12 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 8936711477..77f1c51adb 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -368,16 +368,6 @@ inline jsi::Value getJsiValue(uint64_t val, jsi::Runtime &runtime) { return {runtime, bigInt}; } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { - jsi::Array array(runtime, vec.size()); - for (size_t i = 0; i < vec.size(); i++) { - // JS numbers are doubles. Large uint64s > 2^53 will lose precision. - array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); - } - return {runtime, array}; -} - inline jsi::Value getJsiValue(const std::vector &vec, jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 35789847dc..f80c719bf1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -189,7 +189,7 @@ template class ModelHostObject : public JsiHostObject { "stream")); } - if constexpr (meta::DerivedFromOrSameAs) { + if constexpr (meta::HasGenerateFromFrame) { addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, synchronousHostFunction<&Model::streamStop>, "streamStop")); diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h index fde81e046d..ccce1cb5fd 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h @@ -10,6 +10,32 @@ namespace rnexecutorch::meta { using namespace facebook; +// ========================================================================= +// 1. Function Traits (Extracts Arity, Return Type, Args) +// ========================================================================= + +template struct FunctionTraits; + +// Specialization for Member Functions +template +struct FunctionTraits { + static constexpr std::size_t arity = sizeof...(Args); + using return_type = R; + using args_tuple = std::tuple; +}; + +// Specialization for const Member Functions +template +struct FunctionTraits { + static constexpr std::size_t arity = sizeof...(Args); + using return_type = R; + using args_tuple = std::tuple; +}; + +// ========================================================================= +// 2. Argument Counting Helpers +// ========================================================================= + template constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) { return sizeof...(Types); @@ -20,6 +46,10 @@ constexpr std::size_t getArgumentCount(R (Model::*f)(Types...) const) { return sizeof...(Types); } +// ========================================================================= +// 3. JSI -> Tuple Conversion Logic +// ========================================================================= + template std::tuple fillTupleFromArgs(std::index_sequence, const jsi::Value *args, @@ -33,7 +63,6 @@ std::tuple fillTupleFromArgs(std::index_sequence, * arguments for method supplied with a pointer. The types in the tuple are * inferred from the method pointer. */ - template std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...), const jsi::Value *args, diff --git a/yarn.lock b/yarn.lock index 12cb5c31b7..e6ddb22294 100644 --- a/yarn.lock +++ b/yarn.lock @@ -110,6 +110,19 @@ __metadata: languageName: node linkType: hard +"@babel/generator@npm:^7.29.0": + version: 7.29.1 + resolution: "@babel/generator@npm:7.29.1" + dependencies: + "@babel/parser": "npm:^7.29.0" + "@babel/types": "npm:^7.29.0" + "@jridgewell/gen-mapping": "npm:^0.3.12" + "@jridgewell/trace-mapping": "npm:^0.3.28" + jsesc: "npm:^3.0.2" + checksum: 10/61fe4ddd6e817aa312a14963ccdbb5c9a8c57e8b97b98d19a8a99ccab2215fda1a5f52bc8dd8d2e3c064497ddeb3ab8ceb55c76fa0f58f8169c34679d2256fe0 + languageName: node + linkType: hard + "@babel/helper-annotate-as-pure@npm:^7.27.1, @babel/helper-annotate-as-pure@npm:^7.27.3": version: 7.27.3 resolution: "@babel/helper-annotate-as-pure@npm:7.27.3" @@ -242,6 +255,13 @@ __metadata: languageName: node linkType: hard +"@babel/helper-plugin-utils@npm:^7.28.6": + version: 7.28.6 + resolution: "@babel/helper-plugin-utils@npm:7.28.6" + checksum: 10/21c853bbc13dbdddf03309c9a0477270124ad48989e1ad6524b83e83a77524b333f92edd2caae645c5a7ecf264ec6d04a9ebe15aeb54c7f33c037b71ec521e4a + languageName: node + linkType: hard + "@babel/helper-remap-async-to-generator@npm:^7.18.9, @babel/helper-remap-async-to-generator@npm:^7.27.1": version: 7.27.1 resolution: "@babel/helper-remap-async-to-generator@npm:7.27.1" @@ -268,6 +288,19 @@ __metadata: languageName: node linkType: hard +"@babel/helper-replace-supers@npm:^7.28.6": + version: 7.28.6 + resolution: "@babel/helper-replace-supers@npm:7.28.6" + dependencies: + "@babel/helper-member-expression-to-functions": "npm:^7.28.5" + "@babel/helper-optimise-call-expression": "npm:^7.27.1" + "@babel/traverse": "npm:^7.28.6" + peerDependencies: + "@babel/core": ^7.0.0 + checksum: 10/ad2724713a4d983208f509e9607e8f950855f11bd97518a700057eb8bec69d687a8f90dc2da0c3c47281d2e3b79cf1d14ecf1fe3e1ee0a8e90b61aee6759c9a7 + languageName: node + linkType: hard + "@babel/helper-skip-transparent-expression-wrappers@npm:^7.20.0, @babel/helper-skip-transparent-expression-wrappers@npm:^7.27.1": version: 7.27.1 resolution: "@babel/helper-skip-transparent-expression-wrappers@npm:7.27.1" @@ -343,6 +376,17 @@ __metadata: languageName: node linkType: hard +"@babel/parser@npm:^7.28.6, @babel/parser@npm:^7.29.0": + version: 7.29.0 + resolution: "@babel/parser@npm:7.29.0" + dependencies: + "@babel/types": "npm:^7.29.0" + bin: + parser: ./bin/babel-parser.js + checksum: 10/b1576dca41074997a33ee740d87b330ae2e647f4b7da9e8d2abd3772b18385d303b0cee962b9b88425e0f30d58358dbb8d63792c1a2d005c823d335f6a029747 + languageName: node + linkType: hard + "@babel/plugin-bugfix-firefox-class-in-computed-class-key@npm:^7.28.5": version: 7.28.5 resolution: "@babel/plugin-bugfix-firefox-class-in-computed-class-key@npm:7.28.5" @@ -767,6 +811,17 @@ __metadata: languageName: node linkType: hard +"@babel/plugin-syntax-typescript@npm:^7.28.6": + version: 7.28.6 + resolution: "@babel/plugin-syntax-typescript@npm:7.28.6" + dependencies: + "@babel/helper-plugin-utils": "npm:^7.28.6" + peerDependencies: + "@babel/core": ^7.0.0-0 + checksum: 10/5c55f9c63bd36cf3d7e8db892294c8f85000f9c1526c3a1cc310d47d1e174f5c6f6605e5cc902c4636d885faba7a9f3d5e5edc6b35e4f3b1fd4c2d58d0304fa5 + languageName: node + linkType: hard + "@babel/plugin-syntax-unicode-sets-regex@npm:^7.18.6": version: 7.18.6 resolution: "@babel/plugin-syntax-unicode-sets-regex@npm:7.18.6" @@ -1509,6 +1564,21 @@ __metadata: languageName: node linkType: hard +"@babel/plugin-transform-typescript@npm:^7.27.1": + version: 7.28.6 + resolution: "@babel/plugin-transform-typescript@npm:7.28.6" + dependencies: + "@babel/helper-annotate-as-pure": "npm:^7.27.3" + "@babel/helper-create-class-features-plugin": "npm:^7.28.6" + "@babel/helper-plugin-utils": "npm:^7.28.6" + "@babel/helper-skip-transparent-expression-wrappers": "npm:^7.27.1" + "@babel/plugin-syntax-typescript": "npm:^7.28.6" + peerDependencies: + "@babel/core": ^7.0.0-0 + checksum: 10/a0bccc531fa8710a45b0b593140273741e0e4a0721b1ef6ef9dfefae0bbe61528440d65aab7936929551fd76793272257d74f60cf66891352f793294930a4b67 + languageName: node + linkType: hard + "@babel/plugin-transform-unicode-escapes@npm:^7.27.1": version: 7.27.1 resolution: "@babel/plugin-transform-unicode-escapes@npm:7.27.1" @@ -1738,6 +1808,16 @@ __metadata: languageName: node linkType: hard +"@babel/types@npm:^7.28.6, @babel/types@npm:^7.29.0": + version: 7.29.0 + resolution: "@babel/types@npm:7.29.0" + dependencies: + "@babel/helper-string-parser": "npm:^7.27.1" + "@babel/helper-validator-identifier": "npm:^7.28.5" + checksum: 10/bfc2b211210f3894dcd7e6a33b2d1c32c93495dc1e36b547376aa33441abe551ab4bc1640d4154ee2acd8e46d3bbc925c7224caae02fcaf0e6a771e97fccc661 + languageName: node + linkType: hard + "@bcoe/v8-coverage@npm:^0.2.3": version: 0.2.3 resolution: "@bcoe/v8-coverage@npm:0.2.3" @@ -5910,6 +5990,18 @@ __metadata: languageName: node linkType: hard +"ajv@npm:^8.11.0": + version: 8.18.0 + resolution: "ajv@npm:8.18.0" + dependencies: + fast-deep-equal: "npm:^3.1.3" + fast-uri: "npm:^3.0.1" + json-schema-traverse: "npm:^1.0.0" + require-from-string: "npm:^2.0.2" + checksum: 10/bfed9de827a2b27c6d4084324eda76a4e32bdde27410b3e9b81d06e6f8f5c78370fc6b93fe1d869f1939ff1d7c4ae8896960995acb8425e3e9288c8884247c48 + languageName: node + linkType: hard + "anser@npm:^1.4.9": version: 1.4.10 resolution: "anser@npm:1.4.10" @@ -14496,6 +14588,27 @@ __metadata: languageName: node linkType: hard +"react-native-nitro-image@npm:0.10.2": + version: 0.10.2 + resolution: "react-native-nitro-image@npm:0.10.2" + peerDependencies: + react: "*" + react-native: "*" + react-native-nitro-modules: "*" + checksum: 10/3be75e93da369adfe00441dae78171572dec38d3d7e75e5d4cb302b81479be9686c8d8dc0ea4b331514b8725099bf3eb069ab9933f7029627d12a72d71766cb4 + languageName: node + linkType: hard + +"react-native-nitro-modules@npm:0.33.4": + version: 0.33.4 + resolution: "react-native-nitro-modules@npm:0.33.4" + peerDependencies: + react: "*" + react-native: "*" + checksum: 10/a737ff6b142c55821688612305245fd10a7cff36f0ee66cad0956c6815a60cdd4ba64cdfba6137a6dbfe815645763ce5d406cf488876edd47dab7f8d0031e01a + languageName: node + linkType: hard + "react-native-reanimated@npm:~4.1.1": version: 4.1.6 resolution: "react-native-reanimated@npm:4.1.6" @@ -14624,6 +14737,29 @@ __metadata: languageName: node linkType: hard +"react-native-worklets@npm:^0.7.2": + version: 0.7.4 + resolution: "react-native-worklets@npm:0.7.4" + dependencies: + "@babel/plugin-transform-arrow-functions": "npm:7.27.1" + "@babel/plugin-transform-class-properties": "npm:7.27.1" + "@babel/plugin-transform-classes": "npm:7.28.4" + "@babel/plugin-transform-nullish-coalescing-operator": "npm:7.27.1" + "@babel/plugin-transform-optional-chaining": "npm:7.27.1" + "@babel/plugin-transform-shorthand-properties": "npm:7.27.1" + "@babel/plugin-transform-template-literals": "npm:7.27.1" + "@babel/plugin-transform-unicode-regex": "npm:7.27.1" + "@babel/preset-typescript": "npm:7.27.1" + convert-source-map: "npm:2.0.0" + semver: "npm:7.7.3" + peerDependencies: + "@babel/core": "*" + react: "*" + react-native: "*" + checksum: 10/922b209940e298d21313d22f8a6eb87ad603442850c7ff8bc9cfef694cb211d7ec9903e24ee20b6bcf6164f8e7c165b65307dcca3d67465fdffda1c45fe05d1d + languageName: node + linkType: hard + "react-native@npm:0.81.5": version: 0.81.5 resolution: "react-native@npm:0.81.5" From 5b01a96147423167833ca40360daa4ce6182811f Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 23 Feb 2026 11:46:18 +0100 Subject: [PATCH 09/71] refactor: changes suggested in review --- .../models/object_detection/ObjectDetection.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 2293a4abc8..f926f49b78 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -80,26 +80,19 @@ cv::Mat ObjectDetection::preprocessFrame(const cv::Mat &frame) const { } cv::Mat ObjectDetection::preprocessFrame(const cv::Mat &frame) const { - // Get target size from model input shape const std::vector tensorDims = getAllInputShapes()[0]; cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1], tensorDims[tensorDims.size() - 2]); cv::Mat rgb; - // Convert RGBA/BGRA to RGB if needed (for VisionCamera frames) if (frame.channels() == 4) { -// Platform-specific color conversion: -// iOS uses BGRA format, Android uses RGBA format #ifdef __APPLE__ - // iOS: BGRA → RGB cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); #else - // Android: RGBA → RGB cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); #endif } else if (frame.channels() == 3) { - // Already RGB rgb = frame; } else { char errorMessage[100]; From 1368255618a2724a24a8bb1fc854f13b9d4bfb85 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 23 Feb 2026 13:00:39 +0100 Subject: [PATCH 10/71] fix: not existing error type, add comments to JSI code --- .../host_objects/ModelHostObject.h | 7 +++++ .../metaprogramming/FunctionHelpers.h | 31 +------------------ 2 files changed, 8 insertions(+), 30 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index f80c719bf1..4a92a415c1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -189,6 +189,13 @@ template class ModelHostObject : public JsiHostObject { "stream")); } + if constexpr (meta::HasGenerateFromString) { + addFunctions( + JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::generateFromString>, + "generateFromString")); + } + if constexpr (meta::HasGenerateFromFrame) { addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, synchronousHostFunction<&Model::streamStop>, diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h index ccce1cb5fd..fde81e046d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h @@ -10,32 +10,6 @@ namespace rnexecutorch::meta { using namespace facebook; -// ========================================================================= -// 1. Function Traits (Extracts Arity, Return Type, Args) -// ========================================================================= - -template struct FunctionTraits; - -// Specialization for Member Functions -template -struct FunctionTraits { - static constexpr std::size_t arity = sizeof...(Args); - using return_type = R; - using args_tuple = std::tuple; -}; - -// Specialization for const Member Functions -template -struct FunctionTraits { - static constexpr std::size_t arity = sizeof...(Args); - using return_type = R; - using args_tuple = std::tuple; -}; - -// ========================================================================= -// 2. Argument Counting Helpers -// ========================================================================= - template constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) { return sizeof...(Types); @@ -46,10 +20,6 @@ constexpr std::size_t getArgumentCount(R (Model::*f)(Types...) const) { return sizeof...(Types); } -// ========================================================================= -// 3. JSI -> Tuple Conversion Logic -// ========================================================================= - template std::tuple fillTupleFromArgs(std::index_sequence, const jsi::Value *args, @@ -63,6 +33,7 @@ std::tuple fillTupleFromArgs(std::index_sequence, * arguments for method supplied with a pointer. The types in the tuple are * inferred from the method pointer. */ + template std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...), const jsi::Value *args, From ee76a4443fa33202687f5acaf542154163d5dc56 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 23 Feb 2026 18:50:24 +0100 Subject: [PATCH 11/71] feat: add tests for generateFromPixels method --- .../app/object_detection/index.tsx | 84 ++----------------- .../computer_vision/ObjectDetectionModule.ts | 7 ++ 2 files changed, 14 insertions(+), 77 deletions(-) diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index e601e9cb11..a5e36c344a 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -1,16 +1,16 @@ import Spinner from '../../components/Spinner'; +import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; import { Detection, useObjectDetection, RF_DETR_NANO, } from 'react-native-executorch'; -import { View, StyleSheet, Image, TouchableOpacity, Text } from 'react-native'; +import { View, StyleSheet, Image } from 'react-native'; import ImageWithBboxes from '../../components/ImageWithBboxes'; import React, { useContext, useEffect, useState } from 'react'; import { GeneratingContext } from '../../context'; import ScreenWrapper from '../../ScreenWrapper'; -import ColorPalette from '../../colors'; export default function ObjectDetectionScreen() { const [imageUri, setImageUri] = useState(''); @@ -45,7 +45,7 @@ export default function ObjectDetectionScreen() { const output = await rfDetr.forward(imageUri); setResults(output); } catch (e) { - console.error('Error in runForward:', e); + console.error(e); } } }; @@ -124,41 +124,10 @@ export default function ObjectDetectionScreen() { )} - - {/* Custom bottom bar with two buttons */} - - - handleCameraPress(false)}> - 📷 Gallery - - - - - - Run (String) - - - - Run (Pixels) - - - + ); } @@ -203,43 +172,4 @@ const styles = StyleSheet.create({ width: '100%', height: '100%', }, - bottomContainer: { - width: '100%', - gap: 15, - alignItems: 'center', - padding: 16, - flex: 1, - }, - bottomIconsContainer: { - flexDirection: 'row', - justifyContent: 'center', - width: '100%', - }, - iconText: { - fontSize: 16, - color: ColorPalette.primary, - }, - buttonsRow: { - flexDirection: 'row', - width: '100%', - gap: 10, - }, - button: { - height: 50, - justifyContent: 'center', - alignItems: 'center', - backgroundColor: ColorPalette.primary, - color: '#fff', - borderRadius: 8, - }, - halfButton: { - flex: 1, - }, - buttonDisabled: { - opacity: 0.5, - }, - buttonText: { - color: '#fff', - fontSize: 16, - }, }); diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index c24bbd1369..bbb990f7b8 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -169,4 +169,11 @@ export class ObjectDetectionModule< nativeModule ); } + + async forward( + input: string | PixelData, + detectionThreshold: number = 0.5 + ): Promise { + return super.forward(input, detectionThreshold); + } } From f369b9a40715a50ce2bb1b91487f78c07dd9eb8f Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 23 Feb 2026 19:32:05 +0100 Subject: [PATCH 12/71] feat: add example screen with vision camera to computer vision app --- ...ative-vision-camera@npm-5.0.0-beta.1.patch | 713 ++++++++++++++++++ yarn.lock | 16 +- 2 files changed, 721 insertions(+), 8 deletions(-) create mode 100644 .yarn/patches/react-native-vision-camera@npm-5.0.0-beta.1.patch diff --git a/.yarn/patches/react-native-vision-camera@npm-5.0.0-beta.1.patch b/.yarn/patches/react-native-vision-camera@npm-5.0.0-beta.1.patch new file mode 100644 index 0000000000..73f999e9a6 --- /dev/null +++ b/.yarn/patches/react-native-vision-camera@npm-5.0.0-beta.1.patch @@ -0,0 +1,713 @@ +diff --git a/lib/expo-plugin/withVisionCamera.js b/lib/expo-plugin/withVisionCamera.js +index 32418a9..f7a8c5c 100644 +--- a/lib/expo-plugin/withVisionCamera.js ++++ b/lib/expo-plugin/withVisionCamera.js +@@ -1,4 +1,4 @@ +-import { AndroidConfig, withPlugins, } from '@expo/config-plugins'; ++const { AndroidConfig, withPlugins } = require('@expo/config-plugins'); + const CAMERA_USAGE = 'Allow $(PRODUCT_NAME) to access your camera'; + const MICROPHONE_USAGE = 'Allow $(PRODUCT_NAME) to access your microphone'; + const withVisionCamera = (config, props = {}) => { +@@ -30,4 +30,4 @@ const withVisionCamera = (config, props = {}) => { + [AndroidConfig.Permissions.withPermissions, androidPermissions], + ]); + }; +-export default withVisionCamera; ++module.exports = withVisionCamera; +diff --git a/cpp/Frame Processors/HybridWorkletQueueFactory.cpp b/cpp/Frame Processors/HybridWorkletQueueFactory.cpp +new file mode 100644 +index 0000000..5da4ef9 +--- /dev/null ++++ b/cpp/Frame Processors/HybridWorkletQueueFactory.cpp +@@ -0,0 +1,50 @@ ++/// ++/// HybridWorkletQueueFactory.cpp ++/// VisionCamera ++/// Copyright © 2025 Marc Rousavy @ Margelo ++/// ++ ++#include "HybridWorkletQueueFactory.hpp" ++ ++#include "JSIConverter+AsyncQueue.hpp" ++#include "NativeThreadAsyncQueue.hpp" ++#include "NativeThreadDispatcher.hpp" ++#include ++#include ++ ++namespace margelo::nitro::camera { ++ ++HybridWorkletQueueFactory::HybridWorkletQueueFactory() : HybridObject(TAG) {} ++ ++void HybridWorkletQueueFactory::loadHybridMethods() { ++ HybridWorkletQueueFactorySpec::loadHybridMethods(); ++ registerHybrids(this, [](Prototype& prototype) { ++ prototype.registerRawHybridMethod("installDispatcher", 1, &HybridWorkletQueueFactory::installDispatcher); ++ }); ++} ++ ++std::shared_ptr HybridWorkletQueueFactory::wrapThreadInQueue(const std::shared_ptr& thread) { ++ return std::make_shared(thread); ++} ++ ++double HybridWorkletQueueFactory::getCurrentThreadMarker() { ++ static std::atomic_size_t threadCounter{1}; ++ static thread_local size_t thisThreadId{0}; ++ if (thisThreadId == 0) { ++ thisThreadId = threadCounter.fetch_add(1); ++ } ++ return static_cast(thisThreadId); ++} ++ ++jsi::Value HybridWorkletQueueFactory::installDispatcher(jsi::Runtime& runtime, const jsi::Value&, const jsi::Value* args, size_t count) { ++ if (count != 1) ++ throw std::runtime_error("installDispatcher(..) must be called with exactly 1 argument!"); ++ auto thread = JSIConverter>::fromJSI(runtime, args[0]); ++ ++ auto dispatcher = std::make_shared(thread); ++ Dispatcher::installRuntimeGlobalDispatcher(runtime, dispatcher); ++ ++ return jsi::Value::undefined(); ++} ++ ++} // namespace margelo::nitro::camera +diff --git a/android/CMakeLists.txt b/android/CMakeLists.txt +index 0000000..1111111 100644 +--- a/android/CMakeLists.txt ++++ b/android/CMakeLists.txt +@@ -20,6 +20,7 @@ + "src/main/cpp" + "../cpp" + "../cpp/Frame Processors" ++ "../nitrogen/generated/shared/c++" + ) + + find_library(LOG_LIB log) +diff --git a/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.cpp b/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.cpp +new file mode 100644 +index 0000000..5da4ef9 +--- /dev/null ++++ b/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.cpp +@@ -0,0 +1,50 @@ ++/// ++/// HybridWorkletQueueFactory.cpp ++/// VisionCamera ++/// Copyright © 2025 Marc Rousavy @ Margelo ++/// ++ ++#include "HybridWorkletQueueFactory.hpp" ++ ++#include "JSIConverter+AsyncQueue.hpp" ++#include "NativeThreadAsyncQueue.hpp" ++#include "NativeThreadDispatcher.hpp" ++#include ++#include ++ ++namespace margelo::nitro::camera { ++ ++HybridWorkletQueueFactory::HybridWorkletQueueFactory() : HybridObject(TAG) {} ++ ++void HybridWorkletQueueFactory::loadHybridMethods() { ++ HybridWorkletQueueFactorySpec::loadHybridMethods(); ++ registerHybrids(this, [](Prototype& prototype) { ++ prototype.registerRawHybridMethod("installDispatcher", 1, &HybridWorkletQueueFactory::installDispatcher); ++ }); ++} ++ ++std::shared_ptr HybridWorkletQueueFactory::wrapThreadInQueue(const std::shared_ptr& thread) { ++ return std::make_shared(thread); ++} ++ ++double HybridWorkletQueueFactory::getCurrentThreadMarker() { ++ static std::atomic_size_t threadCounter{1}; ++ static thread_local size_t thisThreadId{0}; ++ if (thisThreadId == 0) { ++ thisThreadId = threadCounter.fetch_add(1); ++ } ++ return static_cast(thisThreadId); ++} ++ ++jsi::Value HybridWorkletQueueFactory::installDispatcher(jsi::Runtime& runtime, const jsi::Value&, const jsi::Value* args, size_t count) { ++ if (count != 1) ++ throw std::runtime_error("installDispatcher(..) must be called with exactly 1 argument!"); ++ auto thread = JSIConverter>::fromJSI(runtime, args[0]); ++ ++ auto dispatcher = std::make_shared(thread); ++ Dispatcher::installRuntimeGlobalDispatcher(runtime, dispatcher); ++ ++ return jsi::Value::undefined(); ++} ++ ++} // namespace margelo::nitro::camera +diff --git a/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.hpp b/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.hpp +new file mode 100644 +index 0000000..daa16d2 +--- /dev/null ++++ b/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.hpp +@@ -0,0 +1,29 @@ ++/// ++/// HybridWorkletQueueFactory.hpp ++/// VisionCamera ++/// Copyright © 2025 Marc Rousavy @ Margelo ++/// ++ ++#pragma once ++ ++#include "HybridWorkletQueueFactorySpec.hpp" ++#include "JSIConverter+AsyncQueue.hpp" ++#include ++#include ++ ++namespace margelo::nitro::camera { ++ ++class HybridWorkletQueueFactory : public HybridWorkletQueueFactorySpec { ++public: ++ HybridWorkletQueueFactory(); ++ ++public: ++ std::shared_ptr wrapThreadInQueue(const std::shared_ptr& thread) override; ++ double getCurrentThreadMarker() override; ++ ++ jsi::Value installDispatcher(jsi::Runtime& runtime, const jsi::Value&, const jsi::Value* args, size_t count); ++ ++ void loadHybridMethods() override; ++}; ++ ++} // namespace margelo::nitro::camera +diff --git a/nitrogen/generated/shared/c++/JSIConverter+AsyncQueue.hpp b/nitrogen/generated/shared/c++/JSIConverter+AsyncQueue.hpp +new file mode 100644 +index 0000000..5b93f2d +--- /dev/null ++++ b/nitrogen/generated/shared/c++/JSIConverter+AsyncQueue.hpp +@@ -0,0 +1,24 @@ ++/// ++/// JSIConverter+AsyncQueue.swift ++/// VisionCamera ++/// Copyright © 2025 Marc Rousavy @ Margelo ++/// ++ ++#pragma once ++ ++#include ++#include ++#if __has_include() ++#include ++#elif __has_include() ++#include ++#else ++#error react-native-worklets Prefab not found! ++#endif ++ ++namespace margelo::nitro { ++ ++// JSIConverter> is implemented ++// in JSIConverter> ++ ++} +diff --git a/nitrogen/generated/shared/c++/NativeThreadAsyncQueue.hpp b/nitrogen/generated/shared/c++/NativeThreadAsyncQueue.hpp +new file mode 100644 +index 0000000..d5a0958 +--- /dev/null ++++ b/nitrogen/generated/shared/c++/NativeThreadAsyncQueue.hpp +@@ -0,0 +1,34 @@ ++/// ++/// NativeThreadAsyncQueue.hpp ++/// VisionCamera ++/// Copyright © 2025 Marc Rousavy @ Margelo ++/// ++ ++#pragma once ++ ++#include "HybridNativeThreadSpec.hpp" ++#include "JSIConverter+AsyncQueue.hpp" ++#include ++ ++namespace margelo::nitro::camera { ++ ++/** ++ * An implementation of `worklets::AsyncQueue` that uses a `NativeThread` to run its jobs. ++ * ++ * The `NativeThread` (`HybridNativeThreadSpec`) is a platform-implemented object, ++ * e.g. using `DispatchQueue` on iOS. ++ */ ++class NativeThreadAsyncQueue : public worklets::AsyncQueue { ++public: ++ NativeThreadAsyncQueue(std::shared_ptr thread) : _thread(std::move(thread)) {} ++ ++ void push(std::function&& job) override { ++ auto jobCopy = job; ++ _thread->runOnThread(jobCopy); ++ } ++ ++private: ++ std::shared_ptr _thread; ++}; ++ ++} // namespace margelo::nitro::camera +diff --git a/nitrogen/generated/shared/c++/NativeThreadDispatcher.hpp b/nitrogen/generated/shared/c++/NativeThreadDispatcher.hpp +new file mode 100644 +index 0000000..758d2f2 +--- /dev/null ++++ b/nitrogen/generated/shared/c++/NativeThreadDispatcher.hpp +@@ -0,0 +1,36 @@ ++/// ++/// NativeThreadDispatcher.hpp ++/// VisionCamera ++/// Copyright © 2025 Marc Rousavy @ Margelo ++/// ++ ++#pragma once ++ ++#include "HybridNativeThreadSpec.hpp" ++#include "JSIConverter+AsyncQueue.hpp" ++#include ++ ++namespace margelo::nitro::camera { ++ ++/** ++ * An implementation of `nitro::Dispatcher` that uses a `NativeThread` to run its jobs. ++ * ++ * The `NativeThread` (`HybridNativeThreadSpec`) is a platform-implemented object, ++ * e.g. using `DispatchQueue` on iOS. ++ */ ++class NativeThreadDispatcher : public nitro::Dispatcher { ++public: ++ NativeThreadDispatcher(std::shared_ptr thread) : _thread(std::move(thread)) {} ++ ++ void runSync(std::function&&) override { ++ throw std::runtime_error("runSync(...) is not implemented for NativeThreadDispatcher!"); ++ } ++ void runAsync(std::function&& function) override { ++ _thread->runOnThread(function); ++ } ++ ++private: ++ std::shared_ptr _thread; ++}; ++ ++} // namespace margelo::nitro::camera +diff --git a/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/BoundingBox.kt b/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/BoundingBox.kt +new file mode 100644 +index 0000000..aaaaaaa +--- /dev/null ++++ b/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/BoundingBox.kt +@@ -0,0 +1,47 @@ ++/// ++/// BoundingBox.kt ++/// This file was generated by nitrogen. DO NOT MODIFY THIS FILE. ++/// https://github.com/mrousavy/nitro ++/// Copyright © Marc Rousavy @ Margelo ++/// ++ ++package com.margelo.nitro.camera ++ ++import androidx.annotation.Keep ++import com.facebook.proguard.annotations.DoNotStrip ++ ++ ++/** ++ * Represents the JavaScript object/struct "BoundingBox". ++ */ ++@DoNotStrip ++@Keep ++data class BoundingBox( ++ @DoNotStrip ++ @Keep ++ val x: Double, ++ @DoNotStrip ++ @Keep ++ val y: Double, ++ @DoNotStrip ++ @Keep ++ val width: Double, ++ @DoNotStrip ++ @Keep ++ val height: Double ++) { ++ /* primary constructor */ ++ ++ companion object { ++ /** ++ * Constructor called from C++ ++ */ ++ @DoNotStrip ++ @Keep ++ @Suppress("unused") ++ @JvmStatic ++ private fun fromCpp(x: Double, y: Double, width: Double, height: Double): BoundingBox { ++ return BoundingBox(x, y, width, height) ++ } ++ } ++} +diff --git a/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/HybridScannedObjectSpec.kt b/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/HybridScannedObjectSpec.kt +new file mode 100644 +index 0000000..bbbbbbb +--- /dev/null ++++ b/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/HybridScannedObjectSpec.kt +@@ -0,0 +1,60 @@ ++/// ++/// HybridScannedObjectSpec.kt ++/// This file was generated by nitrogen. DO NOT MODIFY THIS FILE. ++/// https://github.com/mrousavy/nitro ++/// Copyright © Marc Rousavy @ Margelo ++/// ++ ++package com.margelo.nitro.camera ++ ++import androidx.annotation.Keep ++import com.facebook.jni.HybridData ++import com.facebook.proguard.annotations.DoNotStrip ++import com.margelo.nitro.core.HybridObject ++ ++/** ++ * A Kotlin class representing the ScannedObject HybridObject. ++ * Implement this abstract class to create Kotlin-based instances of ScannedObject. ++ */ ++@DoNotStrip ++@Keep ++@Suppress( ++ "KotlinJniMissingFunction", "unused", ++ "RedundantSuppression", "RedundantUnitReturnType", "SimpleRedundantLet", ++ "LocalVariableName", "PropertyName", "PrivatePropertyName", "FunctionName" ++) ++abstract class HybridScannedObjectSpec: HybridObject() { ++ @DoNotStrip ++ private var mHybridData: HybridData = initHybrid() ++ ++ init { ++ super.updateNative(mHybridData) ++ } ++ ++ override fun updateNative(hybridData: HybridData) { ++ mHybridData = hybridData ++ super.updateNative(hybridData) ++ } ++ ++ // Default implementation of `HybridObject.toString()` ++ override fun toString(): String { ++ return "[HybridObject ScannedObject]" ++ } ++ ++ // Properties ++ @get:DoNotStrip ++ @get:Keep ++ abstract val type: ScannedObjectType ++ ++ @get:DoNotStrip ++ @get:Keep ++ abstract val boundingBox: BoundingBox ++ ++ // Methods ++ ++ private external fun initHybrid(): HybridData ++ ++ companion object { ++ protected const val TAG = "HybridScannedObjectSpec" ++ } ++} +diff --git a/nitrogen/generated/android/c++/JBoundingBox.hpp b/nitrogen/generated/android/c++/JBoundingBox.hpp +new file mode 100644 +index 0000000..ccccccc +--- /dev/null ++++ b/nitrogen/generated/android/c++/JBoundingBox.hpp +@@ -0,0 +1,69 @@ ++/// ++/// JBoundingBox.hpp ++/// This file was generated by nitrogen. DO NOT MODIFY THIS FILE. ++/// https://github.com/mrousavy/nitro ++/// Copyright © Marc Rousavy @ Margelo ++/// ++ ++#pragma once ++ ++#include ++#include "BoundingBox.hpp" ++ ++ ++ ++namespace margelo::nitro::camera { ++ ++ using namespace facebook; ++ ++ /** ++ * The C++ JNI bridge between the C++ struct "BoundingBox" and the the Kotlin data class "BoundingBox". ++ */ ++ struct JBoundingBox final: public jni::JavaClass { ++ public: ++ static auto constexpr kJavaDescriptor = "Lcom/margelo/nitro/camera/BoundingBox;"; ++ ++ public: ++ /** ++ * Convert this Java/Kotlin-based struct to the C++ struct BoundingBox by copying all values to C++. ++ */ ++ [[maybe_unused]] ++ [[nodiscard]] ++ BoundingBox toCpp() const { ++ static const auto clazz = javaClassStatic(); ++ static const auto fieldX = clazz->getField("x"); ++ double x = this->getFieldValue(fieldX); ++ static const auto fieldY = clazz->getField("y"); ++ double y = this->getFieldValue(fieldY); ++ static const auto fieldWidth = clazz->getField("width"); ++ double width = this->getFieldValue(fieldWidth); ++ static const auto fieldHeight = clazz->getField("height"); ++ double height = this->getFieldValue(fieldHeight); ++ return BoundingBox( ++ x, ++ y, ++ width, ++ height ++ ); ++ } ++ ++ public: ++ /** ++ * Create a Java/Kotlin-based struct by copying all values from the given C++ struct to Java. ++ */ ++ [[maybe_unused]] ++ static jni::local_ref fromCpp(const BoundingBox& value) { ++ using JSignature = JBoundingBox(double, double, double, double); ++ static const auto clazz = javaClassStatic(); ++ static const auto create = clazz->getStaticMethod("fromCpp"); ++ return create( ++ clazz, ++ value.x, ++ value.y, ++ value.width, ++ value.height ++ ); ++ } ++ }; ++ ++} // namespace margelo::nitro::camera +diff --git a/nitrogen/generated/android/c++/JHybridScannedObjectSpec.hpp b/nitrogen/generated/android/c++/JHybridScannedObjectSpec.hpp +new file mode 100644 +index 0000000..ddddddd +--- /dev/null ++++ b/nitrogen/generated/android/c++/JHybridScannedObjectSpec.hpp +@@ -0,0 +1,63 @@ ++/// ++/// JHybridScannedObjectSpec.hpp ++/// This file was generated by nitrogen. DO NOT MODIFY THIS FILE. ++/// https://github.com/mrousavy/nitro ++/// Copyright © Marc Rousavy @ Margelo ++/// ++ ++#pragma once ++ ++#include ++#include ++#include "HybridScannedObjectSpec.hpp" ++ ++ ++ ++ ++namespace margelo::nitro::camera { ++ ++ using namespace facebook; ++ ++ class JHybridScannedObjectSpec: public jni::HybridClass, ++ public virtual HybridScannedObjectSpec { ++ public: ++ static auto constexpr kJavaDescriptor = "Lcom/margelo/nitro/camera/HybridScannedObjectSpec;"; ++ static jni::local_ref initHybrid(jni::alias_ref jThis); ++ static void registerNatives(); ++ ++ protected: ++ // C++ constructor (called from Java via `initHybrid()`) ++ explicit JHybridScannedObjectSpec(jni::alias_ref jThis) : ++ HybridObject(HybridScannedObjectSpec::TAG), ++ HybridBase(jThis), ++ _javaPart(jni::make_global(jThis)) {} ++ ++ public: ++ ~JHybridScannedObjectSpec() override { ++ // Hermes GC can destroy JS objects on a non-JNI Thread. ++ jni::ThreadScope::WithClassLoader([&] { _javaPart.reset(); }); ++ } ++ ++ public: ++ size_t getExternalMemorySize() noexcept override; ++ bool equals(const std::shared_ptr& other) override; ++ void dispose() noexcept override; ++ std::string toString() override; ++ ++ public: ++ inline const jni::global_ref& getJavaPart() const noexcept { ++ return _javaPart; ++ } ++ ++ public: ++ // Properties ++ ScannedObjectType getType() override; ++ BoundingBox getBoundingBox() override; ++ ++ private: ++ friend HybridBase; ++ using HybridBase::HybridBase; ++ jni::global_ref _javaPart; ++ }; ++ ++} // namespace margelo::nitro::camera +diff --git a/nitrogen/generated/android/VisionCamera+autolinking.cmake b/nitrogen/generated/android/VisionCamera+autolinking.cmake +index 0000000..1111111 100644 +--- a/nitrogen/generated/android/VisionCamera+autolinking.cmake ++++ b/nitrogen/generated/android/VisionCamera+autolinking.cmake +@@ -112,3 +112,4 @@ + ../nitrogen/generated/android/c++/JHybridPreviewViewSpec.cpp + ../nitrogen/generated/android/c++/views/JHybridPreviewViewStateUpdater.cpp ++ ../nitrogen/generated/android/c++/JHybridScannedObjectSpec.cpp + ) +diff --git a/nitrogen/generated/android/c++/JHybridScannedObjectSpec.cpp b/nitrogen/generated/android/c++/JHybridScannedObjectSpec.cpp +new file mode 100644 +index 0000000..eeeeeee +--- /dev/null ++++ b/nitrogen/generated/android/c++/JHybridScannedObjectSpec.cpp +@@ -0,0 +1,69 @@ ++/// ++/// JHybridScannedObjectSpec.cpp ++/// This file was generated by nitrogen. DO NOT MODIFY THIS FILE. ++/// https://github.com/mrousavy/nitro ++/// Copyright © Marc Rousavy @ Margelo ++/// ++ ++#include "JHybridScannedObjectSpec.hpp" ++ ++// Forward declaration of `ScannedObjectType` to properly resolve imports. ++namespace margelo::nitro::camera { enum class ScannedObjectType; } ++// Forward declaration of `BoundingBox` to properly resolve imports. ++namespace margelo::nitro::camera { struct BoundingBox; } ++ ++#include "ScannedObjectType.hpp" ++#include "JScannedObjectType.hpp" ++#include "BoundingBox.hpp" ++#include "JBoundingBox.hpp" ++ ++namespace margelo::nitro::camera { ++ ++ jni::local_ref JHybridScannedObjectSpec::initHybrid(jni::alias_ref jThis) { ++ return makeCxxInstance(jThis); ++ } ++ ++ void JHybridScannedObjectSpec::registerNatives() { ++ registerHybrid({ ++ makeNativeMethod("initHybrid", JHybridScannedObjectSpec::initHybrid), ++ }); ++ } ++ ++ size_t JHybridScannedObjectSpec::getExternalMemorySize() noexcept { ++ static const auto method = javaClassStatic()->getMethod("getMemorySize"); ++ return method(_javaPart); ++ } ++ ++ bool JHybridScannedObjectSpec::equals(const std::shared_ptr& other) { ++ if (auto otherCast = std::dynamic_pointer_cast(other)) { ++ return _javaPart == otherCast->_javaPart; ++ } ++ return false; ++ } ++ ++ void JHybridScannedObjectSpec::dispose() noexcept { ++ static const auto method = javaClassStatic()->getMethod("dispose"); ++ method(_javaPart); ++ } ++ ++ std::string JHybridScannedObjectSpec::toString() { ++ static const auto method = javaClassStatic()->getMethod("toString"); ++ auto javaString = method(_javaPart); ++ return javaString->toStdString(); ++ } ++ ++ // Properties ++ ScannedObjectType JHybridScannedObjectSpec::getType() { ++ static const auto method = javaClassStatic()->getMethod()>("getType"); ++ auto __result = method(_javaPart); ++ return __result->toCpp(); ++ } ++ BoundingBox JHybridScannedObjectSpec::getBoundingBox() { ++ static const auto method = javaClassStatic()->getMethod()>("getBoundingBox"); ++ auto __result = method(_javaPart); ++ return __result->toCpp(); ++ } ++ ++ // Methods ++ ++} // namespace margelo::nitro::camera +diff --git a/android/src/main/java/com/margelo/nitro/camera/hybrids/recording/HybridVideoRecorder.kt b/android/src/main/java/com/margelo/nitro/camera/hybrids/recording/HybridVideoRecorder.kt +index aaaaaaa..bbbbbbb 100644 +--- a/android/src/main/java/com/margelo/nitro/camera/hybrids/recording/HybridVideoRecorder.kt ++++ b/android/src/main/java/com/margelo/nitro/camera/hybrids/recording/HybridVideoRecorder.kt +@@ -55,6 +55,6 @@ + when (event) { + is VideoRecordEvent.Start -> { +- promise.resolve() ++ promise.resolve(Unit) + didResolve = true + } + +@@ -98,27 +98,48 @@ + override fun stopRecording(): Promise { +- return Promise.parallel(executor) { +- val recording = recording ?: throw Error("Not currently recording!") +- recording.stop() +- this.isPaused = false +- this.recording = null +- this.recordedDuration = 0.0 +- this.recordedFileSize = 0.0 +- } ++ val promise = Promise() ++ executor.execute { ++ try { ++ val recording = recording ?: throw Error("Not currently recording!") ++ recording.stop() ++ this.isPaused = false ++ this.recording = null ++ this.recordedDuration = 0.0 ++ this.recordedFileSize = 0.0 ++ promise.resolve(Unit) ++ } catch (e: Throwable) { ++ promise.reject(e) ++ } ++ } ++ return promise + } + + override fun pauseRecording(): Promise { +- return Promise.parallel(executor) { +- val recording = recording ?: throw Error("Not currently recording!") +- recording.pause() +- this.isPaused = true +- } ++ val promise = Promise() ++ executor.execute { ++ try { ++ val recording = recording ?: throw Error("Not currently recording!") ++ recording.pause() ++ this.isPaused = true ++ promise.resolve(Unit) ++ } catch (e: Throwable) { ++ promise.reject(e) ++ } ++ } ++ return promise + } + + override fun resumeRecording(): Promise { +- return Promise.parallel(executor) { +- val recording = recording ?: throw Error("Not currently recording!") +- recording.resume() +- this.isPaused = false +- } ++ val promise = Promise() ++ executor.execute { ++ try { ++ val recording = recording ?: throw Error("Not currently recording!") ++ recording.resume() ++ this.isPaused = false ++ promise.resolve(Unit) ++ } catch (e: Throwable) { ++ promise.reject(e) ++ } ++ } ++ return promise + } + } diff --git a/yarn.lock b/yarn.lock index e6ddb22294..b76e118881 100644 --- a/yarn.lock +++ b/yarn.lock @@ -14588,24 +14588,24 @@ __metadata: languageName: node linkType: hard -"react-native-nitro-image@npm:0.10.2": - version: 0.10.2 - resolution: "react-native-nitro-image@npm:0.10.2" +"react-native-nitro-image@npm:^0.12.0": + version: 0.12.0 + resolution: "react-native-nitro-image@npm:0.12.0" peerDependencies: react: "*" react-native: "*" react-native-nitro-modules: "*" - checksum: 10/3be75e93da369adfe00441dae78171572dec38d3d7e75e5d4cb302b81479be9686c8d8dc0ea4b331514b8725099bf3eb069ab9933f7029627d12a72d71766cb4 + checksum: 10/03f165381c35e060d4d05eae3ce029b32a4009482f327e9526840f306181ca87a862b335e12667c55d4ee9f2069542ca93dd112feb7f1822bf7d2ddc38fe58f0 languageName: node linkType: hard -"react-native-nitro-modules@npm:0.33.4": - version: 0.33.4 - resolution: "react-native-nitro-modules@npm:0.33.4" +"react-native-nitro-modules@npm:^0.33.9": + version: 0.33.9 + resolution: "react-native-nitro-modules@npm:0.33.9" peerDependencies: react: "*" react-native: "*" - checksum: 10/a737ff6b142c55821688612305245fd10a7cff36f0ee66cad0956c6815a60cdd4ba64cdfba6137a6dbfe815645763ce5d406cf488876edd47dab7f8d0031e01a + checksum: 10/4ebf4db46d1e4987a0e52054724081aa9712bcd1d505a6dbdd47aebc6afe72a7abaa0e947651d9f3cc594e4eb3dba47fc6f59db27c5a5ed383946e40d96543a0 languageName: node linkType: hard From 9abdb7bdf8b59ee662c1c2e87dd35bc78b688b50 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 24 Feb 2026 09:02:14 +0100 Subject: [PATCH 13/71] feat: suggested changes / improve comments --- .../common/rnexecutorch/host_objects/JsiConversions.h | 9 --------- .../common/rnexecutorch/metaprogramming/TypeConcepts.h | 5 +++++ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 77f1c51adb..96e3168ee7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -368,15 +368,6 @@ inline jsi::Value getJsiValue(uint64_t val, jsi::Runtime &runtime) { return {runtime, bigInt}; } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { - jsi::Array array(runtime, vec.size()); - for (size_t i = 0; i < vec.size(); i++) { - array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); - } - return {runtime, array}; -} - inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) { return {runtime, val}; } diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index fdf8c9dba7..216e2bae39 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -11,6 +11,11 @@ concept DerivedFromOrSameAs = std::is_base_of_v; template concept SameAs = std::is_same_v; +template +concept HasGenerate = requires(T t) { + { &T::generate }; +}; + template concept HasGenerateFromString = requires(T t) { { &T::generateFromString }; From c27d745440789dd4b6307964bd5b917f531c16a0 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 25 Feb 2026 14:16:42 +0100 Subject: [PATCH 14/71] fix(android): object detection not working on android --- ...ative-vision-camera@npm-5.0.0-beta.1.patch | 713 ------------------ 1 file changed, 713 deletions(-) delete mode 100644 .yarn/patches/react-native-vision-camera@npm-5.0.0-beta.1.patch diff --git a/.yarn/patches/react-native-vision-camera@npm-5.0.0-beta.1.patch b/.yarn/patches/react-native-vision-camera@npm-5.0.0-beta.1.patch deleted file mode 100644 index 73f999e9a6..0000000000 --- a/.yarn/patches/react-native-vision-camera@npm-5.0.0-beta.1.patch +++ /dev/null @@ -1,713 +0,0 @@ -diff --git a/lib/expo-plugin/withVisionCamera.js b/lib/expo-plugin/withVisionCamera.js -index 32418a9..f7a8c5c 100644 ---- a/lib/expo-plugin/withVisionCamera.js -+++ b/lib/expo-plugin/withVisionCamera.js -@@ -1,4 +1,4 @@ --import { AndroidConfig, withPlugins, } from '@expo/config-plugins'; -+const { AndroidConfig, withPlugins } = require('@expo/config-plugins'); - const CAMERA_USAGE = 'Allow $(PRODUCT_NAME) to access your camera'; - const MICROPHONE_USAGE = 'Allow $(PRODUCT_NAME) to access your microphone'; - const withVisionCamera = (config, props = {}) => { -@@ -30,4 +30,4 @@ const withVisionCamera = (config, props = {}) => { - [AndroidConfig.Permissions.withPermissions, androidPermissions], - ]); - }; --export default withVisionCamera; -+module.exports = withVisionCamera; -diff --git a/cpp/Frame Processors/HybridWorkletQueueFactory.cpp b/cpp/Frame Processors/HybridWorkletQueueFactory.cpp -new file mode 100644 -index 0000000..5da4ef9 ---- /dev/null -+++ b/cpp/Frame Processors/HybridWorkletQueueFactory.cpp -@@ -0,0 +1,50 @@ -+/// -+/// HybridWorkletQueueFactory.cpp -+/// VisionCamera -+/// Copyright © 2025 Marc Rousavy @ Margelo -+/// -+ -+#include "HybridWorkletQueueFactory.hpp" -+ -+#include "JSIConverter+AsyncQueue.hpp" -+#include "NativeThreadAsyncQueue.hpp" -+#include "NativeThreadDispatcher.hpp" -+#include -+#include -+ -+namespace margelo::nitro::camera { -+ -+HybridWorkletQueueFactory::HybridWorkletQueueFactory() : HybridObject(TAG) {} -+ -+void HybridWorkletQueueFactory::loadHybridMethods() { -+ HybridWorkletQueueFactorySpec::loadHybridMethods(); -+ registerHybrids(this, [](Prototype& prototype) { -+ prototype.registerRawHybridMethod("installDispatcher", 1, &HybridWorkletQueueFactory::installDispatcher); -+ }); -+} -+ -+std::shared_ptr HybridWorkletQueueFactory::wrapThreadInQueue(const std::shared_ptr& thread) { -+ return std::make_shared(thread); -+} -+ -+double HybridWorkletQueueFactory::getCurrentThreadMarker() { -+ static std::atomic_size_t threadCounter{1}; -+ static thread_local size_t thisThreadId{0}; -+ if (thisThreadId == 0) { -+ thisThreadId = threadCounter.fetch_add(1); -+ } -+ return static_cast(thisThreadId); -+} -+ -+jsi::Value HybridWorkletQueueFactory::installDispatcher(jsi::Runtime& runtime, const jsi::Value&, const jsi::Value* args, size_t count) { -+ if (count != 1) -+ throw std::runtime_error("installDispatcher(..) must be called with exactly 1 argument!"); -+ auto thread = JSIConverter>::fromJSI(runtime, args[0]); -+ -+ auto dispatcher = std::make_shared(thread); -+ Dispatcher::installRuntimeGlobalDispatcher(runtime, dispatcher); -+ -+ return jsi::Value::undefined(); -+} -+ -+} // namespace margelo::nitro::camera -diff --git a/android/CMakeLists.txt b/android/CMakeLists.txt -index 0000000..1111111 100644 ---- a/android/CMakeLists.txt -+++ b/android/CMakeLists.txt -@@ -20,6 +20,7 @@ - "src/main/cpp" - "../cpp" - "../cpp/Frame Processors" -+ "../nitrogen/generated/shared/c++" - ) - - find_library(LOG_LIB log) -diff --git a/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.cpp b/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.cpp -new file mode 100644 -index 0000000..5da4ef9 ---- /dev/null -+++ b/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.cpp -@@ -0,0 +1,50 @@ -+/// -+/// HybridWorkletQueueFactory.cpp -+/// VisionCamera -+/// Copyright © 2025 Marc Rousavy @ Margelo -+/// -+ -+#include "HybridWorkletQueueFactory.hpp" -+ -+#include "JSIConverter+AsyncQueue.hpp" -+#include "NativeThreadAsyncQueue.hpp" -+#include "NativeThreadDispatcher.hpp" -+#include -+#include -+ -+namespace margelo::nitro::camera { -+ -+HybridWorkletQueueFactory::HybridWorkletQueueFactory() : HybridObject(TAG) {} -+ -+void HybridWorkletQueueFactory::loadHybridMethods() { -+ HybridWorkletQueueFactorySpec::loadHybridMethods(); -+ registerHybrids(this, [](Prototype& prototype) { -+ prototype.registerRawHybridMethod("installDispatcher", 1, &HybridWorkletQueueFactory::installDispatcher); -+ }); -+} -+ -+std::shared_ptr HybridWorkletQueueFactory::wrapThreadInQueue(const std::shared_ptr& thread) { -+ return std::make_shared(thread); -+} -+ -+double HybridWorkletQueueFactory::getCurrentThreadMarker() { -+ static std::atomic_size_t threadCounter{1}; -+ static thread_local size_t thisThreadId{0}; -+ if (thisThreadId == 0) { -+ thisThreadId = threadCounter.fetch_add(1); -+ } -+ return static_cast(thisThreadId); -+} -+ -+jsi::Value HybridWorkletQueueFactory::installDispatcher(jsi::Runtime& runtime, const jsi::Value&, const jsi::Value* args, size_t count) { -+ if (count != 1) -+ throw std::runtime_error("installDispatcher(..) must be called with exactly 1 argument!"); -+ auto thread = JSIConverter>::fromJSI(runtime, args[0]); -+ -+ auto dispatcher = std::make_shared(thread); -+ Dispatcher::installRuntimeGlobalDispatcher(runtime, dispatcher); -+ -+ return jsi::Value::undefined(); -+} -+ -+} // namespace margelo::nitro::camera -diff --git a/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.hpp b/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.hpp -new file mode 100644 -index 0000000..daa16d2 ---- /dev/null -+++ b/nitrogen/generated/shared/c++/HybridWorkletQueueFactory.hpp -@@ -0,0 +1,29 @@ -+/// -+/// HybridWorkletQueueFactory.hpp -+/// VisionCamera -+/// Copyright © 2025 Marc Rousavy @ Margelo -+/// -+ -+#pragma once -+ -+#include "HybridWorkletQueueFactorySpec.hpp" -+#include "JSIConverter+AsyncQueue.hpp" -+#include -+#include -+ -+namespace margelo::nitro::camera { -+ -+class HybridWorkletQueueFactory : public HybridWorkletQueueFactorySpec { -+public: -+ HybridWorkletQueueFactory(); -+ -+public: -+ std::shared_ptr wrapThreadInQueue(const std::shared_ptr& thread) override; -+ double getCurrentThreadMarker() override; -+ -+ jsi::Value installDispatcher(jsi::Runtime& runtime, const jsi::Value&, const jsi::Value* args, size_t count); -+ -+ void loadHybridMethods() override; -+}; -+ -+} // namespace margelo::nitro::camera -diff --git a/nitrogen/generated/shared/c++/JSIConverter+AsyncQueue.hpp b/nitrogen/generated/shared/c++/JSIConverter+AsyncQueue.hpp -new file mode 100644 -index 0000000..5b93f2d ---- /dev/null -+++ b/nitrogen/generated/shared/c++/JSIConverter+AsyncQueue.hpp -@@ -0,0 +1,24 @@ -+/// -+/// JSIConverter+AsyncQueue.swift -+/// VisionCamera -+/// Copyright © 2025 Marc Rousavy @ Margelo -+/// -+ -+#pragma once -+ -+#include -+#include -+#if __has_include() -+#include -+#elif __has_include() -+#include -+#else -+#error react-native-worklets Prefab not found! -+#endif -+ -+namespace margelo::nitro { -+ -+// JSIConverter> is implemented -+// in JSIConverter> -+ -+} -diff --git a/nitrogen/generated/shared/c++/NativeThreadAsyncQueue.hpp b/nitrogen/generated/shared/c++/NativeThreadAsyncQueue.hpp -new file mode 100644 -index 0000000..d5a0958 ---- /dev/null -+++ b/nitrogen/generated/shared/c++/NativeThreadAsyncQueue.hpp -@@ -0,0 +1,34 @@ -+/// -+/// NativeThreadAsyncQueue.hpp -+/// VisionCamera -+/// Copyright © 2025 Marc Rousavy @ Margelo -+/// -+ -+#pragma once -+ -+#include "HybridNativeThreadSpec.hpp" -+#include "JSIConverter+AsyncQueue.hpp" -+#include -+ -+namespace margelo::nitro::camera { -+ -+/** -+ * An implementation of `worklets::AsyncQueue` that uses a `NativeThread` to run its jobs. -+ * -+ * The `NativeThread` (`HybridNativeThreadSpec`) is a platform-implemented object, -+ * e.g. using `DispatchQueue` on iOS. -+ */ -+class NativeThreadAsyncQueue : public worklets::AsyncQueue { -+public: -+ NativeThreadAsyncQueue(std::shared_ptr thread) : _thread(std::move(thread)) {} -+ -+ void push(std::function&& job) override { -+ auto jobCopy = job; -+ _thread->runOnThread(jobCopy); -+ } -+ -+private: -+ std::shared_ptr _thread; -+}; -+ -+} // namespace margelo::nitro::camera -diff --git a/nitrogen/generated/shared/c++/NativeThreadDispatcher.hpp b/nitrogen/generated/shared/c++/NativeThreadDispatcher.hpp -new file mode 100644 -index 0000000..758d2f2 ---- /dev/null -+++ b/nitrogen/generated/shared/c++/NativeThreadDispatcher.hpp -@@ -0,0 +1,36 @@ -+/// -+/// NativeThreadDispatcher.hpp -+/// VisionCamera -+/// Copyright © 2025 Marc Rousavy @ Margelo -+/// -+ -+#pragma once -+ -+#include "HybridNativeThreadSpec.hpp" -+#include "JSIConverter+AsyncQueue.hpp" -+#include -+ -+namespace margelo::nitro::camera { -+ -+/** -+ * An implementation of `nitro::Dispatcher` that uses a `NativeThread` to run its jobs. -+ * -+ * The `NativeThread` (`HybridNativeThreadSpec`) is a platform-implemented object, -+ * e.g. using `DispatchQueue` on iOS. -+ */ -+class NativeThreadDispatcher : public nitro::Dispatcher { -+public: -+ NativeThreadDispatcher(std::shared_ptr thread) : _thread(std::move(thread)) {} -+ -+ void runSync(std::function&&) override { -+ throw std::runtime_error("runSync(...) is not implemented for NativeThreadDispatcher!"); -+ } -+ void runAsync(std::function&& function) override { -+ _thread->runOnThread(function); -+ } -+ -+private: -+ std::shared_ptr _thread; -+}; -+ -+} // namespace margelo::nitro::camera -diff --git a/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/BoundingBox.kt b/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/BoundingBox.kt -new file mode 100644 -index 0000000..aaaaaaa ---- /dev/null -+++ b/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/BoundingBox.kt -@@ -0,0 +1,47 @@ -+/// -+/// BoundingBox.kt -+/// This file was generated by nitrogen. DO NOT MODIFY THIS FILE. -+/// https://github.com/mrousavy/nitro -+/// Copyright © Marc Rousavy @ Margelo -+/// -+ -+package com.margelo.nitro.camera -+ -+import androidx.annotation.Keep -+import com.facebook.proguard.annotations.DoNotStrip -+ -+ -+/** -+ * Represents the JavaScript object/struct "BoundingBox". -+ */ -+@DoNotStrip -+@Keep -+data class BoundingBox( -+ @DoNotStrip -+ @Keep -+ val x: Double, -+ @DoNotStrip -+ @Keep -+ val y: Double, -+ @DoNotStrip -+ @Keep -+ val width: Double, -+ @DoNotStrip -+ @Keep -+ val height: Double -+) { -+ /* primary constructor */ -+ -+ companion object { -+ /** -+ * Constructor called from C++ -+ */ -+ @DoNotStrip -+ @Keep -+ @Suppress("unused") -+ @JvmStatic -+ private fun fromCpp(x: Double, y: Double, width: Double, height: Double): BoundingBox { -+ return BoundingBox(x, y, width, height) -+ } -+ } -+} -diff --git a/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/HybridScannedObjectSpec.kt b/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/HybridScannedObjectSpec.kt -new file mode 100644 -index 0000000..bbbbbbb ---- /dev/null -+++ b/nitrogen/generated/android/kotlin/com/margelo/nitro/camera/HybridScannedObjectSpec.kt -@@ -0,0 +1,60 @@ -+/// -+/// HybridScannedObjectSpec.kt -+/// This file was generated by nitrogen. DO NOT MODIFY THIS FILE. -+/// https://github.com/mrousavy/nitro -+/// Copyright © Marc Rousavy @ Margelo -+/// -+ -+package com.margelo.nitro.camera -+ -+import androidx.annotation.Keep -+import com.facebook.jni.HybridData -+import com.facebook.proguard.annotations.DoNotStrip -+import com.margelo.nitro.core.HybridObject -+ -+/** -+ * A Kotlin class representing the ScannedObject HybridObject. -+ * Implement this abstract class to create Kotlin-based instances of ScannedObject. -+ */ -+@DoNotStrip -+@Keep -+@Suppress( -+ "KotlinJniMissingFunction", "unused", -+ "RedundantSuppression", "RedundantUnitReturnType", "SimpleRedundantLet", -+ "LocalVariableName", "PropertyName", "PrivatePropertyName", "FunctionName" -+) -+abstract class HybridScannedObjectSpec: HybridObject() { -+ @DoNotStrip -+ private var mHybridData: HybridData = initHybrid() -+ -+ init { -+ super.updateNative(mHybridData) -+ } -+ -+ override fun updateNative(hybridData: HybridData) { -+ mHybridData = hybridData -+ super.updateNative(hybridData) -+ } -+ -+ // Default implementation of `HybridObject.toString()` -+ override fun toString(): String { -+ return "[HybridObject ScannedObject]" -+ } -+ -+ // Properties -+ @get:DoNotStrip -+ @get:Keep -+ abstract val type: ScannedObjectType -+ -+ @get:DoNotStrip -+ @get:Keep -+ abstract val boundingBox: BoundingBox -+ -+ // Methods -+ -+ private external fun initHybrid(): HybridData -+ -+ companion object { -+ protected const val TAG = "HybridScannedObjectSpec" -+ } -+} -diff --git a/nitrogen/generated/android/c++/JBoundingBox.hpp b/nitrogen/generated/android/c++/JBoundingBox.hpp -new file mode 100644 -index 0000000..ccccccc ---- /dev/null -+++ b/nitrogen/generated/android/c++/JBoundingBox.hpp -@@ -0,0 +1,69 @@ -+/// -+/// JBoundingBox.hpp -+/// This file was generated by nitrogen. DO NOT MODIFY THIS FILE. -+/// https://github.com/mrousavy/nitro -+/// Copyright © Marc Rousavy @ Margelo -+/// -+ -+#pragma once -+ -+#include -+#include "BoundingBox.hpp" -+ -+ -+ -+namespace margelo::nitro::camera { -+ -+ using namespace facebook; -+ -+ /** -+ * The C++ JNI bridge between the C++ struct "BoundingBox" and the the Kotlin data class "BoundingBox". -+ */ -+ struct JBoundingBox final: public jni::JavaClass { -+ public: -+ static auto constexpr kJavaDescriptor = "Lcom/margelo/nitro/camera/BoundingBox;"; -+ -+ public: -+ /** -+ * Convert this Java/Kotlin-based struct to the C++ struct BoundingBox by copying all values to C++. -+ */ -+ [[maybe_unused]] -+ [[nodiscard]] -+ BoundingBox toCpp() const { -+ static const auto clazz = javaClassStatic(); -+ static const auto fieldX = clazz->getField("x"); -+ double x = this->getFieldValue(fieldX); -+ static const auto fieldY = clazz->getField("y"); -+ double y = this->getFieldValue(fieldY); -+ static const auto fieldWidth = clazz->getField("width"); -+ double width = this->getFieldValue(fieldWidth); -+ static const auto fieldHeight = clazz->getField("height"); -+ double height = this->getFieldValue(fieldHeight); -+ return BoundingBox( -+ x, -+ y, -+ width, -+ height -+ ); -+ } -+ -+ public: -+ /** -+ * Create a Java/Kotlin-based struct by copying all values from the given C++ struct to Java. -+ */ -+ [[maybe_unused]] -+ static jni::local_ref fromCpp(const BoundingBox& value) { -+ using JSignature = JBoundingBox(double, double, double, double); -+ static const auto clazz = javaClassStatic(); -+ static const auto create = clazz->getStaticMethod("fromCpp"); -+ return create( -+ clazz, -+ value.x, -+ value.y, -+ value.width, -+ value.height -+ ); -+ } -+ }; -+ -+} // namespace margelo::nitro::camera -diff --git a/nitrogen/generated/android/c++/JHybridScannedObjectSpec.hpp b/nitrogen/generated/android/c++/JHybridScannedObjectSpec.hpp -new file mode 100644 -index 0000000..ddddddd ---- /dev/null -+++ b/nitrogen/generated/android/c++/JHybridScannedObjectSpec.hpp -@@ -0,0 +1,63 @@ -+/// -+/// JHybridScannedObjectSpec.hpp -+/// This file was generated by nitrogen. DO NOT MODIFY THIS FILE. -+/// https://github.com/mrousavy/nitro -+/// Copyright © Marc Rousavy @ Margelo -+/// -+ -+#pragma once -+ -+#include -+#include -+#include "HybridScannedObjectSpec.hpp" -+ -+ -+ -+ -+namespace margelo::nitro::camera { -+ -+ using namespace facebook; -+ -+ class JHybridScannedObjectSpec: public jni::HybridClass, -+ public virtual HybridScannedObjectSpec { -+ public: -+ static auto constexpr kJavaDescriptor = "Lcom/margelo/nitro/camera/HybridScannedObjectSpec;"; -+ static jni::local_ref initHybrid(jni::alias_ref jThis); -+ static void registerNatives(); -+ -+ protected: -+ // C++ constructor (called from Java via `initHybrid()`) -+ explicit JHybridScannedObjectSpec(jni::alias_ref jThis) : -+ HybridObject(HybridScannedObjectSpec::TAG), -+ HybridBase(jThis), -+ _javaPart(jni::make_global(jThis)) {} -+ -+ public: -+ ~JHybridScannedObjectSpec() override { -+ // Hermes GC can destroy JS objects on a non-JNI Thread. -+ jni::ThreadScope::WithClassLoader([&] { _javaPart.reset(); }); -+ } -+ -+ public: -+ size_t getExternalMemorySize() noexcept override; -+ bool equals(const std::shared_ptr& other) override; -+ void dispose() noexcept override; -+ std::string toString() override; -+ -+ public: -+ inline const jni::global_ref& getJavaPart() const noexcept { -+ return _javaPart; -+ } -+ -+ public: -+ // Properties -+ ScannedObjectType getType() override; -+ BoundingBox getBoundingBox() override; -+ -+ private: -+ friend HybridBase; -+ using HybridBase::HybridBase; -+ jni::global_ref _javaPart; -+ }; -+ -+} // namespace margelo::nitro::camera -diff --git a/nitrogen/generated/android/VisionCamera+autolinking.cmake b/nitrogen/generated/android/VisionCamera+autolinking.cmake -index 0000000..1111111 100644 ---- a/nitrogen/generated/android/VisionCamera+autolinking.cmake -+++ b/nitrogen/generated/android/VisionCamera+autolinking.cmake -@@ -112,3 +112,4 @@ - ../nitrogen/generated/android/c++/JHybridPreviewViewSpec.cpp - ../nitrogen/generated/android/c++/views/JHybridPreviewViewStateUpdater.cpp -+ ../nitrogen/generated/android/c++/JHybridScannedObjectSpec.cpp - ) -diff --git a/nitrogen/generated/android/c++/JHybridScannedObjectSpec.cpp b/nitrogen/generated/android/c++/JHybridScannedObjectSpec.cpp -new file mode 100644 -index 0000000..eeeeeee ---- /dev/null -+++ b/nitrogen/generated/android/c++/JHybridScannedObjectSpec.cpp -@@ -0,0 +1,69 @@ -+/// -+/// JHybridScannedObjectSpec.cpp -+/// This file was generated by nitrogen. DO NOT MODIFY THIS FILE. -+/// https://github.com/mrousavy/nitro -+/// Copyright © Marc Rousavy @ Margelo -+/// -+ -+#include "JHybridScannedObjectSpec.hpp" -+ -+// Forward declaration of `ScannedObjectType` to properly resolve imports. -+namespace margelo::nitro::camera { enum class ScannedObjectType; } -+// Forward declaration of `BoundingBox` to properly resolve imports. -+namespace margelo::nitro::camera { struct BoundingBox; } -+ -+#include "ScannedObjectType.hpp" -+#include "JScannedObjectType.hpp" -+#include "BoundingBox.hpp" -+#include "JBoundingBox.hpp" -+ -+namespace margelo::nitro::camera { -+ -+ jni::local_ref JHybridScannedObjectSpec::initHybrid(jni::alias_ref jThis) { -+ return makeCxxInstance(jThis); -+ } -+ -+ void JHybridScannedObjectSpec::registerNatives() { -+ registerHybrid({ -+ makeNativeMethod("initHybrid", JHybridScannedObjectSpec::initHybrid), -+ }); -+ } -+ -+ size_t JHybridScannedObjectSpec::getExternalMemorySize() noexcept { -+ static const auto method = javaClassStatic()->getMethod("getMemorySize"); -+ return method(_javaPart); -+ } -+ -+ bool JHybridScannedObjectSpec::equals(const std::shared_ptr& other) { -+ if (auto otherCast = std::dynamic_pointer_cast(other)) { -+ return _javaPart == otherCast->_javaPart; -+ } -+ return false; -+ } -+ -+ void JHybridScannedObjectSpec::dispose() noexcept { -+ static const auto method = javaClassStatic()->getMethod("dispose"); -+ method(_javaPart); -+ } -+ -+ std::string JHybridScannedObjectSpec::toString() { -+ static const auto method = javaClassStatic()->getMethod("toString"); -+ auto javaString = method(_javaPart); -+ return javaString->toStdString(); -+ } -+ -+ // Properties -+ ScannedObjectType JHybridScannedObjectSpec::getType() { -+ static const auto method = javaClassStatic()->getMethod()>("getType"); -+ auto __result = method(_javaPart); -+ return __result->toCpp(); -+ } -+ BoundingBox JHybridScannedObjectSpec::getBoundingBox() { -+ static const auto method = javaClassStatic()->getMethod()>("getBoundingBox"); -+ auto __result = method(_javaPart); -+ return __result->toCpp(); -+ } -+ -+ // Methods -+ -+} // namespace margelo::nitro::camera -diff --git a/android/src/main/java/com/margelo/nitro/camera/hybrids/recording/HybridVideoRecorder.kt b/android/src/main/java/com/margelo/nitro/camera/hybrids/recording/HybridVideoRecorder.kt -index aaaaaaa..bbbbbbb 100644 ---- a/android/src/main/java/com/margelo/nitro/camera/hybrids/recording/HybridVideoRecorder.kt -+++ b/android/src/main/java/com/margelo/nitro/camera/hybrids/recording/HybridVideoRecorder.kt -@@ -55,6 +55,6 @@ - when (event) { - is VideoRecordEvent.Start -> { -- promise.resolve() -+ promise.resolve(Unit) - didResolve = true - } - -@@ -98,27 +98,48 @@ - override fun stopRecording(): Promise { -- return Promise.parallel(executor) { -- val recording = recording ?: throw Error("Not currently recording!") -- recording.stop() -- this.isPaused = false -- this.recording = null -- this.recordedDuration = 0.0 -- this.recordedFileSize = 0.0 -- } -+ val promise = Promise() -+ executor.execute { -+ try { -+ val recording = recording ?: throw Error("Not currently recording!") -+ recording.stop() -+ this.isPaused = false -+ this.recording = null -+ this.recordedDuration = 0.0 -+ this.recordedFileSize = 0.0 -+ promise.resolve(Unit) -+ } catch (e: Throwable) { -+ promise.reject(e) -+ } -+ } -+ return promise - } - - override fun pauseRecording(): Promise { -- return Promise.parallel(executor) { -- val recording = recording ?: throw Error("Not currently recording!") -- recording.pause() -- this.isPaused = true -- } -+ val promise = Promise() -+ executor.execute { -+ try { -+ val recording = recording ?: throw Error("Not currently recording!") -+ recording.pause() -+ this.isPaused = true -+ promise.resolve(Unit) -+ } catch (e: Throwable) { -+ promise.reject(e) -+ } -+ } -+ return promise - } - - override fun resumeRecording(): Promise { -- return Promise.parallel(executor) { -- val recording = recording ?: throw Error("Not currently recording!") -- recording.resume() -- this.isPaused = false -- } -+ val promise = Promise() -+ executor.execute { -+ try { -+ val recording = recording ?: throw Error("Not currently recording!") -+ recording.resume() -+ this.isPaused = false -+ promise.resolve(Unit) -+ } catch (e: Throwable) { -+ promise.reject(e) -+ } -+ } -+ return promise - } - } From c1941d138cc15e7e780e724192e2b6775584264c Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 25 Feb 2026 14:19:49 +0100 Subject: [PATCH 15/71] chore: remove unused ImageSegmentation.cpp --- .../ImageSegmentation.cpp | 170 ------------------ 1 file changed, 170 deletions(-) delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp deleted file mode 100644 index a2c1ae865b..0000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/ImageSegmentation.cpp +++ /dev/null @@ -1,170 +0,0 @@ -#include "ImageSegmentation.h" - -#include - -#include -#include -#include -#include -#include -#include - -namespace rnexecutorch::models::image_segmentation { - -ImageSegmentation::ImageSegmentation( - const std::string &modelSource, - std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) { - auto inputShapes = getAllInputShapes(); - if (inputShapes.size() == 0) { - throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - "Model seems to not take any input tensors."); - } - std::vector modelInputShape = inputShapes[0]; - if (modelInputShape.size() < 2) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " - "but got: %zu.", - modelInputShape.size()); - throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, - errorMessage); - } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); - numModelPixels = modelImageSize.area(); -} - -std::shared_ptr ImageSegmentation::generate( - std::string imageSource, - std::set> classesOfInterest, bool resize) { - auto [inputTensor, originalSize] = - image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); - - auto forwardResult = BaseModel::forward(inputTensor); - if (!forwardResult.ok()) { - throw RnExecutorchError(forwardResult.error(), - "The model's forward function did not succeed. " - "Ensure the model input is correct."); - } - - return postprocess(forwardResult->at(0).toTensor(), originalSize, - classesOfInterest, resize); -} - -std::shared_ptr ImageSegmentation::postprocess( - const Tensor &tensor, cv::Size originalSize, - std::set> classesOfInterest, bool resize) { - - auto dataPtr = static_cast(tensor.const_data_ptr()); - auto resultData = std::span(dataPtr, tensor.numel()); - - // We copy the ET-owned data to jsi array buffers that can be directly - // returned to JS - std::vector> resultClasses; - resultClasses.reserve(numClasses); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - auto classBuffer = std::make_shared( - &resultData[cl * numModelPixels], numModelPixels * sizeof(float)); - resultClasses.push_back(classBuffer); - } - - // Apply softmax per each pixel across all classes - for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { - std::vector classValues(numClasses); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - classValues[cl] = - reinterpret_cast(resultClasses[cl]->data())[pixel]; - } - numerical::softmax(classValues); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - reinterpret_cast(resultClasses[cl]->data())[pixel] = - classValues[cl]; - } - } - - // Calculate the maximum class for each pixel - auto argmax = - std::make_shared(numModelPixels * sizeof(int32_t)); - for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { - float max = reinterpret_cast(resultClasses[0]->data())[pixel]; - int maxInd = 0; - for (int cl = 1; cl < numClasses; ++cl) { - if (reinterpret_cast(resultClasses[cl]->data())[pixel] > max) { - maxInd = cl; - max = reinterpret_cast(resultClasses[cl]->data())[pixel]; - } - } - reinterpret_cast(argmax->data())[pixel] = maxInd; - } - - auto buffersToReturn = std::make_shared>>(); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - if (classesOfInterest.contains(constants::kDeeplabV3Resnet50Labels[cl])) { - (*buffersToReturn)[constants::kDeeplabV3Resnet50Labels[cl]] = - resultClasses[cl]; - } - } - - // Resize selected classes and argmax - if (resize) { - cv::Mat argmaxMat(modelImageSize, CV_32SC1, argmax->data()); - cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0, - cv::InterpolationFlags::INTER_NEAREST); - argmax = std::make_shared( - argmaxMat.data, originalSize.area() * sizeof(int32_t)); - - for (auto &[label, arrayBuffer] : *buffersToReturn) { - cv::Mat classMat(modelImageSize, CV_32FC1, arrayBuffer->data()); - cv::resize(classMat, classMat, originalSize); - arrayBuffer = std::make_shared( - classMat.data, originalSize.area() * sizeof(float)); - } - } - return populateDictionary(argmax, buffersToReturn); -} - -std::shared_ptr ImageSegmentation::populateDictionary( - std::shared_ptr argmax, - std::shared_ptr>> - classesToOutput) { - // Synchronize the invoked thread to return when the dict is constructed - auto promisePtr = std::make_shared>(); - std::future doneFuture = promisePtr->get_future(); - - std::shared_ptr dictPtr = nullptr; - callInvoker->invokeAsync( - [argmax, classesToOutput, &dictPtr, promisePtr](jsi::Runtime &runtime) { - dictPtr = std::make_shared(runtime); - auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax); - - auto int32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Int32Array"); - auto int32Array = - int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) - .getObject(runtime); - dictPtr->setProperty(runtime, "ARGMAX", int32Array); - - for (auto &[classLabel, owningBuffer] : *classesToOutput) { - auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); - - auto float32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Float32Array"); - auto float32Array = - float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) - .getObject(runtime); - - dictPtr->setProperty( - runtime, jsi::String::createFromAscii(runtime, classLabel.data()), - float32Array); - } - promisePtr->set_value(); - }); - - doneFuture.wait(); - return dictPtr; -} - -} // namespace rnexecutorch::models::image_segmentation From 23404e7275f8f119a32a6deb00c3c25a132c4c63 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 25 Feb 2026 14:38:39 +0100 Subject: [PATCH 16/71] docs: add correct api references --- .../classes/ImageSegmentationModule.md | 356 ++++++++++++++++++ 1 file changed, 356 insertions(+) create mode 100644 docs/docs/06-api-reference/classes/ImageSegmentationModule.md diff --git a/docs/docs/06-api-reference/classes/ImageSegmentationModule.md b/docs/docs/06-api-reference/classes/ImageSegmentationModule.md new file mode 100644 index 0000000000..6b41289069 --- /dev/null +++ b/docs/docs/06-api-reference/classes/ImageSegmentationModule.md @@ -0,0 +1,356 @@ +# Class: ImageSegmentationModule\ + +Defined in: [modules/computer_vision/ImageSegmentationModule.ts:60](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L60) + +Generic image segmentation module with type-safe label maps. +Use a model name (e.g. `'deeplab-v3'`) as the generic parameter for built-in models, +or a custom label enum for custom configs. + +## Extends + +- `BaseModule` + +## Type Parameters + +### T + +`T` _extends_ [`SegmentationModelName`](../type-aliases/SegmentationModelName.md) \| [`LabelEnum`](../type-aliases/LabelEnum.md) + +Either a built-in model name (`'deeplab-v3'`, `'selfie-segmentation'`) +or a custom [LabelEnum](../type-aliases/LabelEnum.md) label map. + +## Properties + +### generateFromFrame() + +> **generateFromFrame**: (`frameData`, ...`args`) => `any` + +Defined in: [modules/BaseModule.ts:56](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/BaseModule.ts#L56) + +Process a camera frame directly for real-time inference. + +This method is bound to a native JSI function after calling `load()`, +making it worklet-compatible and safe to call from VisionCamera's +frame processor thread. + +**Performance characteristics:** + +- **Zero-copy path**: When using `frame.getNativeBuffer()` from VisionCamera v5, + frame data is accessed directly without copying (fastest, recommended). +- **Copy path**: When using `frame.toArrayBuffer()`, pixel data is copied + from native to JS, then accessed from native code (slower, fallback). + +**Usage with VisionCamera:** + +```typescript +const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + onFrame(frame) { + 'worklet'; + // Zero-copy approach (recommended) + const nativeBuffer = frame.getNativeBuffer(); + const result = model.generateFromFrame( + { + nativeBuffer: nativeBuffer.pointer, + width: frame.width, + height: frame.height, + }, + ...args + ); + nativeBuffer.release(); + frame.dispose(); + }, +}); +``` + +#### Parameters + +##### frameData + +[`Frame`](../interfaces/Frame.md) + +Frame data object with either nativeBuffer (zero-copy) or data (ArrayBuffer) + +##### args + +...`any`[] + +Additional model-specific arguments (e.g., threshold, options) + +#### Returns + +`any` + +Model-specific output (e.g., detections, classifications, embeddings) + +#### See + +[Frame](../interfaces/Frame.md) for frame data format details + +#### Inherited from + +`BaseModule.generateFromFrame` + +--- + +### nativeModule + +> **nativeModule**: `any` = `null` + +Defined in: [modules/BaseModule.ts:17](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/BaseModule.ts#L17) + +**`Internal`** + +Native module instance (JSI Host Object) + +#### Inherited from + +`BaseModule.nativeModule` + +## Methods + +### delete() + +> **delete**(): `void` + +Defined in: [modules/BaseModule.ts:100](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/BaseModule.ts#L100) + +Unloads the model from memory and releases native resources. + +Always call this method when you're done with a model to prevent memory leaks. + +#### Returns + +`void` + +#### Inherited from + +`BaseModule.delete` + +--- + +### forward() + +> **forward**\<`K`\>(`imageSource`, `classesOfInterest`, `resizeToInput`): `Promise`\<`Record`\<`"ARGMAX"`, `Int32Array`\<`ArrayBufferLike`\>\> & `Record`\<`K`, `Float32Array`\<`ArrayBufferLike`\>\>\> + +Defined in: [modules/computer_vision/ImageSegmentationModule.ts:176](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L176) + +Executes the model's forward pass to perform semantic segmentation on the provided image. + +#### Type Parameters + +##### K + +`K` _extends_ `string` \| `number` \| `symbol` + +#### Parameters + +##### imageSource + +`string` + +A string representing the image source (e.g., a file path, URI, or Base64-encoded string). + +##### classesOfInterest + +`K`[] = `[]` + +An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. + +##### resizeToInput + +`boolean` = `true` + +Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. + +#### Returns + +`Promise`\<`Record`\<`"ARGMAX"`, `Int32Array`\<`ArrayBufferLike`\>\> & `Record`\<`K`, `Float32Array`\<`ArrayBufferLike`\>\>\> + +A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. + +#### Throws + +If the model is not loaded. + +--- + +### forwardET() + +> `protected` **forwardET**(`inputTensor`): `Promise`\<[`TensorPtr`](../interfaces/TensorPtr.md)[]\> + +Defined in: [modules/BaseModule.ts:80](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/BaseModule.ts#L80) + +**`Internal`** + +Runs the model's forward method with the given input tensors. +It returns the output tensors that mimic the structure of output from ExecuTorch. + +#### Parameters + +##### inputTensor + +[`TensorPtr`](../interfaces/TensorPtr.md)[] + +Array of input tensors. + +#### Returns + +`Promise`\<[`TensorPtr`](../interfaces/TensorPtr.md)[]\> + +Array of output tensors. + +#### Inherited from + +`BaseModule.forwardET` + +--- + +### getInputShape() + +> **getInputShape**(`methodName`, `index`): `Promise`\<`number`[]\> + +Defined in: [modules/BaseModule.ts:91](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/BaseModule.ts#L91) + +Gets the input shape for a given method and index. + +#### Parameters + +##### methodName + +`string` + +method name + +##### index + +`number` + +index of the argument which shape is requested + +#### Returns + +`Promise`\<`number`[]\> + +The input shape as an array of numbers. + +#### Inherited from + +`BaseModule.getInputShape` + +--- + +### load() + +> **load**(): `Promise`\<`void`\> + +Defined in: [modules/computer_vision/ImageSegmentationModule.ts:76](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L76) + +Load the model and prepare it for inference. + +#### Returns + +`Promise`\<`void`\> + +#### Overrides + +`BaseModule.load` + +--- + +### fromCustomConfig() + +> `static` **fromCustomConfig**\<`L`\>(`modelSource`, `config`, `onDownloadProgress`): `Promise`\<`ImageSegmentationModule`\<`L`\>\> + +Defined in: [modules/computer_vision/ImageSegmentationModule.ts:142](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L142) + +Creates a segmentation instance with a user-provided label map and custom config. +Use this when working with a custom-exported segmentation model that is not one of the built-in models. + +#### Type Parameters + +##### L + +`L` _extends_ `Readonly`\<`Record`\<`string`, `string` \| `number`\>\> + +#### Parameters + +##### modelSource + +[`ResourceSource`](../type-aliases/ResourceSource.md) + +A fetchable resource pointing to the model binary. + +##### config + +[`SegmentationConfig`](../type-aliases/SegmentationConfig.md)\<`L`\> + +A [SegmentationConfig](../type-aliases/SegmentationConfig.md) object with the label map and optional preprocessing parameters. + +##### onDownloadProgress + +(`progress`) => `void` + +Optional callback to monitor download progress, receiving a value between 0 and 1. + +#### Returns + +`Promise`\<`ImageSegmentationModule`\<`L`\>\> + +A Promise resolving to an `ImageSegmentationModule` instance typed to the provided label map. + +#### Example + +```ts +const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; +const segmentation = await ImageSegmentationModule.fromCustomConfig( + 'https://example.com/custom_model.pte', + { labelMap: MyLabels } +); +``` + +--- + +### fromModelName() + +> `static` **fromModelName**\<`C`\>(`config`, `onDownloadProgress`): `Promise`\<`ImageSegmentationModule`\<[`ModelNameOf`](../type-aliases/ModelNameOf.md)\<`C`\>\>\> + +Defined in: [modules/computer_vision/ImageSegmentationModule.ts:95](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L95) + +Creates a segmentation instance for a built-in model. +The config object is discriminated by `modelName` — each model can require different fields. + +#### Type Parameters + +##### C + +`C` _extends_ [`ModelSources`](../type-aliases/ModelSources.md) + +#### Parameters + +##### config + +`C` + +A [ModelSources](../type-aliases/ModelSources.md) object specifying which model to load and where to fetch it from. + +##### onDownloadProgress + +(`progress`) => `void` + +Optional callback to monitor download progress, receiving a value between 0 and 1. + +#### Returns + +`Promise`\<`ImageSegmentationModule`\<[`ModelNameOf`](../type-aliases/ModelNameOf.md)\<`C`\>\>\> + +A Promise resolving to an `ImageSegmentationModule` instance typed to the chosen model's label map. + +#### Example + +```ts +const segmentation = await ImageSegmentationModule.fromModelName({ + modelName: 'deeplab-v3', + modelSource: 'https://example.com/deeplab.pte', +}); +``` From ec901f67ea85016525e8dde7e3d30eabfcdc3f88 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Feb 2026 14:24:02 +0100 Subject: [PATCH 17/71] feat: frame extractor for zero-copy approach --- .../common/rnexecutorch/utils/FrameExtractor.cpp | 2 +- .../common/rnexecutorch/utils/FrameExtractor.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp index baae35dc35..c62d1b21c9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp @@ -111,4 +111,4 @@ cv::Mat extractFromNativeBuffer(uint64_t bufferPtr) { #endif } -} // namespace rnexecutorch::utils +} // namespace rnexecutorch::utils \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.h b/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.h index f5d7c2094d..dda4ff9568 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.h +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.h @@ -22,4 +22,4 @@ namespace rnexecutorch::utils { */ cv::Mat extractFromNativeBuffer(uint64_t bufferPtr); -} // namespace rnexecutorch::utils +} // namespace rnexecutorch::utils \ No newline at end of file From ff4bcb358475fabac839789c6a6ee316f3f9d033 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Feb 2026 10:37:11 +0100 Subject: [PATCH 18/71] feat: unify frame extraction and preprocessing --- .../common/rnexecutorch/models/VisionModel.cpp | 2 +- .../common/rnexecutorch/models/VisionModel.h | 2 +- .../rnexecutorch/models/classification/Classification.cpp | 2 +- .../common/rnexecutorch/utils/FrameProcessor.cpp | 2 +- .../common/rnexecutorch/utils/FrameProcessor.h | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index b88310e124..c0ce049f28 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -50,4 +50,4 @@ cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { return image; } -} // namespace rnexecutorch::models +} // namespace rnexecutorch::models \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h index 4828f26578..875d633a87 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -151,4 +151,4 @@ class VisionModel : public BaseModel { REGISTER_CONSTRUCTOR(models::VisionModel, std::string, std::shared_ptr); -} // namespace rnexecutorch +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index 0fba071087..b9fad1b88b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -73,4 +73,4 @@ Classification::postprocess(const Tensor &tensor) { return probs; } -} // namespace rnexecutorch::models::classification +} // namespace rnexecutorch::models::classification \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp index 30238ad5c4..1d03b97ba4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp @@ -25,4 +25,4 @@ cv::Mat extractFrame(jsi::Runtime &runtime, const jsi::Object &frameData) { return extractFromNativeBuffer(bufferPtr); } -} // namespace rnexecutorch::utils +} // namespace rnexecutorch::utils \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h index 403f4bde91..6bbb3390df 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h @@ -24,4 +24,4 @@ using namespace facebook; */ cv::Mat extractFrame(jsi::Runtime &runtime, const jsi::Object &frameData); -} // namespace rnexecutorch::utils +} // namespace rnexecutorch::utils \ No newline at end of file From ffb8ae0f62266fb7f404f4d79196eafb439b483d Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 17 Feb 2026 13:05:14 +0100 Subject: [PATCH 19/71] feat: initial version of vision model API --- .../app/object_detection/index.tsx | 163 +++++++++++++++++- .../host_objects/ModelHostObject.h | 2 +- .../metaprogramming/TypeConcepts.h | 9 +- .../models/embeddings/image/ImageEmbeddings.h | 2 +- .../BaseSemanticSegmentation.h | 2 +- .../models/style_transfer/StyleTransfer.h | 2 +- .../computer_vision/ObjectDetectionModule.ts | 7 - 7 files changed, 167 insertions(+), 20 deletions(-) diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index a5e36c344a..521d537fa2 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -1,16 +1,66 @@ import Spinner from '../../components/Spinner'; -import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; import { Detection, useObjectDetection, RF_DETR_NANO, } from 'react-native-executorch'; -import { View, StyleSheet, Image } from 'react-native'; +import { View, StyleSheet, Image, TouchableOpacity, Text } from 'react-native'; import ImageWithBboxes from '../../components/ImageWithBboxes'; import React, { useContext, useEffect, useState } from 'react'; import { GeneratingContext } from '../../context'; import ScreenWrapper from '../../ScreenWrapper'; +import ColorPalette from '../../colors'; +import { Images } from 'react-native-nitro-image'; + +// Helper function to convert image URI to raw pixel data using NitroImage +async function imageUriToPixelData( + uri: string, + targetWidth: number, + targetHeight: number +): Promise<{ + data: ArrayBuffer; + width: number; + height: number; + channels: number; +}> { + try { + // Load image and resize to target dimensions + const image = await Images.loadFromFileAsync(uri); + const resized = image.resize(targetWidth, targetHeight); + + // Get pixel data as ArrayBuffer (RGBA format) + const pixelData = resized.toRawPixelData(); + const buffer = + pixelData instanceof ArrayBuffer ? pixelData : pixelData.buffer; + + // Calculate actual buffer dimensions (accounts for device pixel ratio) + const bufferSize = buffer?.byteLength || 0; + const totalPixels = bufferSize / 4; // RGBA = 4 bytes per pixel + const aspectRatio = targetWidth / targetHeight; + const actualHeight = Math.sqrt(totalPixels / aspectRatio); + const actualWidth = totalPixels / actualHeight; + + console.log('Requested:', targetWidth, 'x', targetHeight); + console.log('Buffer size:', bufferSize); + console.log( + 'Actual dimensions:', + Math.round(actualWidth), + 'x', + Math.round(actualHeight) + ); + + return { + data: buffer, + width: Math.round(actualWidth), + height: Math.round(actualHeight), + channels: 4, // RGBA + }; + } catch (error) { + console.error('Error loading image with NitroImage:', error); + throw error; + } +} export default function ObjectDetectionScreen() { const [imageUri, setImageUri] = useState(''); @@ -45,7 +95,36 @@ export default function ObjectDetectionScreen() { const output = await rfDetr.forward(imageUri); setResults(output); } catch (e) { - console.error(e); + console.error('Error in runForward:', e); + } + } + }; + + const runForwardPixels = async () => { + if (imageUri && imageDimensions) { + try { + console.log('Converting image to pixel data...'); + // Resize to 640x640 to avoid memory issues + const intermediateSize = 640; + const pixelData = await imageUriToPixelData( + imageUri, + intermediateSize, + intermediateSize + ); + + console.log('Running forward with pixel data...', { + width: pixelData.width, + height: pixelData.height, + channels: pixelData.channels, + dataSize: pixelData.data.byteLength, + }); + + // Run inference using unified forward() API + const output = await ssdLite.forward(pixelData, 0.5); + console.log('Pixel data result:', output.length, 'detections'); + setResults(output); + } catch (e) { + console.error('Error in runForwardPixels:', e); } } }; @@ -124,10 +203,41 @@ export default function ObjectDetectionScreen() { )} - + + {/* Custom bottom bar with two buttons */} + + + handleCameraPress(false)}> + 📷 Gallery + + + + + + Run (String) + + + + Run (Pixels) + + + ); } @@ -172,4 +282,43 @@ const styles = StyleSheet.create({ width: '100%', height: '100%', }, + bottomContainer: { + width: '100%', + gap: 15, + alignItems: 'center', + padding: 16, + flex: 1, + }, + bottomIconsContainer: { + flexDirection: 'row', + justifyContent: 'center', + width: '100%', + }, + iconText: { + fontSize: 16, + color: ColorPalette.primary, + }, + buttonsRow: { + flexDirection: 'row', + width: '100%', + gap: 10, + }, + button: { + height: 50, + justifyContent: 'center', + alignItems: 'center', + backgroundColor: ColorPalette.primary, + color: '#fff', + borderRadius: 8, + }, + halfButton: { + flex: 1, + }, + buttonDisabled: { + opacity: 0.5, + }, + buttonText: { + color: '#fff', + fontSize: 16, + }, }); diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 4a92a415c1..7a432a50f6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -483,4 +483,4 @@ template class ModelHostObject : public JsiHostObject { std::shared_ptr callInvoker; }; -} // namespace rnexecutorch +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index 216e2bae39..97f88b6725 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -12,8 +12,13 @@ template concept SameAs = std::is_same_v; template -concept HasGenerate = requires(T t) { - { &T::generate }; +concept HasGenerateFromString = requires(T t) { + { &T::generateFromString }; +}; + +template +concept HasGenerateFromPixels = requires(T t) { + { &T::generateFromPixels }; }; template diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h index 7e114e939d..9a1d6429bd 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h @@ -27,4 +27,4 @@ class ImageEmbeddings final : public BaseEmbeddings { REGISTER_CONSTRUCTOR(models::embeddings::ImageEmbeddings, std::string, std::shared_ptr); -} // namespace rnexecutorch +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index d39a7e5d4a..8ba422afba 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -59,4 +59,4 @@ REGISTER_CONSTRUCTOR(models::semantic_segmentation::BaseSemanticSegmentation, std::string, std::vector, std::vector, std::vector, std::shared_ptr); -} // namespace rnexecutorch +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h index 73744c4d82..8eed3c888d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -33,4 +33,4 @@ class StyleTransfer : public BaseModel { REGISTER_CONSTRUCTOR(models::style_transfer::StyleTransfer, std::string, std::shared_ptr); -} // namespace rnexecutorch +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index bbb990f7b8..c24bbd1369 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -169,11 +169,4 @@ export class ObjectDetectionModule< nativeModule ); } - - async forward( - input: string | PixelData, - detectionThreshold: number = 0.5 - ): Promise { - return super.forward(input, detectionThreshold); - } } From 9936b89cdbd84932d536c7341ed4dde1d0b5be0f Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 17 Feb 2026 17:51:10 +0100 Subject: [PATCH 20/71] refactor: errors, logs, unnecessary comments, use existing TensorPtr --- .../app/object_detection/index.tsx | 61 ++++++++++--------- .../host_objects/JsiConversions.h | 19 ++++++ 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index 521d537fa2..5dad96e0bd 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -13,7 +13,26 @@ import ScreenWrapper from '../../ScreenWrapper'; import ColorPalette from '../../colors'; import { Images } from 'react-native-nitro-image'; -// Helper function to convert image URI to raw pixel data using NitroImage +// Helper function to convert BGRA to RGB +function convertBGRAtoRGB( + buffer: ArrayBuffer, + width: number, + height: number +): ArrayBuffer { + const source = new Uint8Array(buffer); + const rgb = new Uint8Array(width * height * 3); + + for (let i = 0; i < width * height; i++) { + // BGRA format: [B, G, R, A] → RGB: [R, G, B] + rgb[i * 3 + 0] = source[i * 4 + 2]; // R + rgb[i * 3 + 1] = source[i * 4 + 1]; // G + rgb[i * 3 + 2] = source[i * 4 + 0]; // B + } + + return rgb.buffer; +} + +// Helper function to convert image URI to raw RGB pixel data async function imageUriToPixelData( uri: string, targetWidth: number, @@ -29,32 +48,19 @@ async function imageUriToPixelData( const image = await Images.loadFromFileAsync(uri); const resized = image.resize(targetWidth, targetHeight); - // Get pixel data as ArrayBuffer (RGBA format) - const pixelData = resized.toRawPixelData(); + // Get pixel data as ArrayBuffer (BGRA format from NitroImage) + const rawPixelData = resized.toRawPixelData(); const buffer = - pixelData instanceof ArrayBuffer ? pixelData : pixelData.buffer; - - // Calculate actual buffer dimensions (accounts for device pixel ratio) - const bufferSize = buffer?.byteLength || 0; - const totalPixels = bufferSize / 4; // RGBA = 4 bytes per pixel - const aspectRatio = targetWidth / targetHeight; - const actualHeight = Math.sqrt(totalPixels / aspectRatio); - const actualWidth = totalPixels / actualHeight; + rawPixelData instanceof ArrayBuffer ? rawPixelData : rawPixelData.buffer; - console.log('Requested:', targetWidth, 'x', targetHeight); - console.log('Buffer size:', bufferSize); - console.log( - 'Actual dimensions:', - Math.round(actualWidth), - 'x', - Math.round(actualHeight) - ); + // Convert BGRA to RGB as required by the native API + const rgbBuffer = convertBGRAtoRGB(buffer, targetWidth, targetHeight); return { - data: buffer, - width: Math.round(actualWidth), - height: Math.round(actualHeight), - channels: 4, // RGBA + data: rgbBuffer, + width: targetWidth, + height: targetHeight, + channels: 3, // RGB }; } catch (error) { console.error('Error loading image with NitroImage:', error); @@ -104,12 +110,11 @@ export default function ObjectDetectionScreen() { if (imageUri && imageDimensions) { try { console.log('Converting image to pixel data...'); - // Resize to 640x640 to avoid memory issues - const intermediateSize = 640; + // Use original dimensions - let the model resize internally const pixelData = await imageUriToPixelData( imageUri, - intermediateSize, - intermediateSize + imageDimensions.width, + imageDimensions.height ); console.log('Running forward with pixel data...', { @@ -120,7 +125,7 @@ export default function ObjectDetectionScreen() { }); // Run inference using unified forward() API - const output = await ssdLite.forward(pixelData, 0.5); + const output = await ssdLite.forward(pixelData, 0.3); console.log('Pixel data result:', output.length, 'detections'); setResults(output); } catch (e) { diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 96e3168ee7..8936711477 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -368,6 +368,25 @@ inline jsi::Value getJsiValue(uint64_t val, jsi::Runtime &runtime) { return {runtime, bigInt}; } +inline jsi::Value getJsiValue(const std::vector &vec, + jsi::Runtime &runtime) { + jsi::Array array(runtime, vec.size()); + for (size_t i = 0; i < vec.size(); i++) { + // JS numbers are doubles. Large uint64s > 2^53 will lose precision. + array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); + } + return {runtime, array}; +} + +inline jsi::Value getJsiValue(const std::vector &vec, + jsi::Runtime &runtime) { + jsi::Array array(runtime, vec.size()); + for (size_t i = 0; i < vec.size(); i++) { + array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); + } + return {runtime, array}; +} + inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) { return {runtime, val}; } From d0be5fc036fa1e16d848c1e2223d4fd0f4c9c1fc Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 18 Feb 2026 13:03:22 +0100 Subject: [PATCH 21/71] refactor: add or remove empty lines --- .../rnexecutorch/models/classification/Classification.cpp | 2 +- .../rnexecutorch/models/embeddings/image/ImageEmbeddings.h | 2 +- .../models/semantic_segmentation/BaseSemanticSegmentation.h | 2 +- .../common/rnexecutorch/models/style_transfer/StyleTransfer.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index b9fad1b88b..0fba071087 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -73,4 +73,4 @@ Classification::postprocess(const Tensor &tensor) { return probs; } -} // namespace rnexecutorch::models::classification \ No newline at end of file +} // namespace rnexecutorch::models::classification diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h index 9a1d6429bd..7e114e939d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h @@ -27,4 +27,4 @@ class ImageEmbeddings final : public BaseEmbeddings { REGISTER_CONSTRUCTOR(models::embeddings::ImageEmbeddings, std::string, std::shared_ptr); -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index 8ba422afba..d39a7e5d4a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -59,4 +59,4 @@ REGISTER_CONSTRUCTOR(models::semantic_segmentation::BaseSemanticSegmentation, std::string, std::vector, std::vector, std::vector, std::shared_ptr); -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h index 8eed3c888d..73744c4d82 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -33,4 +33,4 @@ class StyleTransfer : public BaseModel { REGISTER_CONSTRUCTOR(models::style_transfer::StyleTransfer, std::string, std::shared_ptr); -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch From ccd8ff09db3c49e713241fcda9dc339c9e3ae5cb Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 19 Feb 2026 22:34:20 +0100 Subject: [PATCH 22/71] fix: errors after rebase --- .../common/rnexecutorch/host_objects/JsiConversions.h | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 8936711477..77f1c51adb 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -368,16 +368,6 @@ inline jsi::Value getJsiValue(uint64_t val, jsi::Runtime &runtime) { return {runtime, bigInt}; } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { - jsi::Array array(runtime, vec.size()); - for (size_t i = 0; i < vec.size(); i++) { - // JS numbers are doubles. Large uint64s > 2^53 will lose precision. - array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); - } - return {runtime, array}; -} - inline jsi::Value getJsiValue(const std::vector &vec, jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); From 1699ae57ed80cc66f809a59f326b188af43a5407 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 24 Feb 2026 09:02:14 +0100 Subject: [PATCH 23/71] feat: suggested changes / improve comments --- .../common/rnexecutorch/host_objects/JsiConversions.h | 9 --------- .../common/rnexecutorch/metaprogramming/TypeConcepts.h | 5 +++++ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 77f1c51adb..96e3168ee7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -368,15 +368,6 @@ inline jsi::Value getJsiValue(uint64_t val, jsi::Runtime &runtime) { return {runtime, bigInt}; } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { - jsi::Array array(runtime, vec.size()); - for (size_t i = 0; i < vec.size(); i++) { - array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); - } - return {runtime, array}; -} - inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) { return {runtime, val}; } diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index 97f88b6725..5cf0c79e14 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -11,6 +11,11 @@ concept DerivedFromOrSameAs = std::is_base_of_v; template concept SameAs = std::is_same_v; +template +concept HasGenerate = requires(T t) { + { &T::generate }; +}; + template concept HasGenerateFromString = requires(T t) { { &T::generateFromString }; From f556ac01a6b6be8c8b57b113dc966ba1d1ff8446 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 25 Feb 2026 13:40:50 +0100 Subject: [PATCH 24/71] feat: make all cv models compatible with Vision Camera --- apps/computer-vision/app/_layout.tsx | 40 + .../app/classification_live/index.tsx | 255 ++++++ .../app/image_segmentation_live/index.tsx | 292 +++++++ .../app/object_detection_live/index.tsx | 98 ++- apps/computer-vision/app/ocr_live/index.tsx | 329 ++++++++ .../app/style_transfer/index.tsx | 73 +- .../app/style_transfer_live/index.tsx | 274 ++++++ .../app/vision_camera_live/index.tsx | 798 ++++++++++++++++++ .../host_objects/JsiConversions.h | 50 ++ .../rnexecutorch/models/VisionModel.cpp | 11 +- .../common/rnexecutorch/models/VisionModel.h | 14 + .../models/classification/Classification.cpp | 66 +- .../models/classification/Classification.h | 21 +- .../embeddings/image/ImageEmbeddings.cpp | 72 +- .../models/embeddings/image/ImageEmbeddings.h | 22 +- .../object_detection/ObjectDetection.cpp | 5 +- .../common/rnexecutorch/models/ocr/OCR.cpp | 66 +- .../common/rnexecutorch/models/ocr/OCR.h | 11 +- .../BaseSemanticSegmentation.cpp | 12 +- .../BaseSemanticSegmentation.h | 22 +- .../models/semantic_segmentation/Types.h | 17 + .../models/style_transfer/StyleTransfer.cpp | 94 ++- .../models/style_transfer/StyleTransfer.h | 27 +- .../models/style_transfer/Types.h | 14 + .../models/vertical_ocr/VerticalOCR.cpp | 70 +- .../models/vertical_ocr/VerticalOCR.h | 11 +- .../tests/integration/ClassificationTest.cpp | 16 +- .../tests/integration/ImageEmbeddingsTest.cpp | 16 +- .../tests/integration/OCRTest.cpp | 16 +- .../tests/integration/StyleTransferTest.cpp | 43 +- .../tests/integration/VerticalOCRTest.cpp | 41 +- .../src/controllers/BaseOCRController.ts | 57 +- .../computer_vision/useImageSegmentation.ts | 131 +++ .../src/hooks/computer_vision/useOCR.ts | 21 +- .../hooks/computer_vision/useVerticalOCR.ts | 22 +- .../src/hooks/useModule.ts | 2 + .../computer_vision/ClassificationModule.ts | 25 +- .../computer_vision/ImageEmbeddingsModule.ts | 23 +- .../SemanticSegmentationModule.ts | 93 +- .../computer_vision/StyleTransferModule.ts | 22 +- .../src/types/classification.ts | 45 +- .../src/types/imageEmbeddings.ts | 27 +- .../react-native-executorch/src/types/ocr.ts | 32 +- .../src/types/semanticSegmentation.ts | 35 +- .../src/types/styleTransfer.ts | 29 +- 45 files changed, 3225 insertions(+), 235 deletions(-) create mode 100644 apps/computer-vision/app/classification_live/index.tsx create mode 100644 apps/computer-vision/app/image_segmentation_live/index.tsx create mode 100644 apps/computer-vision/app/ocr_live/index.tsx create mode 100644 apps/computer-vision/app/style_transfer_live/index.tsx create mode 100644 apps/computer-vision/app/vision_camera_live/index.tsx create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h create mode 100644 packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts diff --git a/apps/computer-vision/app/_layout.tsx b/apps/computer-vision/app/_layout.tsx index 4ce2f3e5c2..3c7fa38ba2 100644 --- a/apps/computer-vision/app/_layout.tsx +++ b/apps/computer-vision/app/_layout.tsx @@ -91,6 +91,46 @@ export default function _layout() { headerTitleStyle: { color: ColorPalette.primary }, }} /> + + + + + { + setGlobalGenerating(isGenerating); + }, [isGenerating, setGlobalGenerating]); + + const [topLabel, setTopLabel] = useState(''); + const [topScore, setTopScore] = useState(0); + const [fps, setFps] = useState(0); + const lastFrameTimeRef = useRef(Date.now()); + + const cameraPermission = useCameraPermission(); + const devices = useCameraDevices(); + const device = devices.find((d) => d.position === 'back') ?? devices[0]; + + const format = useMemo(() => { + if (device == null) return undefined; + try { + return getCameraFormat(device, Templates.FrameProcessing); + } catch { + return undefined; + } + }, [device]); + + const updateStats = useCallback( + (result: { label: string; score: number }) => { + setTopLabel(result.label); + setTopScore(result.score); + const now = Date.now(); + const timeDiff = now - lastFrameTimeRef.current; + if (timeDiff > 0) { + setFps(Math.round(1000 / timeDiff)); + } + lastFrameTimeRef.current = now; + }, + [] + ); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + onFrame(frame) { + 'worklet'; + if (!runOnFrame) { + frame.dispose(); + return; + } + try { + const result = runOnFrame(frame); + if (result) { + // find the top-1 entry + let bestLabel = ''; + let bestScore = -1; + const entries = Object.entries(result); + for (let i = 0; i < entries.length; i++) { + const [label, score] = entries[i]; + if ((score as number) > bestScore) { + bestScore = score as number; + bestLabel = label; + } + } + scheduleOnRN(updateStats, { label: bestLabel, score: bestScore }); + } + } catch { + // ignore frame errors + } finally { + frame.dispose(); + } + }, + }); + + if (!isReady) { + return ( + + ); + } + + if (!cameraPermission.hasPermission) { + return ( + + Camera access needed + cameraPermission.requestPermission()} + style={styles.button} + > + Grant Permission + + + ); + } + + if (device == null) { + return ( + + No camera device found + + ); + } + + return ( + + + + + + + + + + {topLabel || '—'} + + + {topLabel ? (topScore * 100).toFixed(1) + '%' : ''} + + + + + {fps} + fps + + + + + ); +} + +const styles = StyleSheet.create({ + container: { + flex: 1, + backgroundColor: 'black', + }, + centered: { + flex: 1, + backgroundColor: 'black', + justifyContent: 'center', + alignItems: 'center', + gap: 16, + }, + message: { + color: 'white', + fontSize: 18, + }, + button: { + paddingHorizontal: 24, + paddingVertical: 12, + backgroundColor: ColorPalette.primary, + borderRadius: 24, + }, + buttonText: { + color: 'white', + fontSize: 15, + fontWeight: '600', + letterSpacing: 0.3, + }, + bottomBarWrapper: { + position: 'absolute', + bottom: 0, + left: 0, + right: 0, + alignItems: 'center', + paddingHorizontal: 16, + }, + bottomBar: { + flexDirection: 'row', + alignItems: 'center', + backgroundColor: 'rgba(0, 0, 0, 0.55)', + borderRadius: 24, + paddingHorizontal: 28, + paddingVertical: 10, + gap: 24, + maxWidth: '100%', + }, + labelContainer: { + flex: 1, + alignItems: 'flex-start', + }, + labelText: { + color: 'white', + fontSize: 16, + fontWeight: '700', + }, + scoreText: { + color: 'rgba(255,255,255,0.7)', + fontSize: 13, + fontWeight: '500', + }, + statItem: { + alignItems: 'center', + }, + statValue: { + color: 'white', + fontSize: 22, + fontWeight: '700', + letterSpacing: -0.5, + }, + statLabel: { + color: 'rgba(255,255,255,0.55)', + fontSize: 11, + fontWeight: '500', + textTransform: 'uppercase', + letterSpacing: 0.8, + }, + statDivider: { + width: 1, + height: 32, + backgroundColor: 'rgba(255,255,255,0.2)', + }, +}); diff --git a/apps/computer-vision/app/image_segmentation_live/index.tsx b/apps/computer-vision/app/image_segmentation_live/index.tsx new file mode 100644 index 0000000000..f665c63c59 --- /dev/null +++ b/apps/computer-vision/app/image_segmentation_live/index.tsx @@ -0,0 +1,292 @@ +import React, { + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from 'react'; +import { + StatusBar, + StyleSheet, + Text, + TouchableOpacity, + useWindowDimensions, + View, +} from 'react-native'; +import { useSafeAreaInsets } from 'react-native-safe-area-context'; + +import { + Camera, + getCameraFormat, + Templates, + useCameraDevices, + useCameraPermission, + useFrameOutput, +} from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { + DEEPLAB_V3_RESNET50, + useImageSegmentation, +} from 'react-native-executorch'; +import { + Canvas, + Image as SkiaImage, + Skia, + AlphaType, + ColorType, + SkImage, +} from '@shopify/react-native-skia'; +import { GeneratingContext } from '../../context'; +import Spinner from '../../components/Spinner'; +import ColorPalette from '../../colors'; + +// RGBA colors for each DeepLab V3 class (alpha = 180 for semi-transparency) +const CLASS_COLORS: number[][] = [ + [0, 0, 0, 0], // 0 background — transparent + [51, 255, 87, 180], // 1 aeroplane + [51, 87, 255, 180], // 2 bicycle + [255, 51, 246, 180], // 3 bird + [51, 255, 246, 180], // 4 boat + [243, 255, 51, 180], // 5 bottle + [141, 51, 255, 180], // 6 bus + [255, 131, 51, 180], // 7 car + [51, 255, 131, 180], // 8 cat + [131, 51, 255, 180], // 9 chair + [255, 255, 51, 180], // 10 cow + [51, 255, 255, 180], // 11 diningtable + [255, 51, 143, 180], // 12 dog + [127, 51, 255, 180], // 13 horse + [51, 255, 175, 180], // 14 motorbike + [255, 175, 51, 180], // 15 person + [179, 255, 51, 180], // 16 pottedplant + [255, 87, 51, 180], // 17 sheep + [255, 51, 162, 180], // 18 sofa + [51, 162, 255, 180], // 19 train + [162, 51, 255, 180], // 20 tvmonitor +]; + +export default function ImageSegmentationLiveScreen() { + const insets = useSafeAreaInsets(); + const { width: screenWidth, height: screenHeight } = useWindowDimensions(); + + const { isReady, isGenerating, downloadProgress, runOnFrame } = + useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); + const { setGlobalGenerating } = useContext(GeneratingContext); + + useEffect(() => { + setGlobalGenerating(isGenerating); + }, [isGenerating, setGlobalGenerating]); + + const [maskImage, setMaskImage] = useState(null); + const [fps, setFps] = useState(0); + const lastFrameTimeRef = useRef(Date.now()); + + const cameraPermission = useCameraPermission(); + const devices = useCameraDevices(); + const device = devices.find((d) => d.position === 'back') ?? devices[0]; + + const format = useMemo(() => { + if (device == null) return undefined; + try { + return getCameraFormat(device, Templates.FrameProcessing); + } catch { + return undefined; + } + }, [device]); + + const updateMask = useCallback((img: SkImage) => { + setMaskImage(img); + const now = Date.now(); + const timeDiff = now - lastFrameTimeRef.current; + if (timeDiff > 0) { + setFps(Math.round(1000 / timeDiff)); + } + lastFrameTimeRef.current = now; + }, []); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame(frame) { + 'worklet'; + if (!runOnFrame) { + frame.dispose(); + return; + } + try { + const result = runOnFrame(frame, [], false); + if (result?.ARGMAX) { + const argmax: Int32Array = result.ARGMAX; + // Model output is always square (modelImageSize × modelImageSize). + // Derive width/height from argmax length (sqrt for square output). + const side = Math.round(Math.sqrt(argmax.length)); + const width = side; + const height = side; + + // Build RGBA pixel buffer on the worklet thread to avoid transferring + // the large Int32Array across the worklet→RN boundary via scheduleOnRN. + const pixels = new Uint8Array(width * height * 4); + for (let i = 0; i < argmax.length; i++) { + const color = CLASS_COLORS[argmax[i]] ?? [0, 0, 0, 0]; + pixels[i * 4] = color[0]!; + pixels[i * 4 + 1] = color[1]!; + pixels[i * 4 + 2] = color[2]!; + pixels[i * 4 + 3] = color[3]!; + } + + const skData = Skia.Data.fromBytes(pixels); + const img = Skia.Image.MakeImage( + { + width, + height, + alphaType: AlphaType.Unpremul, + colorType: ColorType.RGBA_8888, + }, + skData, + width * 4 + ); + if (img) { + scheduleOnRN(updateMask, img); + } + } + } catch (e) { + console.log('frame error:', String(e)); + } finally { + frame.dispose(); + } + }, + }); + + if (!isReady) { + return ( + + ); + } + + if (!cameraPermission.hasPermission) { + return ( + + Camera access needed + cameraPermission.requestPermission()} + style={styles.button} + > + Grant Permission + + + ); + } + + if (device == null) { + return ( + + No camera device found + + ); + } + + return ( + + + + + + {maskImage && ( + + + + )} + + + + + {fps} + fps + + + + + ); +} + +const styles = StyleSheet.create({ + container: { + flex: 1, + backgroundColor: 'black', + }, + centered: { + flex: 1, + backgroundColor: 'black', + justifyContent: 'center', + alignItems: 'center', + gap: 16, + }, + message: { + color: 'white', + fontSize: 18, + }, + button: { + paddingHorizontal: 24, + paddingVertical: 12, + backgroundColor: ColorPalette.primary, + borderRadius: 24, + }, + buttonText: { + color: 'white', + fontSize: 15, + fontWeight: '600', + letterSpacing: 0.3, + }, + bottomBarWrapper: { + position: 'absolute', + bottom: 0, + left: 0, + right: 0, + alignItems: 'center', + }, + bottomBar: { + flexDirection: 'row', + alignItems: 'center', + backgroundColor: 'rgba(0, 0, 0, 0.55)', + borderRadius: 24, + paddingHorizontal: 28, + paddingVertical: 10, + gap: 24, + }, + statItem: { + alignItems: 'center', + }, + statValue: { + color: 'white', + fontSize: 22, + fontWeight: '700', + letterSpacing: -0.5, + }, + statLabel: { + color: 'rgba(255,255,255,0.55)', + fontSize: 11, + fontWeight: '500', + textTransform: 'uppercase', + letterSpacing: 0.8, + }, +}); diff --git a/apps/computer-vision/app/object_detection_live/index.tsx b/apps/computer-vision/app/object_detection_live/index.tsx index 3db2c53602..b4210b0541 100644 --- a/apps/computer-vision/app/object_detection_live/index.tsx +++ b/apps/computer-vision/app/object_detection_live/index.tsx @@ -35,6 +35,7 @@ import ColorPalette from '../../colors'; export default function ObjectDetectionLiveScreen() { const insets = useSafeAreaInsets(); + const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); const model = useObjectDetection({ model: SSDLITE_320_MOBILENET_V3_LARGE }); const { setGlobalGenerating } = useContext(GeneratingContext); @@ -59,15 +60,23 @@ export default function ObjectDetectionLiveScreen() { } }, [device]); - const updateStats = useCallback((results: Detection[]) => { - setDetectionCount(results.length); - const now = Date.now(); - const timeDiff = now - lastFrameTimeRef.current; - if (timeDiff > 0) { - setFps(Math.round(1000 / timeDiff)); - } - lastFrameTimeRef.current = now; - }, []); + const updateDetections = useCallback( + (payload: { + results: Detection[]; + imageWidth: number; + imageHeight: number; + }) => { + setDetections(payload.results); + setImageSize({ width: payload.imageWidth, height: payload.imageHeight }); + const now = Date.now(); + const timeDiff = now - lastFrameTimeRef.current; + if (timeDiff > 0) { + setFps(Math.round(1000 / timeDiff)); + } + lastFrameTimeRef.current = now; + }, + [] + ); const frameOutput = useFrameOutput({ pixelFormat: 'rgb', @@ -78,10 +87,19 @@ export default function ObjectDetectionLiveScreen() { frame.dispose(); return; } + // After 90° CW rotation, the image fed to the model has swapped dims. + const imageWidth = + frame.width > frame.height ? frame.height : frame.width; + const imageHeight = + frame.width > frame.height ? frame.width : frame.height; try { const result = model.runOnFrame(frame, 0.5); if (result) { - scheduleOnRN(updateStats, result); + scheduleOnRN(updateDetections, { + results: result, + imageWidth, + imageHeight, + }); } } catch { // ignore frame errors @@ -134,13 +152,51 @@ export default function ObjectDetectionLiveScreen() { format={format} /> + {/* Bounding box overlay — measured to match the exact camera preview area */} + + setCanvasSize({ + width: e.nativeEvent.layout.width, + height: e.nativeEvent.layout.height, + }) + } + > + {(() => { + // Cover-fit: camera preview scales to fill the canvas, cropping the + // excess. Compute the same transform so bbox pixel coords map correctly. + const scale = Math.max( + canvasSize.width / imageSize.width, + canvasSize.height / imageSize.height + ); + const offsetX = (canvasSize.width - imageSize.width * scale) / 2; + const offsetY = (canvasSize.height - imageSize.height * scale) / 2; + return detections.map((det, i) => { + const left = det.bbox.x1 * scale + offsetX; + const top = det.bbox.y1 * scale + offsetY; + const width = (det.bbox.x2 - det.bbox.x1) * scale; + const height = (det.bbox.y2 - det.bbox.y1) * scale; + return ( + + + + {det.label} {(det.score * 100).toFixed(0)}% + + + + ); + }); + })()} + + - {detectionCount} + {detections.length} objects @@ -182,6 +238,26 @@ const styles = StyleSheet.create({ fontWeight: '600', letterSpacing: 0.3, }, + bbox: { + position: 'absolute', + borderWidth: 2, + borderColor: ColorPalette.primary, + borderRadius: 4, + }, + bboxLabel: { + position: 'absolute', + top: -22, + left: -2, + backgroundColor: ColorPalette.primary, + paddingHorizontal: 6, + paddingVertical: 2, + borderRadius: 4, + }, + bboxLabelText: { + color: 'white', + fontSize: 11, + fontWeight: '600', + }, bottomBarWrapper: { position: 'absolute', bottom: 0, diff --git a/apps/computer-vision/app/ocr_live/index.tsx b/apps/computer-vision/app/ocr_live/index.tsx new file mode 100644 index 0000000000..a0c93899f6 --- /dev/null +++ b/apps/computer-vision/app/ocr_live/index.tsx @@ -0,0 +1,329 @@ +import React, { + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from 'react'; +import { + StatusBar, + StyleSheet, + Text, + TouchableOpacity, + View, +} from 'react-native'; +import { useSafeAreaInsets } from 'react-native-safe-area-context'; + +import { + Camera, + getCameraFormat, + Templates, + useCameraDevices, + useCameraPermission, + useFrameOutput, +} from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { OCR_ENGLISH, useOCR, OCRDetection } from 'react-native-executorch'; +import { + Canvas, + Path, + Skia, + Text as SkiaText, + matchFont, +} from '@shopify/react-native-skia'; +import { GeneratingContext } from '../../context'; +import Spinner from '../../components/Spinner'; +import ColorPalette from '../../colors'; + +interface FrameDetections { + detections: OCRDetection[]; + frameWidth: number; + frameHeight: number; +} + +export default function OCRLiveScreen() { + const insets = useSafeAreaInsets(); + const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); + + const { isReady, isGenerating, downloadProgress, runOnFrame } = useOCR({ + model: OCR_ENGLISH, + }); + const { setGlobalGenerating } = useContext(GeneratingContext); + + useEffect(() => { + setGlobalGenerating(isGenerating); + }, [isGenerating, setGlobalGenerating]); + + const [frameDetections, setFrameDetections] = useState({ + detections: [], + frameWidth: 1, + frameHeight: 1, + }); + const [fps, setFps] = useState(0); + const lastFrameTimeRef = useRef(Date.now()); + + const font = matchFont({ fontFamily: 'Helvetica', fontSize: 11 }); + + const cameraPermission = useCameraPermission(); + const devices = useCameraDevices(); + const device = devices.find((d) => d.position === 'back') ?? devices[0]; + + const format = useMemo(() => { + if (device == null) return undefined; + try { + return getCameraFormat(device, Templates.FrameProcessing); + } catch { + return undefined; + } + }, [device]); + + const updateDetections = useCallback((result: FrameDetections) => { + setFrameDetections(result); + const now = Date.now(); + const timeDiff = now - lastFrameTimeRef.current; + if (timeDiff > 0) { + setFps(Math.round(1000 / timeDiff)); + } + lastFrameTimeRef.current = now; + }, []); + + const frameOutput = useFrameOutput({ + dropFramesWhileBusy: true, + pixelFormat: 'rgb', + onFrame(frame) { + 'worklet'; + if (!runOnFrame) { + frame.dispose(); + return; + } + const frameWidth = frame.width; + const frameHeight = frame.height; + try { + const result = runOnFrame(frame); + if (result) { + scheduleOnRN(updateDetections, { + detections: result, + frameWidth, + frameHeight, + }); + } + } catch { + // ignore frame errors + } finally { + frame.dispose(); + } + }, + }); + + if (!isReady) { + return ( + + ); + } + + if (!cameraPermission.hasPermission) { + return ( + + Camera access needed + cameraPermission.requestPermission()} + style={styles.button} + > + Grant Permission + + + ); + } + + if (device == null) { + return ( + + No camera device found + + ); + } + + const { detections, frameWidth, frameHeight } = frameDetections; + + // OCR runs on the raw landscape frame (no rotation applied in native). + // The camera preview displays it as portrait (90° CW rotation applied by iOS). + // After rotation the image dimensions become (frameHeight × frameWidth). + // Cover-fit scale uses post-rotation dims to match what the preview shows. + const isLandscape = frameWidth > frameHeight; + const imageW = isLandscape ? frameHeight : frameWidth; + const imageH = isLandscape ? frameWidth : frameHeight; + const scale = Math.max(canvasSize.width / imageW, canvasSize.height / imageH); + const offsetX = (canvasSize.width - imageW * scale) / 2; + const offsetY = (canvasSize.height - imageH * scale) / 2; + + // Map a raw landscape point to screen coords accounting for rotation + cover-fit. + function toScreenX(px: number, py: number) { + // After 90° CW: rotated_x = frameHeight - py, rotated_y = px + const rx = isLandscape ? frameHeight - py : px; + return rx * scale + offsetX; + } + function toScreenY(px: number, py: number) { + const ry = isLandscape ? px : py; + return ry * scale + offsetY; + } + + return ( + + + + + + {/* Measure the overlay area, then draw polygons inside a Canvas */} + + setCanvasSize({ + width: e.nativeEvent.layout.width, + height: e.nativeEvent.layout.height, + }) + } + > + + {detections.map((det, i) => { + if (!det.bbox || det.bbox.length < 2) return null; + + const path = Skia.Path.Make(); + path.moveTo( + toScreenX(det.bbox[0]!.x, det.bbox[0]!.y), + toScreenY(det.bbox[0]!.x, det.bbox[0]!.y) + ); + for (let j = 1; j < det.bbox.length; j++) { + path.lineTo( + toScreenX(det.bbox[j]!.x, det.bbox[j]!.y), + toScreenY(det.bbox[j]!.x, det.bbox[j]!.y) + ); + } + path.close(); + + const labelX = toScreenX(det.bbox[0]!.x, det.bbox[0]!.y); + const labelY = Math.max( + 0, + toScreenY(det.bbox[0]!.x, det.bbox[0]!.y) - 4 + ); + + return ( + + + + {font && ( + + )} + + ); + })} + + + + + + + {detections.length} + regions + + + + {fps} + fps + + + + + ); +} + +const styles = StyleSheet.create({ + container: { + flex: 1, + backgroundColor: 'black', + }, + centered: { + flex: 1, + backgroundColor: 'black', + justifyContent: 'center', + alignItems: 'center', + gap: 16, + }, + message: { + color: 'white', + fontSize: 18, + }, + button: { + paddingHorizontal: 24, + paddingVertical: 12, + backgroundColor: ColorPalette.primary, + borderRadius: 24, + }, + buttonText: { + color: 'white', + fontSize: 15, + fontWeight: '600', + letterSpacing: 0.3, + }, + bottomBarWrapper: { + position: 'absolute', + bottom: 0, + left: 0, + right: 0, + alignItems: 'center', + }, + bottomBar: { + flexDirection: 'row', + alignItems: 'center', + backgroundColor: 'rgba(0, 0, 0, 0.55)', + borderRadius: 24, + paddingHorizontal: 28, + paddingVertical: 10, + gap: 24, + }, + statItem: { + alignItems: 'center', + }, + statValue: { + color: 'white', + fontSize: 22, + fontWeight: '700', + letterSpacing: -0.5, + }, + statLabel: { + color: 'rgba(255,255,255,0.55)', + fontSize: 11, + fontWeight: '500', + textTransform: 'uppercase', + letterSpacing: 0.8, + }, + statDivider: { + width: 1, + height: 32, + backgroundColor: 'rgba(255,255,255,0.2)', + }, +}); diff --git a/apps/computer-vision/app/style_transfer/index.tsx b/apps/computer-vision/app/style_transfer/index.tsx index dc6a0d4963..90801cb053 100644 --- a/apps/computer-vision/app/style_transfer/index.tsx +++ b/apps/computer-vision/app/style_transfer/index.tsx @@ -5,6 +5,14 @@ import { useStyleTransfer, STYLE_TRANSFER_CANDY_QUANTIZED, } from 'react-native-executorch'; +import { + Canvas, + Image as SkiaImage, + Skia, + AlphaType, + ColorType, + SkImage, +} from '@shopify/react-native-skia'; import { View, StyleSheet, Image } from 'react-native'; import React, { useContext, useEffect, useState } from 'react'; import { GeneratingContext } from '../../context'; @@ -16,12 +24,16 @@ export default function StyleTransferScreen() { useEffect(() => { setGlobalGenerating(model.isGenerating); }, [model.isGenerating, setGlobalGenerating]); + const [imageUri, setImageUri] = useState(''); + const [styledImage, setStyledImage] = useState(null); + const handleCameraPress = async (isCamera: boolean) => { const image = await getImage(isCamera); const uri = image?.uri; if (typeof uri === 'string') { - setImageUri(uri as string); + setImageUri(uri); + setStyledImage(null); } }; @@ -29,7 +41,29 @@ export default function StyleTransferScreen() { if (imageUri) { try { const output = await model.forward(imageUri); - setImageUri(output); + const height = output.sizes[0]; + const width = output.sizes[1]; + // Convert RGB -> RGBA for Skia + const rgba = new Uint8Array(width * height * 4); + const rgb = output.dataPtr; + for (let i = 0; i < width * height; i++) { + rgba[i * 4] = rgb[i * 3]; + rgba[i * 4 + 1] = rgb[i * 3 + 1]; + rgba[i * 4 + 2] = rgb[i * 3 + 2]; + rgba[i * 4 + 3] = 255; + } + const skData = Skia.Data.fromBytes(rgba); + const img = Skia.Image.MakeImage( + { + width, + height, + alphaType: AlphaType.Opaque, + colorType: ColorType.RGBA_8888, + }, + skData, + width * 4 + ); + setStyledImage(img); } catch (e) { console.error(e); } @@ -48,15 +82,28 @@ export default function StyleTransferScreen() { return ( - + {styledImage ? ( + + + + ) : ( + + )} { + setGlobalGenerating(isGenerating); + }, [isGenerating, setGlobalGenerating]); + + const [styledImage, setStyledImage] = useState(null); + const [fps, setFps] = useState(0); + const lastFrameTimeRef = useRef(Date.now()); + + const cameraPermission = useCameraPermission(); + const devices = useCameraDevices(); + const device = devices.find((d) => d.position === 'back') ?? devices[0]; + + const format = useMemo(() => { + if (device == null) return undefined; + try { + return getCameraFormat(device, Templates.FrameProcessing); + } catch { + return undefined; + } + }, [device]); + + const updateImage = useCallback((img: SkImage) => { + setStyledImage((prev) => { + prev?.dispose(); + return img; + }); + const now = Date.now(); + const timeDiff = now - lastFrameTimeRef.current; + if (timeDiff > 0) { + setFps(Math.round(1000 / timeDiff)); + } + lastFrameTimeRef.current = now; + }, []); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame(frame) { + 'worklet'; + if (!runOnFrame) { + frame.dispose(); + return; + } + try { + const result = runOnFrame(frame); + if (result?.dataPtr) { + const { dataPtr, sizes } = result; + const height = sizes[0]; + const width = sizes[1]; + // Build Skia image on the worklet thread — avoids transferring the + // large pixel buffer across the worklet→RN boundary via scheduleOnRN. + const skData = Skia.Data.fromBytes(dataPtr); + const img = Skia.Image.MakeImage( + { + width, + height, + alphaType: AlphaType.Opaque, + colorType: ColorType.RGBA_8888, + }, + skData, + width * 4 + ); + if (img) { + scheduleOnRN(updateImage, img); + } + } + } catch (e) { + console.log('frame error:', String(e)); + } finally { + frame.dispose(); + } + }, + }); + + if (!isReady) { + return ( + + ); + } + + if (!cameraPermission.hasPermission) { + return ( + + Camera access needed + cameraPermission.requestPermission()} + style={styles.button} + > + Grant Permission + + + ); + } + + if (device == null) { + return ( + + No camera device found + + ); + } + + return ( + + + + {/* Camera always runs to keep frame processing active */} + + + {/* Styled output overlays the camera feed once available */} + {styledImage && ( + + + + )} + + + + + {fps} + fps + + + + candy + style + + + + + ); +} + +const styles = StyleSheet.create({ + container: { + flex: 1, + backgroundColor: 'black', + }, + centered: { + flex: 1, + backgroundColor: 'black', + justifyContent: 'center', + alignItems: 'center', + gap: 16, + }, + message: { + color: 'white', + fontSize: 18, + }, + button: { + paddingHorizontal: 24, + paddingVertical: 12, + backgroundColor: ColorPalette.primary, + borderRadius: 24, + }, + buttonText: { + color: 'white', + fontSize: 15, + fontWeight: '600', + letterSpacing: 0.3, + }, + bottomBarWrapper: { + position: 'absolute', + bottom: 0, + left: 0, + right: 0, + alignItems: 'center', + }, + bottomBar: { + flexDirection: 'row', + alignItems: 'center', + backgroundColor: 'rgba(0, 0, 0, 0.55)', + borderRadius: 24, + paddingHorizontal: 28, + paddingVertical: 10, + gap: 24, + }, + statItem: { + alignItems: 'center', + }, + statValue: { + color: 'white', + fontSize: 22, + fontWeight: '700', + letterSpacing: -0.5, + }, + styleLabel: { + color: 'white', + fontSize: 16, + fontWeight: '700', + }, + statLabel: { + color: 'rgba(255,255,255,0.55)', + fontSize: 11, + fontWeight: '500', + textTransform: 'uppercase', + letterSpacing: 0.8, + }, + statDivider: { + width: 1, + height: 32, + backgroundColor: 'rgba(255,255,255,0.2)', + }, +}); diff --git a/apps/computer-vision/app/vision_camera_live/index.tsx b/apps/computer-vision/app/vision_camera_live/index.tsx new file mode 100644 index 0000000000..4c7b425b18 --- /dev/null +++ b/apps/computer-vision/app/vision_camera_live/index.tsx @@ -0,0 +1,798 @@ +import React, { + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from 'react'; +import { + ScrollView, + StatusBar, + StyleSheet, + Text, + TouchableOpacity, + View, +} from 'react-native'; +import { useSafeAreaInsets } from 'react-native-safe-area-context'; +import { + Camera, + Frame, + getCameraFormat, + Templates, + useCameraDevices, + useCameraPermission, + useFrameOutput, +} from 'react-native-vision-camera'; +import { createSynchronizable, runOnJS } from 'react-native-worklets'; +import { + DEEPLAB_V3_RESNET50, + Detection, + EFFICIENTNET_V2_S, + OCRDetection, + OCR_ENGLISH, + SSDLITE_320_MOBILENET_V3_LARGE, + STYLE_TRANSFER_RAIN_PRINCESS, + useClassification, + useImageSegmentation, + useObjectDetection, + useOCR, + useStyleTransfer, +} from 'react-native-executorch'; +import { + AlphaType, + Canvas, + ColorType, + Image as SkiaImage, + matchFont, + Path, + Skia, + SkImage, + Text as SkiaText, +} from '@shopify/react-native-skia'; +import { GeneratingContext } from '../../context'; +import Spinner from '../../components/Spinner'; +import ColorPalette from '../../colors'; + +// ─── Model IDs ─────────────────────────────────────────────────────────────── + +type ModelId = + | 'classification' + | 'object_detection' + | 'segmentation' + | 'style_transfer' + | 'ocr'; + +const MODELS: { id: ModelId; label: string }[] = [ + { id: 'classification', label: 'Classification' }, + { id: 'object_detection', label: 'Object Detection' }, + { id: 'segmentation', label: 'Segmentation' }, + { id: 'style_transfer', label: 'Style Transfer' }, + { id: 'ocr', label: 'OCR' }, +]; + +// ─── Segmentation colors ───────────────────────────────────────────────────── + +const CLASS_COLORS: number[][] = [ + [0, 0, 0, 0], + [51, 255, 87, 180], + [51, 87, 255, 180], + [255, 51, 246, 180], + [51, 255, 246, 180], + [243, 255, 51, 180], + [141, 51, 255, 180], + [255, 131, 51, 180], + [51, 255, 131, 180], + [131, 51, 255, 180], + [255, 255, 51, 180], + [51, 255, 255, 180], + [255, 51, 143, 180], + [127, 51, 255, 180], + [51, 255, 175, 180], + [255, 175, 51, 180], + [179, 255, 51, 180], + [255, 87, 51, 180], + [255, 51, 162, 180], + [51, 162, 255, 180], + [162, 51, 255, 180], +]; + +// ─── Kill switch — synchronizable boolean shared between JS and worklet thread. +// setBlocking(true) immediately stops the worklet from dispatching new work +// (both in onFrame and inside the async callback) before the old model tears down. +const frameKillSwitch = createSynchronizable(false); + +// ─── Screen ────────────────────────────────────────────────────────────────── + +export default function VisionCameraLiveScreen() { + const insets = useSafeAreaInsets(); + const [activeModel, setActiveModel] = useState('classification'); + const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); + const { setGlobalGenerating } = useContext(GeneratingContext); + + // ── Models (only the active model loads; others are prevented) ── + const classification = useClassification({ + model: EFFICIENTNET_V2_S, + preventLoad: activeModel !== 'classification', + }); + const objectDetection = useObjectDetection({ + model: SSDLITE_320_MOBILENET_V3_LARGE, + preventLoad: activeModel !== 'object_detection', + }); + const segmentation = useImageSegmentation({ + model: DEEPLAB_V3_RESNET50, + preventLoad: activeModel !== 'segmentation', + }); + const styleTransfer = useStyleTransfer({ + model: STYLE_TRANSFER_RAIN_PRINCESS, + preventLoad: activeModel !== 'style_transfer', + }); + const ocr = useOCR({ + model: OCR_ENGLISH, + preventLoad: activeModel !== 'ocr', + }); + + const activeIsGenerating = { + classification: classification.isGenerating, + object_detection: objectDetection.isGenerating, + segmentation: segmentation.isGenerating, + style_transfer: styleTransfer.isGenerating, + ocr: ocr.isGenerating, + }[activeModel]; + + useEffect(() => { + setGlobalGenerating(activeIsGenerating); + }, [activeIsGenerating, setGlobalGenerating]); + + // ── Camera ── + const [fps, setFps] = useState(0); + const lastFrameTimeRef = useRef(Date.now()); + const cameraPermission = useCameraPermission(); + const devices = useCameraDevices(); + const device = devices.find((d) => d.position === 'back') ?? devices[0]; + const format = useMemo(() => { + if (device == null) return undefined; + try { + return getCameraFormat(device, Templates.FrameProcessing); + } catch { + return undefined; + } + }, [device]); + + // ── Per-model result state ── + const [classResult, setClassResult] = useState({ label: '', score: 0 }); + const [detections, setDetections] = useState([]); + const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); + const [maskImage, setMaskImage] = useState(null); + const [styledImage, setStyledImage] = useState(null); + const [ocrData, setOcrData] = useState<{ + detections: OCRDetection[]; + frameWidth: number; + frameHeight: number; + }>({ detections: [], frameWidth: 1, frameHeight: 1 }); + + // ── Stable callbacks ── + function tick() { + const now = Date.now(); + const diff = now - lastFrameTimeRef.current; + if (diff > 0) setFps(Math.round(1000 / diff)); + lastFrameTimeRef.current = now; + } + + const updateClass = useCallback((r: { label: string; score: number }) => { + setClassResult(r); + tick(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + const updateDetections = useCallback( + (p: { results: Detection[]; imageWidth: number; imageHeight: number }) => { + setDetections(p.results); + setImageSize({ width: p.imageWidth, height: p.imageHeight }); + tick(); + }, + // eslint-disable-next-line react-hooks/exhaustive-deps + [] + ); + + const updateMask = useCallback((img: SkImage) => { + setMaskImage((prev) => { + prev?.dispose(); + return img; + }); + tick(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + const updateStyled = useCallback((img: SkImage) => { + setStyledImage((prev) => { + prev?.dispose(); + return img; + }); + tick(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + const updateOcr = useCallback( + (d: { + detections: OCRDetection[]; + frameWidth: number; + frameHeight: number; + }) => { + setOcrData(d); + tick(); + }, + // eslint-disable-next-line react-hooks/exhaustive-deps + [] + ); + + // ── runOnJS-wrapped callbacks — created on the RN thread so the Babel plugin + // can serialize them into remote functions. These can then be safely called + // from any worklet runtime, including the asyncRunner's worker runtime. + const notifyClass = runOnJS(updateClass); + const notifyDetections = runOnJS(updateDetections); + const notifyMask = runOnJS(updateMask); + const notifyStyled = runOnJS(updateStyled); + const notifyOcr = runOnJS(updateOcr); + + // ── Pull the active model's runOnFrame out of the hook each render. + // These are worklet functions (not plain JS objects), so they CAN be + // captured directly in a useCallback closure — the worklets runtime + // serializes them correctly. A new closure is produced whenever the + // active runOnFrame changes, causing useFrameOutput to re-register. + const classRof = classification.runOnFrame; + const detRof = objectDetection.runOnFrame; + const segRof = segmentation.runOnFrame; + const stRof = styleTransfer.runOnFrame; + const ocrRof = ocr.runOnFrame; + + // When switching models: activate kill switch synchronously so the worklet + // thread stops calling runOnFrame before delete() fires on the old model. + // Then re-enable once the new model's preventLoad has taken effect. + useEffect(() => { + frameKillSwitch.setBlocking(true); + setMaskImage((prev) => { + prev?.dispose(); + return null; + }); + setStyledImage((prev) => { + prev?.dispose(); + return null; + }); + const id = setTimeout(() => { + frameKillSwitch.setBlocking(false); + }, 300); + return () => clearTimeout(id); + }, [activeModel]); + + // ── Single frame output. + // onFrame is re-created (and re-registered by useFrameOutput) whenever the + // active model or its runOnFrame worklet changes. The kill switch provides + // synchronous cross-thread protection during the transition window. + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + + // Kill switch is set synchronously from JS when switching models — + // guaranteed visible here before the next frame is dispatched. + if (frameKillSwitch.getDirty()) { + frame.dispose(); + return; + } + + try { + if (activeModel === 'classification') { + if (!classRof) return; + const result = classRof(frame); + if (result) { + let bestLabel = ''; + let bestScore = -1; + const entries = Object.entries(result); + for (let i = 0; i < entries.length; i++) { + const [label, score] = entries[i]!; + if ((score as number) > bestScore) { + bestScore = score as number; + bestLabel = label; + } + } + notifyClass({ + label: bestLabel, + score: bestScore, + }); + } + } else if (activeModel === 'object_detection') { + if (!detRof) return; + const iw = frame.width > frame.height ? frame.height : frame.width; + const ih = frame.width > frame.height ? frame.width : frame.height; + const result = detRof(frame, 0.5); + if (result) { + notifyDetections({ + results: result, + imageWidth: iw, + imageHeight: ih, + }); + } + } else if (activeModel === 'segmentation') { + if (!segRof) return; + const result = segRof(frame, [], false); + if (result?.ARGMAX) { + const argmax: Int32Array = result.ARGMAX; + const side = Math.round(Math.sqrt(argmax.length)); + const pixels = new Uint8Array(side * side * 4); + for (let i = 0; i < argmax.length; i++) { + const color = CLASS_COLORS[argmax[i]!] ?? [0, 0, 0, 0]; + pixels[i * 4] = color[0]!; + pixels[i * 4 + 1] = color[1]!; + pixels[i * 4 + 2] = color[2]!; + pixels[i * 4 + 3] = color[3]!; + } + const skData = Skia.Data.fromBytes(pixels); + const img = Skia.Image.MakeImage( + { + width: side, + height: side, + alphaType: AlphaType.Unpremul, + colorType: ColorType.RGBA_8888, + }, + skData, + side * 4 + ); + if (img) notifyMask(img); + } + } else if (activeModel === 'style_transfer') { + if (!stRof) return; + const result = stRof(frame); + if (result?.dataPtr) { + const { dataPtr, sizes } = result; + const h = sizes[0]!; + const w = sizes[1]!; + const skData = Skia.Data.fromBytes(dataPtr); + const img = Skia.Image.MakeImage( + { + width: w, + height: h, + alphaType: AlphaType.Opaque, + colorType: ColorType.RGBA_8888, + }, + skData, + w * 4 + ); + if (img) notifyStyled(img); + } + } else if (activeModel === 'ocr') { + if (!ocrRof) return; + const fw = frame.width; + const fh = frame.height; + const result = ocrRof(frame); + if (result) { + notifyOcr({ + detections: result, + frameWidth: fw, + frameHeight: fh, + }); + } + } + } catch { + // ignore + } finally { + frame.dispose(); + } + }, + [ + activeModel, + classRof, + detRof, + segRof, + stRof, + ocrRof, + notifyClass, + notifyDetections, + notifyMask, + notifyStyled, + notifyOcr, + ] + ), + }); + + // ── Loading state: only care about the active model ── + const activeIsReady = { + classification: classification.isReady, + object_detection: objectDetection.isReady, + segmentation: segmentation.isReady, + style_transfer: styleTransfer.isReady, + ocr: ocr.isReady, + }[activeModel]; + + const activeDownloadProgress = { + classification: classification.downloadProgress, + object_detection: objectDetection.downloadProgress, + segmentation: segmentation.downloadProgress, + style_transfer: styleTransfer.downloadProgress, + ocr: ocr.downloadProgress, + }[activeModel]; + + if (!cameraPermission.hasPermission) { + return ( + + Camera access needed + cameraPermission.requestPermission()} + style={styles.button} + > + Grant Permission + + + ); + } + + if (device == null) { + return ( + + No camera device found + + ); + } + + // ── Cover-fit helpers ── + function coverFit(imgW: number, imgH: number) { + const scale = Math.max(canvasSize.width / imgW, canvasSize.height / imgH); + return { + scale, + offsetX: (canvasSize.width - imgW * scale) / 2, + offsetY: (canvasSize.height - imgH * scale) / 2, + }; + } + + // ── OCR coord transform ── + const { + detections: ocrDets, + frameWidth: ocrFW, + frameHeight: ocrFH, + } = ocrData; + const ocrIsLandscape = ocrFW > ocrFH; + const ocrImgW = ocrIsLandscape ? ocrFH : ocrFW; + const ocrImgH = ocrIsLandscape ? ocrFW : ocrFH; + const { + scale: ocrScale, + offsetX: ocrOX, + offsetY: ocrOY, + } = coverFit(ocrImgW, ocrImgH); + function ocrToX(px: number, py: number) { + return (ocrIsLandscape ? ocrFH - py : px) * ocrScale + ocrOX; + } + function ocrToY(px: number, py: number) { + return (ocrIsLandscape ? px : py) * ocrScale + ocrOY; + } + + // ── Object detection cover-fit ── + const { + scale: detScale, + offsetX: detOX, + offsetY: detOY, + } = coverFit(imageSize.width, imageSize.height); + + const font = matchFont({ fontFamily: 'Helvetica', fontSize: 11 }); + + return ( + + + + + + {/* ── Overlays ── */} + + setCanvasSize({ + width: e.nativeEvent.layout.width, + height: e.nativeEvent.layout.height, + }) + } + > + {activeModel === 'segmentation' && maskImage && ( + + + + )} + + {activeModel === 'style_transfer' && styledImage && ( + + + + )} + + {activeModel === 'object_detection' && ( + <> + {detections.map((det, i) => { + const left = det.bbox.x1 * detScale + detOX; + const top = det.bbox.y1 * detScale + detOY; + const w = (det.bbox.x2 - det.bbox.x1) * detScale; + const h = (det.bbox.y2 - det.bbox.y1) * detScale; + return ( + + + + {det.label} {(det.score * 100).toFixed(0)}% + + + + ); + })} + + )} + + {activeModel === 'ocr' && ( + + {ocrDets.map((det, i) => { + if (!det.bbox || det.bbox.length < 2) return null; + const path = Skia.Path.Make(); + path.moveTo( + ocrToX(det.bbox[0]!.x, det.bbox[0]!.y), + ocrToY(det.bbox[0]!.x, det.bbox[0]!.y) + ); + for (let j = 1; j < det.bbox.length; j++) { + path.lineTo( + ocrToX(det.bbox[j]!.x, det.bbox[j]!.y), + ocrToY(det.bbox[j]!.x, det.bbox[j]!.y) + ); + } + path.close(); + const lx = ocrToX(det.bbox[0]!.x, det.bbox[0]!.y); + const ly = Math.max( + 0, + ocrToY(det.bbox[0]!.x, det.bbox[0]!.y) - 4 + ); + return ( + + + + {font && ( + + )} + + ); + })} + + )} + + + {!activeIsReady && ( + + m.id === activeModel)?.label} ${(activeDownloadProgress * 100).toFixed(0)}%`} + /> + + )} + + + + {MODELS.map((m) => ( + setActiveModel(m.id)} + > + + {m.label} + + + ))} + + + + + + {activeModel === 'classification' && ( + + + {classResult.label || '—'} + + {classResult.label ? ( + + {(classResult.score * 100).toFixed(1)}% + + ) : null} + + )} + {activeModel === 'object_detection' && ( + + {detections.length} + objects + + )} + {activeModel === 'segmentation' && ( + + DeepLab V3 + segmentation + + )} + {activeModel === 'style_transfer' && ( + + Rain Princess + style + + )} + {activeModel === 'ocr' && ( + + {ocrDets.length} + regions + + )} + + + {fps} + fps + + + + + ); +} + +// ─── Styles ────────────────────────────────────────────────────────────────── + +const styles = StyleSheet.create({ + container: { flex: 1, backgroundColor: 'black' }, + centered: { + flex: 1, + backgroundColor: 'black', + justifyContent: 'center', + alignItems: 'center', + gap: 16, + }, + message: { color: 'white', fontSize: 18 }, + button: { + paddingHorizontal: 24, + paddingVertical: 12, + backgroundColor: ColorPalette.primary, + borderRadius: 24, + }, + buttonText: { color: 'white', fontSize: 15, fontWeight: '600' }, + loadingOverlay: { + ...StyleSheet.absoluteFillObject, + backgroundColor: 'rgba(0,0,0,0.6)', + justifyContent: 'center', + alignItems: 'center', + }, + topBarWrapper: { + position: 'absolute', + top: 0, + left: 0, + right: 0, + }, + pickerContent: { + paddingHorizontal: 12, + gap: 8, + }, + chip: { + paddingHorizontal: 16, + paddingVertical: 8, + borderRadius: 20, + backgroundColor: 'rgba(0,0,0,0.55)', + borderWidth: 1, + borderColor: 'rgba(255,255,255,0.2)', + }, + chipActive: { + backgroundColor: ColorPalette.primary, + borderColor: ColorPalette.primary, + }, + chipText: { + color: 'rgba(255,255,255,0.8)', + fontSize: 13, + fontWeight: '600', + }, + chipTextActive: { color: 'white' }, + bbox: { + position: 'absolute', + borderWidth: 2, + borderColor: ColorPalette.primary, + borderRadius: 4, + }, + bboxLabel: { + position: 'absolute', + top: -22, + left: -2, + backgroundColor: ColorPalette.primary, + paddingHorizontal: 6, + paddingVertical: 2, + borderRadius: 4, + }, + bboxLabelText: { color: 'white', fontSize: 11, fontWeight: '600' }, + bottomBarWrapper: { + position: 'absolute', + bottom: 0, + left: 0, + right: 0, + alignItems: 'center', + }, + bottomBar: { + flexDirection: 'row', + alignItems: 'center', + backgroundColor: 'rgba(0,0,0,0.55)', + borderRadius: 24, + paddingHorizontal: 28, + paddingVertical: 10, + gap: 24, + }, + resultContainer: { alignItems: 'flex-start', maxWidth: 220 }, + resultText: { + color: 'white', + fontSize: 16, + fontWeight: '700', + }, + resultSub: { + color: 'rgba(255,255,255,0.6)', + fontSize: 12, + fontWeight: '500', + }, + statDivider: { + width: 1, + height: 32, + backgroundColor: 'rgba(255,255,255,0.2)', + }, + statItem: { alignItems: 'center' }, + statValue: { + color: 'white', + fontSize: 22, + fontWeight: '700', + letterSpacing: -0.5, + }, + statLabel: { + color: 'rgba(255,255,255,0.55)', + fontSize: 11, + fontWeight: '500', + textTransform: 'uppercase', + letterSpacing: 0.8, + }, +}); diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 96e3168ee7..0a82dc3efe 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -19,6 +19,7 @@ #include #include #include +#include #include using namespace rnexecutorch::models::speech_to_text; @@ -557,4 +558,53 @@ inline jsi::Value getJsiValue(const TranscriptionResult &result, return obj; } +inline jsi::Value +getJsiValue(const models::style_transfer::PixelDataResult &result, + jsi::Runtime &runtime) { + jsi::Object obj(runtime); + + auto arrayBuffer = jsi::ArrayBuffer(runtime, result.dataPtr); + auto uint8ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Uint8Array"); + auto uint8Array = + uint8ArrayCtor.callAsConstructor(runtime, arrayBuffer).getObject(runtime); + obj.setProperty(runtime, "dataPtr", uint8Array); + + auto sizesArray = jsi::Array(runtime, 3); + sizesArray.setValueAtIndex(runtime, 0, jsi::Value(result.height)); + sizesArray.setValueAtIndex(runtime, 1, jsi::Value(result.width)); + sizesArray.setValueAtIndex(runtime, 2, jsi::Value(4)); + obj.setProperty(runtime, "sizes", sizesArray); + + obj.setProperty(runtime, "scalarType", jsi::Value(0)); + + return obj; +} + +inline jsi::Value +getJsiValue(const models::image_segmentation::SegmentationResult &result, + jsi::Runtime &runtime) { + jsi::Object dict(runtime); + + auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, result.argmax); + auto int32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Int32Array"); + auto int32Array = int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) + .getObject(runtime); + dict.setProperty(runtime, "ARGMAX", int32Array); + + for (auto &[classLabel, owningBuffer] : *result.classBuffers) { + auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); + auto float32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Float32Array"); + auto float32Array = + float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) + .getObject(runtime); + dict.setProperty(runtime, jsi::String::createFromAscii(runtime, classLabel), + float32Array); + } + + return dict; +} + } // namespace rnexecutorch::jsi_conversion diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index c0ce049f28..8f67175c41 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -11,7 +11,16 @@ using namespace facebook; cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) const { auto frameObj = frameData.asObject(runtime); - return ::rnexecutorch::utils::extractFrame(runtime, frameObj); + cv::Mat frame = ::rnexecutorch::utils::extractFrame(runtime, frameObj); + + // Camera sensors natively deliver frames in landscape orientation. + // Rotate 90° CW so all models receive upright portrait frames. + if (frame.cols > frame.rows) { + cv::Mat upright; + cv::rotate(frame, upright, cv::ROTATE_90_CLOCKWISE); + return upright; + } + return frame; } cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h index 875d633a87..38cf26dead 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -53,6 +53,20 @@ class VisionModel : public BaseModel { virtual ~VisionModel() = default; + /** + * @brief Thread-safe unload that waits for any in-flight inference to + * complete + * + * Overrides BaseModel::unload() to acquire inference_mutex_ before + * resetting the module. This prevents a crash where BaseModel::unload() + * destroys module_ while generateFromFrame() is still executing on the + * VisionCamera worklet thread. + */ + void unload() noexcept { + std::scoped_lock lock(inference_mutex_); + BaseModel::unload(); + } + protected: /** * @brief Mutex to ensure thread-safe inference diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index 0fba071087..2a00d5dce8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -12,7 +12,7 @@ namespace rnexecutorch::models::classification { Classification::Classification(const std::string &modelSource, std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) { + : VisionModel(modelSource, callInvoker) { auto inputShapes = getAllInputShapes(); if (inputShapes.size() == 0) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, @@ -32,20 +32,78 @@ Classification::Classification(const std::string &modelSource, modelInputShape[modelInputShape.size() - 2]); } +cv::Mat Classification::preprocessFrame(const cv::Mat &frame) const { + cv::Mat rgb; + + if (frame.channels() == 4) { +#ifdef __APPLE__ + cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); +#else + cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); +#endif + } else if (frame.channels() == 3) { + rgb = frame; + } else { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Unsupported frame format: %d channels", frame.channels()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + if (rgb.size() != modelImageSize) { + cv::Mat resized; + cv::resize(rgb, resized, modelImageSize); + return resized; + } + + return rgb; +} + std::unordered_map -Classification::generate(std::string imageSource) { +Classification::runInference(cv::Mat image) { + std::scoped_lock lock(inference_mutex_); + + cv::Mat preprocessed = preprocessFrame(image); + + const std::vector tensorDims = getAllInputShapes()[0]; auto inputTensor = - image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]) - .first; + image_processing::getTensorFromMatrix(tensorDims, preprocessed); + auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { throw RnExecutorchError(forwardResult.error(), "The model's forward function did not succeed. " "Ensure the model input is correct."); } + return postprocess(forwardResult->at(0).toTensor()); } +std::unordered_map +Classification::generateFromString(std::string imageSource) { + cv::Mat imageBGR = image_processing::readImage(imageSource); + + cv::Mat imageRGB; + cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); + + return runInference(imageRGB); +} + +std::unordered_map +Classification::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { + cv::Mat frame = extractFromFrame(runtime, frameData); + return runInference(frame); +} + +std::unordered_map +Classification::generateFromPixels(JSTensorViewIn pixelData) { + cv::Mat image = extractFromPixels(pixelData); + + return runInference(image); +} + std::unordered_map Classification::postprocess(const Tensor &tensor) { std::span resultData( diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h index 1465fc5f9b..473d9b4bb3 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h @@ -3,25 +3,40 @@ #include #include +#include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include +#include namespace rnexecutorch { namespace models::classification { using executorch::aten::Tensor; using executorch::extension::TensorPtr; -class Classification : public BaseModel { +class Classification : public VisionModel { public: Classification(const std::string &modelSource, std::shared_ptr callInvoker); + [[nodiscard("Registered non-void function")]] std::unordered_map< std::string_view, float> - generate(std::string imageSource); + generateFromString(std::string imageSource); + + [[nodiscard("Registered non-void function")]] std::unordered_map< + std::string_view, float> + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); + + [[nodiscard("Registered non-void function")]] std::unordered_map< + std::string_view, float> + generateFromPixels(JSTensorViewIn pixelData); + +protected: + cv::Mat preprocessFrame(const cv::Mat &frame) const override; private: + std::unordered_map runInference(cv::Mat image); + std::unordered_map postprocess(const Tensor &tensor); cv::Size modelImageSize{0, 0}; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp index ec3129e760..a82fffbb22 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp @@ -1,17 +1,18 @@ #include "ImageEmbeddings.h" +#include + #include #include #include #include -#include namespace rnexecutorch::models::embeddings { ImageEmbeddings::ImageEmbeddings( const std::string &modelSource, std::shared_ptr callInvoker) - : BaseEmbeddings(modelSource, callInvoker) { + : VisionModel(modelSource, callInvoker) { auto inputTensors = getAllInputShapes(); if (inputTensors.size() == 0) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, @@ -31,10 +32,43 @@ ImageEmbeddings::ImageEmbeddings( modelInputShape[modelInputShape.size() - 2]); } +cv::Mat ImageEmbeddings::preprocessFrame(const cv::Mat &frame) const { + cv::Mat rgb; + + if (frame.channels() == 4) { +#ifdef __APPLE__ + cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); +#else + cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); +#endif + } else if (frame.channels() == 3) { + rgb = frame; + } else { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Unsupported frame format: %d channels", frame.channels()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + if (rgb.size() != modelImageSize) { + cv::Mat resized; + cv::resize(rgb, resized, modelImageSize); + return resized; + } + + return rgb; +} + std::shared_ptr -ImageEmbeddings::generate(std::string imageSource) { - auto [inputTensor, originalSize] = - image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); +ImageEmbeddings::runInference(cv::Mat image) { + std::scoped_lock lock(inference_mutex_); + + cv::Mat preprocessed = preprocessFrame(image); + + const std::vector tensorDims = getAllInputShapes()[0]; + auto inputTensor = + image_processing::getTensorFromMatrix(tensorDims, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); @@ -45,7 +79,33 @@ ImageEmbeddings::generate(std::string imageSource) { "is correct."); } - return BaseEmbeddings::postprocess(forwardResult); + auto forwardResultTensor = forwardResult->at(0).toTensor(); + return std::make_shared( + forwardResultTensor.const_data_ptr(), forwardResultTensor.nbytes()); +} + +std::shared_ptr +ImageEmbeddings::generateFromString(std::string imageSource) { + cv::Mat imageBGR = image_processing::readImage(imageSource); + + cv::Mat imageRGB; + cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); + + return runInference(imageRGB); +} + +std::shared_ptr +ImageEmbeddings::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { + cv::Mat frame = extractFromFrame(runtime, frameData); + return runInference(frame); +} + +std::shared_ptr +ImageEmbeddings::generateFromPixels(JSTensorViewIn pixelData) { + cv::Mat image = extractFromPixels(pixelData); + + return runInference(image); } } // namespace rnexecutorch::models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h index 7e114e939d..ec11ee5c69 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h @@ -2,25 +2,41 @@ #include #include +#include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include +#include +#include namespace rnexecutorch { namespace models::embeddings { using executorch::extension::TensorPtr; using executorch::runtime::EValue; -class ImageEmbeddings final : public BaseEmbeddings { +class ImageEmbeddings final : public VisionModel { public: ImageEmbeddings(const std::string &modelSource, std::shared_ptr callInvoker); + [[nodiscard( "Registered non-void function")]] std::shared_ptr - generate(std::string imageSource); + generateFromString(std::string imageSource); + + [[nodiscard( + "Registered non-void function")]] std::shared_ptr + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); + + [[nodiscard( + "Registered non-void function")]] std::shared_ptr + generateFromPixels(JSTensorViewIn pixelData); + +protected: + cv::Mat preprocessFrame(const cv::Mat &frame) const override; private: + std::shared_ptr runInference(cv::Mat image); + cv::Size modelImageSize{0, 0}; }; } // namespace models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index f926f49b78..e54f3e9a4a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -5,7 +5,6 @@ #include #include #include -#include namespace rnexecutorch::models::object_detection { @@ -201,9 +200,7 @@ std::vector ObjectDetection::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, double detectionThreshold) { - auto frameObj = frameData.asObject(runtime); - cv::Mat frame = rnexecutorch::utils::extractFrame(runtime, frameObj); - + cv::Mat frame = extractFromFrame(runtime, frameData); return runInference(frame, detectionThreshold); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp index a521b4e8b0..50834a1b82 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace rnexecutorch::models::ocr { OCR::OCR(const std::string &detectorSource, const std::string &recognizerSource, @@ -12,12 +13,8 @@ OCR::OCR(const std::string &detectorSource, const std::string &recognizerSource, : detector(detectorSource, callInvoker), recognitionHandler(recognizerSource, symbols, callInvoker) {} -std::vector OCR::generate(std::string input) { - cv::Mat image = image_processing::readImage(input); - if (image.empty()) { - throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, - "Failed to load image from path: " + input); - } +std::vector OCR::runInference(cv::Mat image) { + std::scoped_lock lock(inference_mutex_); /* 1. Detection process returns the list of bounding boxes containing areas @@ -43,6 +40,63 @@ std::vector OCR::generate(std::string input) { return result; } +std::vector OCR::generateFromString(std::string input) { + cv::Mat image = image_processing::readImage(input); + if (image.empty()) { + throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, + "Failed to load image from path: " + input); + } + return runInference(image); +} + +std::vector +OCR::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) { + auto frameObj = frameData.asObject(runtime); + cv::Mat frame = ::rnexecutorch::utils::extractFrame(runtime, frameObj); + // extractFrame returns RGB; convert to BGR for consistency with readImage + cv::cvtColor(frame, frame, cv::COLOR_RGB2BGR); + return runInference(frame); +} + +std::vector +OCR::generateFromPixels(JSTensorViewIn pixelData) { + if (pixelData.sizes.size() != 3) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Invalid pixel data: sizes must have 3 elements " + "[height, width, channels], got %zu", + pixelData.sizes.size()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + int32_t height = pixelData.sizes[0]; + int32_t width = pixelData.sizes[1]; + int32_t channels = pixelData.sizes[2]; + + if (channels != 3) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Invalid pixel data: expected 3 channels (RGB), got %d", + channels); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + if (pixelData.scalarType != executorch::aten::ScalarType::Byte) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidUserInput, + "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); + } + + uint8_t *dataPtr = static_cast(pixelData.dataPtr); + // Input is RGB from JS; convert to BGR for consistency with readImage + cv::Mat rgbImage(height, width, CV_8UC3, dataPtr); + cv::Mat image; + cv::cvtColor(rgbImage, image, cv::COLOR_RGB2BGR); + return runInference(image); +} + std::size_t OCR::getMemoryLowerBound() const noexcept { return detector.getMemoryLowerBound() + recognitionHandler.getMemoryLowerBound(); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h index d84ba903f0..719cb957c4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h @@ -1,9 +1,11 @@ #pragma once +#include #include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" +#include #include #include #include @@ -28,13 +30,20 @@ class OCR final { const std::string &recognizerSource, const std::string &symbols, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::vector - generate(std::string input); + generateFromString(std::string input); + [[nodiscard("Registered non-void function")]] std::vector + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); + [[nodiscard("Registered non-void function")]] std::vector + generateFromPixels(JSTensorViewIn pixelData); std::size_t getMemoryLowerBound() const noexcept; void unload() noexcept; private: + std::vector runInference(cv::Mat image); + Detector detector; RecognitionHandler recognitionHandler; + mutable std::mutex inference_mutex_; }; } // namespace models::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp index fc6a04ebcd..3016f8edf1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp @@ -60,9 +60,13 @@ TensorPtr BaseSemanticSegmentation::preprocess(const std::string &imageSource, std::shared_ptr BaseSemanticSegmentation::generate( std::string imageSource, std::set> classesOfInterest, bool resize) { + std::scoped_lock lock(inference_mutex_); - cv::Size originalSize; - auto inputTensor = preprocess(imageSource, originalSize); + cv::Mat preprocessed = preprocessFrame(image); + + const std::vector tensorDims = getAllInputShapes()[0]; + auto inputTensor = + image_processing::getTensorFromMatrix(tensorDims, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); @@ -161,8 +165,8 @@ std::shared_ptr BaseSemanticSegmentation::postprocess( } // Filter classes of interest - auto buffersToReturn = std::make_shared>>(); + auto buffersToReturn = std::make_shared< + std::unordered_map>>(); for (std::size_t cl = 0; cl < resultClasses.size(); ++cl) { if (cl < allClasses.size() && classesOfInterest.contains(allClasses[cl])) { (*buffersToReturn)[allClasses[cl]] = resultClasses[cl]; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index d39a7e5d4a..bd2a6b9e84 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -8,7 +8,8 @@ #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" #include -#include +#include +#include namespace rnexecutorch { namespace models::semantic_segmentation { @@ -30,9 +31,9 @@ class BaseSemanticSegmentation : public BaseModel { std::set> classesOfInterest, bool resize); protected: - virtual TensorPtr preprocess(const std::string &imageSource, - cv::Size &originalSize); - virtual std::shared_ptr + cv::Mat preprocessFrame(const cv::Mat &frame) const override; + + virtual SegmentationResult postprocess(const Tensor &tensor, cv::Size originalSize, std::vector &allClasses, std::set> &classesOfInterest, @@ -44,14 +45,15 @@ class BaseSemanticSegmentation : public BaseModel { std::optional normStd_; std::vector allClasses_; - std::shared_ptr populateDictionary( - std::shared_ptr argmax, - std::shared_ptr>> - classesToOutput); - private: void initModelImageSize(); + + SegmentationResult runInference( + cv::Mat image, cv::Size originalSize, std::vector allClasses, + std::set> classesOfInterest, bool resize); + + TensorPtr preprocessFromString(const std::string &imageSource, + cv::Size &originalSize); }; } // namespace models::semantic_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h new file mode 100644 index 0000000000..b5d6f5067d --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include + +namespace rnexecutorch::models::image_segmentation { + +struct SegmentationResult { + std::shared_ptr argmax; + std::shared_ptr< + std::unordered_map>> + classBuffers; +}; + +} // namespace rnexecutorch::models::image_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index 3b9c0187b9..c334f5d842 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace rnexecutorch::models::style_transfer { using namespace facebook; @@ -13,7 +14,7 @@ using executorch::extension::TensorPtr; StyleTransfer::StyleTransfer(const std::string &modelSource, std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) { + : VisionModel(modelSource, callInvoker) { auto inputShapes = getAllInputShapes(); if (inputShapes.size() == 0) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, @@ -33,17 +34,67 @@ StyleTransfer::StyleTransfer(const std::string &modelSource, modelInputShape[modelInputShape.size() - 2]); } -std::string StyleTransfer::postprocess(const Tensor &tensor, - cv::Size originalSize) { +cv::Mat StyleTransfer::preprocessFrame(const cv::Mat &frame) const { + cv::Mat rgb; + + if (frame.channels() == 4) { +#ifdef __APPLE__ + cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); +#else + cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); +#endif + } else if (frame.channels() == 3) { + rgb = frame; + } else { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Unsupported frame format: %d channels", frame.channels()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + if (rgb.size() != modelImageSize) { + cv::Mat resized; + cv::resize(rgb, resized, modelImageSize); + return resized; + } + + return rgb; +} + +PixelDataResult StyleTransfer::postprocess(const Tensor &tensor, + cv::Size outputSize) { + // Convert tensor output (at modelImageSize) to CV_8UC3 BGR mat cv::Mat mat = image_processing::getMatrixFromTensor(modelImageSize, tensor); - cv::resize(mat, mat, originalSize); - return image_processing::saveToTempFile(mat); + // Resize only if requested output differs from model output size + if (mat.size() != outputSize) { + cv::resize(mat, mat, outputSize); + } + + // Convert BGR -> RGBA so JS can pass the buffer directly to Skia + cv::Mat rgba; + cv::cvtColor(mat, rgba, cv::COLOR_BGR2RGBA); + + std::size_t dataSize = + static_cast(outputSize.width) * outputSize.height * 4; + auto pixelBuffer = std::make_shared(rgba.data, dataSize); + log(LOG_LEVEL::Debug, + "[StyleTransfer] postprocess: RGBA buffer size:", dataSize, + "w:", outputSize.width, "h:", outputSize.height); + + return PixelDataResult{pixelBuffer, outputSize.width, outputSize.height}; } -std::string StyleTransfer::generate(std::string imageSource) { - auto [inputTensor, originalSize] = - image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); +PixelDataResult StyleTransfer::runInference(cv::Mat image, + cv::Size originalSize) { + std::scoped_lock lock(inference_mutex_); + + cv::Mat preprocessed = preprocessFrame(image); + + const std::vector tensorDims = getAllInputShapes()[0]; + auto inputTensor = + image_processing::getTensorFromMatrix(tensorDims, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { @@ -55,4 +106,31 @@ std::string StyleTransfer::generate(std::string imageSource) { return postprocess(forwardResult->at(0).toTensor(), originalSize); } +PixelDataResult StyleTransfer::generateFromString(std::string imageSource) { + cv::Mat imageBGR = image_processing::readImage(imageSource); + cv::Size originalSize = imageBGR.size(); + + cv::Mat imageRGB; + cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); + + return runInference(imageRGB, originalSize); +} + +PixelDataResult StyleTransfer::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { + // extractFromFrame rotates landscape frames 90° CW automatically. + cv::Mat frame = extractFromFrame(runtime, frameData); + + // For real-time frame processing, output at modelImageSize to avoid + // allocating large buffers (e.g. 1280x720x3 ~2.7MB) on every frame. + return runInference(frame, modelImageSize); +} + +PixelDataResult StyleTransfer::generateFromPixels(JSTensorViewIn pixelData) { + cv::Mat image = extractFromPixels(pixelData); + cv::Size originalSize = image.size(); + + return runInference(image, originalSize); +} + } // namespace rnexecutorch::models::style_transfer diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h index 73744c4d82..99f9f4b3ac 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -9,7 +9,9 @@ #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include +#include +#include +#include namespace rnexecutorch { namespace models::style_transfer { @@ -17,15 +19,30 @@ using namespace facebook; using executorch::aten::Tensor; using executorch::extension::TensorPtr; -class StyleTransfer : public BaseModel { +class StyleTransfer : public VisionModel { public: StyleTransfer(const std::string &modelSource, std::shared_ptr callInvoker); - [[nodiscard("Registered non-void function")]] std::string - generate(std::string imageSource); + + [[nodiscard("Registered non-void function")]] PixelDataResult + generateFromString(std::string imageSource); + + [[nodiscard("Registered non-void function")]] PixelDataResult + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); + + [[nodiscard("Registered non-void function")]] PixelDataResult + generateFromPixels(JSTensorViewIn pixelData); + +protected: + cv::Mat preprocessFrame(const cv::Mat &frame) const override; private: - std::string postprocess(const Tensor &tensor, cv::Size originalSize); + // outputSize: size to resize the styled output to before returning. + // Pass modelImageSize for real-time frame processing (avoids large allocs). + // Pass the source image size for generateFromString/generateFromPixels. + PixelDataResult runInference(cv::Mat image, cv::Size outputSize); + + PixelDataResult postprocess(const Tensor &tensor, cv::Size outputSize); cv::Size modelImageSize{0, 0}; }; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h new file mode 100644 index 0000000000..f677183a64 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace rnexecutorch::models::style_transfer { + +struct PixelDataResult { + std::shared_ptr dataPtr; + int width; + int height; +}; + +} // namespace rnexecutorch::models::style_transfer diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp index 0f75d20152..71ea737f8e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp @@ -1,10 +1,12 @@ #include "VerticalOCR.h" #include #include +#include #include #include #include #include +#include #include namespace rnexecutorch::models::ocr { @@ -16,12 +18,9 @@ VerticalOCR::VerticalOCR(const std::string &detectorSource, converter(symbols), independentCharacters(independentChars), callInvoker(invoker) {} -std::vector VerticalOCR::generate(std::string input) { - cv::Mat image = image_processing::readImage(input); - if (image.empty()) { - throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, - "Failed to load image from path: " + input); - } +std::vector VerticalOCR::runInference(cv::Mat image) { + std::scoped_lock lock(inference_mutex_); + // 1. Large Detector std::vector largeBoxes = detector.generate(image, constants::kLargeDetectorWidth); @@ -44,6 +43,65 @@ std::vector VerticalOCR::generate(std::string input) { return predictions; } +std::vector +VerticalOCR::generateFromString(std::string input) { + cv::Mat image = image_processing::readImage(input); + if (image.empty()) { + throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, + "Failed to load image from path: " + input); + } + return runInference(image); +} + +std::vector +VerticalOCR::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { + auto frameObj = frameData.asObject(runtime); + cv::Mat frame = ::rnexecutorch::utils::extractFrame(runtime, frameObj); + // extractFrame returns RGB; convert to BGR for consistency with readImage + cv::cvtColor(frame, frame, cv::COLOR_RGB2BGR); + return runInference(frame); +} + +std::vector +VerticalOCR::generateFromPixels(JSTensorViewIn pixelData) { + if (pixelData.sizes.size() != 3) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Invalid pixel data: sizes must have 3 elements " + "[height, width, channels], got %zu", + pixelData.sizes.size()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + int32_t height = pixelData.sizes[0]; + int32_t width = pixelData.sizes[1]; + int32_t channels = pixelData.sizes[2]; + + if (channels != 3) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Invalid pixel data: expected 3 channels (RGB), got %d", + channels); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + if (pixelData.scalarType != executorch::aten::ScalarType::Byte) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidUserInput, + "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); + } + + uint8_t *dataPtr = static_cast(pixelData.dataPtr); + // Input is RGB from JS; convert to BGR for consistency with readImage + cv::Mat rgbImage(height, width, CV_8UC3, dataPtr); + cv::Mat image; + cv::cvtColor(rgbImage, image, cv::COLOR_RGB2BGR); + return runInference(image); +} + std::size_t VerticalOCR::getMemoryLowerBound() const noexcept { return detector.getMemoryLowerBound() + recognizer.getMemoryLowerBound(); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h index e97fb90348..4016e28138 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h @@ -1,12 +1,14 @@ #pragma once #include +#include #include #include #include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" +#include #include #include #include @@ -48,11 +50,17 @@ class VerticalOCR final { bool indpendentCharacters, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::vector - generate(std::string input); + generateFromString(std::string input); + [[nodiscard("Registered non-void function")]] std::vector + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); + [[nodiscard("Registered non-void function")]] std::vector + generateFromPixels(JSTensorViewIn pixelData); std::size_t getMemoryLowerBound() const noexcept; void unload() noexcept; private: + std::vector runInference(cv::Mat image); + std::pair _handleIndependentCharacters( const types::DetectorBBox &box, const cv::Mat &originalImage, const std::vector &characterBoxes, @@ -75,6 +83,7 @@ class VerticalOCR final { CTCLabelConverter converter; bool independentCharacters; std::shared_ptr callInvoker; + mutable std::mutex inference_mutex_; }; } // namespace models::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp index 10aa663a4a..b64f167c90 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp @@ -28,7 +28,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generate(kValidTestImagePath); + (void)model.generateFromString(kValidTestImagePath); } }; } // namespace model_tests @@ -42,37 +42,37 @@ INSTANTIATE_TYPED_TEST_SUITE_P(Classification, CommonModelTest, // ============================================================================ TEST(ClassificationGenerateTests, InvalidImagePathThrows) { Classification model(kValidClassificationModelPath, nullptr); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(ClassificationGenerateTests, EmptyImagePathThrows) { Classification model(kValidClassificationModelPath, nullptr); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(ClassificationGenerateTests, MalformedURIThrows) { Classification model(kValidClassificationModelPath, nullptr); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(ClassificationGenerateTests, ValidImageReturnsResults) { Classification model(kValidClassificationModelPath, nullptr); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); EXPECT_FALSE(results.empty()); } TEST(ClassificationGenerateTests, ResultsHaveCorrectSize) { Classification model(kValidClassificationModelPath, nullptr); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); auto expectedNumClasses = constants::kImagenet1kV1Labels.size(); EXPECT_EQ(results.size(), expectedNumClasses); } TEST(ClassificationGenerateTests, ResultsContainValidProbabilities) { Classification model(kValidClassificationModelPath, nullptr); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); float sum = 0.0f; for (const auto &[label, prob] : results) { @@ -85,7 +85,7 @@ TEST(ClassificationGenerateTests, ResultsContainValidProbabilities) { TEST(ClassificationGenerateTests, TopPredictionHasReasonableConfidence) { Classification model(kValidClassificationModelPath, nullptr); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); float maxProb = 0.0f; for (const auto &[label, prob] : results) { diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp index 3a23746957..ba76939a8e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp @@ -29,7 +29,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generate(kValidTestImagePath); + (void)model.generateFromString(kValidTestImagePath); } }; } // namespace model_tests @@ -43,31 +43,31 @@ INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, CommonModelTest, // ============================================================================ TEST(ImageEmbeddingsGenerateTests, InvalidImagePathThrows) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(ImageEmbeddingsGenerateTests, EmptyImagePathThrows) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(ImageEmbeddingsGenerateTests, MalformedURIThrows) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(ImageEmbeddingsGenerateTests, ValidImageReturnsResults) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); + auto result = model.generateFromString(kValidTestImagePath); EXPECT_NE(result, nullptr); EXPECT_GT(result->size(), 0u); } TEST(ImageEmbeddingsGenerateTests, ResultsHaveCorrectSize) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); + auto result = model.generateFromString(kValidTestImagePath); size_t numFloats = result->size() / sizeof(float); constexpr size_t kClipEmbeddingDimensions = 512; EXPECT_EQ(numFloats, kClipEmbeddingDimensions); @@ -77,7 +77,7 @@ TEST(ImageEmbeddingsGenerateTests, ResultsAreNormalized) { // TODO: Investigate the source of the issue; GTEST_SKIP() << "Expected to fail in emulator environments"; ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); + auto result = model.generateFromString(kValidTestImagePath); const float *data = reinterpret_cast(result->data()); size_t numFloats = result->size() / sizeof(float); @@ -92,7 +92,7 @@ TEST(ImageEmbeddingsGenerateTests, ResultsAreNormalized) { TEST(ImageEmbeddingsGenerateTests, ResultsContainValidValues) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); + auto result = model.generateFromString(kValidTestImagePath); const float *data = reinterpret_cast(result->data()); size_t numFloats = result->size() / sizeof(float); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp index 428fb5afb1..6f6f708be2 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp @@ -41,7 +41,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generate(kValidTestImagePath); + (void)model.generateFromString(kValidTestImagePath); } }; } // namespace model_tests @@ -67,27 +67,27 @@ TEST(OCRCtorTests, EmptySymbolsThrows) { TEST(OCRGenerateTests, InvalidImagePathThrows) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(OCRGenerateTests, EmptyImagePathThrows) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(OCRGenerateTests, MalformedURIThrows) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(OCRGenerateTests, ValidImageReturnsResults) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); // May or may not have detections depending on image content EXPECT_GE(results.size(), 0u); } @@ -95,7 +95,7 @@ TEST(OCRGenerateTests, ValidImageReturnsResults) { TEST(OCRGenerateTests, DetectionsHaveValidBoundingBoxes) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); for (const auto &detection : results) { // Each bbox should have 4 points @@ -110,7 +110,7 @@ TEST(OCRGenerateTests, DetectionsHaveValidBoundingBoxes) { TEST(OCRGenerateTests, DetectionsHaveValidScores) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); for (const auto &detection : results) { EXPECT_GE(detection.score, 0.0f); @@ -121,7 +121,7 @@ TEST(OCRGenerateTests, DetectionsHaveValidScores) { TEST(OCRGenerateTests, DetectionsHaveNonEmptyText) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); for (const auto &detection : results) { EXPECT_FALSE(detection.text.empty()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp index d5427ce61b..532b4c04b2 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp @@ -1,6 +1,4 @@ #include "BaseModelTests.h" -#include "utils/TestUtils.h" -#include #include #include #include @@ -30,7 +28,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generate(kValidTestImagePath); + (void)model.generateFromString(kValidTestImagePath); } }; } // namespace model_tests @@ -44,51 +42,34 @@ INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, CommonModelTest, // ============================================================================ TEST(StyleTransferGenerateTests, InvalidImagePathThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(StyleTransferGenerateTests, EmptyImagePathThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(StyleTransferGenerateTests, MalformedURIThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } -TEST(StyleTransferGenerateTests, ValidImageReturnsFilePath) { +TEST(StyleTransferGenerateTests, ValidImageReturnsNonNull) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); - EXPECT_FALSE(result.empty()); -} - -TEST(StyleTransferGenerateTests, ResultIsValidFilePath) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); - test_utils::trimFilePrefix(result); - EXPECT_TRUE(std::filesystem::exists(result)); -} - -TEST(StyleTransferGenerateTests, ResultFileHasContent) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); - test_utils::trimFilePrefix(result); - auto fileSize = std::filesystem::file_size(result); - EXPECT_GT(fileSize, 0u); + auto result = model.generateFromString(kValidTestImagePath); + EXPECT_NE(result, nullptr); } TEST(StyleTransferGenerateTests, MultipleGeneratesWork) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_NO_THROW((void)model.generate(kValidTestImagePath)); - auto result1 = model.generate(kValidTestImagePath); - auto result2 = model.generate(kValidTestImagePath); - test_utils::trimFilePrefix(result1); - test_utils::trimFilePrefix(result2); - EXPECT_TRUE(std::filesystem::exists(result1)); - EXPECT_TRUE(std::filesystem::exists(result2)); + EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath)); + auto result1 = model.generateFromString(kValidTestImagePath); + auto result2 = model.generateFromString(kValidTestImagePath); + EXPECT_NE(result1, nullptr); + EXPECT_NE(result2, nullptr); } TEST(StyleTransferInheritedTests, GetInputShapeWorks) { diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp index 7b1010a81e..56f18d862a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp @@ -43,7 +43,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generate(kValidVerticalTestImagePath); + (void)model.generateFromString(kValidVerticalTestImagePath); } }; } // namespace model_tests @@ -85,34 +85,34 @@ TEST(VerticalOCRCtorTests, IndependentCharsFalseDoesntThrow) { TEST(VerticalOCRGenerateTests, IndependentCharsInvalidImageThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, IndependentCharsEmptyImagePathThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(VerticalOCRGenerateTests, IndependentCharsMalformedURIThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, IndependentCharsValidImageReturnsResults) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); EXPECT_GE(results.size(), 0u); } TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidBBoxes) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_EQ(detection.bbox.size(), 4u); @@ -126,7 +126,7 @@ TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidBBoxes) { TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidScores) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_GE(detection.score, 0.0f); @@ -137,7 +137,7 @@ TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidScores) { TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveNonEmptyText) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_FALSE(detection.text.empty()); @@ -148,34 +148,34 @@ TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveNonEmptyText) { TEST(VerticalOCRGenerateTests, JointCharsInvalidImageThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, JointCharsEmptyImagePathThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(VerticalOCRGenerateTests, JointCharsMalformedURIThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, JointCharsValidImageReturnsResults) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); EXPECT_GE(results.size(), 0u); } TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidBBoxes) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_EQ(detection.bbox.size(), 4u); @@ -189,7 +189,7 @@ TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidBBoxes) { TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidScores) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_GE(detection.score, 0.0f); @@ -200,7 +200,7 @@ TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidScores) { TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveNonEmptyText) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_FALSE(detection.text.empty()); @@ -216,8 +216,10 @@ TEST(VerticalOCRStrategyTests, BothStrategiesRunSuccessfully) { kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - EXPECT_NO_THROW((void)independentModel.generate(kValidVerticalTestImagePath)); - EXPECT_NO_THROW((void)jointModel.generate(kValidVerticalTestImagePath)); + EXPECT_NO_THROW( + (void)independentModel.generateFromString(kValidVerticalTestImagePath)); + EXPECT_NO_THROW( + (void)jointModel.generateFromString(kValidVerticalTestImagePath)); } TEST(VerticalOCRStrategyTests, BothStrategiesReturnValidResults) { @@ -229,8 +231,9 @@ TEST(VerticalOCRStrategyTests, BothStrategiesReturnValidResults) { createMockCallInvoker()); auto independentResults = - independentModel.generate(kValidVerticalTestImagePath); - auto jointResults = jointModel.generate(kValidVerticalTestImagePath); + independentModel.generateFromString(kValidVerticalTestImagePath); + auto jointResults = + jointModel.generateFromString(kValidVerticalTestImagePath); // Both should return some results (or none if no text detected) EXPECT_GE(independentResults.size(), 0u); diff --git a/packages/react-native-executorch/src/controllers/BaseOCRController.ts b/packages/react-native-executorch/src/controllers/BaseOCRController.ts index 614d42a212..910e2d5930 100644 --- a/packages/react-native-executorch/src/controllers/BaseOCRController.ts +++ b/packages/react-native-executorch/src/controllers/BaseOCRController.ts @@ -2,10 +2,24 @@ import { Logger } from '../common/Logger'; import { symbols } from '../constants/ocr/symbols'; import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; -import { ResourceSource } from '../types/common'; +import { Frame, PixelData, ResourceSource, ScalarType } from '../types/common'; import { OCRLanguage, OCRDetection } from '../types/ocr'; import { ResourceFetcher } from '../utils/ResourceFetcher'; +function isPixelData(input: unknown): input is PixelData { + return ( + typeof input === 'object' && + input !== null && + 'dataPtr' in input && + input.dataPtr instanceof Uint8Array && + 'sizes' in input && + Array.isArray(input.sizes) && + input.sizes.length === 3 && + 'scalarType' in input && + input.scalarType === ScalarType.BYTE + ); +} + export abstract class BaseOCRController { protected nativeModule: any; public isReady: boolean = false; @@ -87,7 +101,34 @@ export abstract class BaseOCRController { } }; - public forward = async (imageSource: string): Promise => { + get runOnFrame(): ((frame: Frame) => OCRDetection[]) | null { + if (!this.nativeModule?.generateFromFrame) { + return null; + } + + const nativeGenerateFromFrame = this.nativeModule.generateFromFrame; + + return (frame: any): OCRDetection[] => { + 'worklet'; + + let nativeBuffer: any = null; + try { + nativeBuffer = frame.getNativeBuffer(); + const frameData = { + nativeBuffer: nativeBuffer.pointer, + }; + return nativeGenerateFromFrame(frameData); + } finally { + if (nativeBuffer?.release) { + nativeBuffer.release(); + } + } + }; + } + + public forward = async ( + input: string | PixelData + ): Promise => { if (!this.isReady) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, @@ -104,7 +145,17 @@ export abstract class BaseOCRController { try { this.isGenerating = true; this.isGeneratingCallback(this.isGenerating); - return await this.nativeModule.generate(imageSource); + + if (typeof input === 'string') { + return await this.nativeModule.generateFromString(input); + } else if (isPixelData(input)) { + return await this.nativeModule.generateFromPixels(input); + } else { + throw new RnExecutorchError( + RnExecutorchErrorCode.InvalidArgument, + 'Invalid input: expected string path or PixelData object. For VisionCamera frames, use runOnFrame instead.' + ); + } } catch (e) { throw parseUnknownError(e); } finally { diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts new file mode 100644 index 0000000000..55b8d85007 --- /dev/null +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -0,0 +1,131 @@ +import { useState, useEffect } from 'react'; +import { + ImageSegmentationModule, + SegmentationLabels, +} from '../../modules/computer_vision/ImageSegmentationModule'; +import { + ImageSegmentationProps, + ImageSegmentationType, + ModelNameOf, + ModelSources, +} from '../../types/imageSegmentation'; +import { Frame } from '../../types/common'; +import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; +import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; + +/** + * React hook for managing an Image Segmentation model instance. + * + * @typeParam C - A {@link ModelSources} config specifying which built-in model to load. + * @param props - Configuration object containing `model` config and optional `preventLoad` flag. + * @returns An object with model state (`error`, `isReady`, `isGenerating`, `downloadProgress`) and a typed `forward` function. + * + * @example + * ```ts + * const { isReady, forward } = useImageSegmentation({ + * model: { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 }, + * }); + * ``` + * + * @category Hooks + */ +export const useImageSegmentation = ({ + model, + preventLoad = false, +}: ImageSegmentationProps): ImageSegmentationType< + SegmentationLabels> +> => { + const [error, setError] = useState(null); + const [isReady, setIsReady] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); + const [instance, setInstance] = useState + > | null>(null); + const [runOnFrame, setRunOnFrame] = useState< + | (( + frame: Frame, + classesOfInterest?: string[], + resizeToInput?: boolean + ) => any) + | null + >(null); + + useEffect(() => { + if (preventLoad) return; + + let isMounted = true; + let currentInstance: ImageSegmentationModule> | null = null; + + (async () => { + setDownloadProgress(0); + setError(null); + setIsReady(false); + try { + currentInstance = await ImageSegmentationModule.fromModelName( + model, + (progress) => { + if (isMounted) setDownloadProgress(progress); + } + ); + if (isMounted) { + setInstance(currentInstance); + setIsReady(true); + const worklet = currentInstance.runOnFrame; + if (worklet) { + setRunOnFrame(() => worklet); + } + } + } catch (err) { + if (isMounted) setError(parseUnknownError(err)); + } + })(); + + return () => { + isMounted = false; + setIsReady(false); + setRunOnFrame(null); + currentInstance?.delete(); + }; + + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [model.modelName, model.modelSource, preventLoad]); + + const forward = async >>( + imageSource: string, + classesOfInterest: K[] = [], + resizeToInput: boolean = true + ) => { + if (!isReady || !instance) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling forward().' + ); + } + if (isGenerating) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModelGenerating, + 'The model is currently generating. Please wait until previous model run is complete.' + ); + } + try { + setIsGenerating(true); + return await instance.forward( + imageSource, + classesOfInterest, + resizeToInput + ); + } finally { + setIsGenerating(false); + } + }; + + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + runOnFrame, + }; +}; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts index 473d3631ba..208824b8b8 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts @@ -1,4 +1,5 @@ import { useCallback, useEffect, useState } from 'react'; +import { Frame, PixelData } from '../../types/common'; import { OCRController } from '../../controllers/OCRController'; import { RnExecutorchError } from '../../errors/errorUtils'; import { OCRDetection, OCRProps, OCRType } from '../../types/ocr'; @@ -15,6 +16,9 @@ export const useOCR = ({ model, preventLoad = false }: OCRProps): OCRType => { const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); const [error, setError] = useState(null); + const [runOnFrame, setRunOnFrame] = useState< + ((frame: Frame) => OCRDetection[]) | null + >(null); const [controller] = useState( () => @@ -38,7 +42,13 @@ export const useOCR = ({ model, preventLoad = false }: OCRProps): OCRType => { setDownloadProgress ); + const worklet = controller.runOnFrame; + if (worklet) { + setRunOnFrame(() => worklet); + } + return () => { + setRunOnFrame(null); if (controller.isReady) { controller.delete(); } @@ -53,10 +63,17 @@ export const useOCR = ({ model, preventLoad = false }: OCRProps): OCRType => { ]); const forward = useCallback( - (imageSource: string): Promise => + (imageSource: string | PixelData): Promise => controller.forward(imageSource), [controller] ); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + forward, + downloadProgress, + runOnFrame, + }; }; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts index 71774198fc..4f72c97d9c 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts @@ -1,4 +1,5 @@ import { useCallback, useEffect, useState } from 'react'; +import { Frame, PixelData } from '../../types/common'; import { VerticalOCRController } from '../../controllers/VerticalOCRController'; import { RnExecutorchError } from '../../errors/errorUtils'; import { OCRDetection, OCRType, VerticalOCRProps } from '../../types/ocr'; @@ -20,6 +21,10 @@ export const useVerticalOCR = ({ const [downloadProgress, setDownloadProgress] = useState(0); const [error, setError] = useState(null); + const [runOnFrame, setRunOnFrame] = useState< + ((frame: Frame) => OCRDetection[]) | null + >(null); + const [controller] = useState( () => new VerticalOCRController({ @@ -43,7 +48,13 @@ export const useVerticalOCR = ({ setDownloadProgress ); + const worklet = controller.runOnFrame; + if (worklet) { + setRunOnFrame(() => worklet); + } + return () => { + setRunOnFrame(null); if (controller.isReady) { controller.delete(); } @@ -59,10 +70,17 @@ export const useVerticalOCR = ({ ]); const forward = useCallback( - (imageSource: string): Promise => + (imageSource: string | PixelData): Promise => controller.forward(imageSource), [controller] ); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + forward, + downloadProgress, + runOnFrame, + }; }; diff --git a/packages/react-native-executorch/src/hooks/useModule.ts b/packages/react-native-executorch/src/hooks/useModule.ts index cc1fc1ef2e..632e94d3ea 100644 --- a/packages/react-native-executorch/src/hooks/useModule.ts +++ b/packages/react-native-executorch/src/hooks/useModule.ts @@ -76,6 +76,8 @@ export const useModule = < return () => { isMounted = false; + setIsReady(false); + setRunOnFrame(null); moduleInstance.delete(); }; diff --git a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts index 43691c2047..61d5c48c90 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts @@ -1,17 +1,19 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; -import { ResourceSource } from '../../types/common'; +import { PixelData, ResourceSource } from '../../types/common'; import { ClassificationModelName } from '../../types/classification'; -import { BaseModule } from '../BaseModule'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { Logger } from '../../common/Logger'; +import { VisionModule } from './VisionModule'; /** * Module for image classification tasks. * * @category Typescript API */ -export class ClassificationModule extends BaseModule { +export class ClassificationModule extends VisionModule<{ + [category: string]: number; +}> { private constructor(nativeModule: unknown) { super(); this.nativeModule = nativeModule; @@ -74,18 +76,9 @@ export class ClassificationModule extends BaseModule { ); } - /** - * Executes the model's forward pass to classify the provided image. - * - * @param imageSource - A string image source (file path, URI, or Base64). - * @returns A Promise resolving to an object mapping category labels to confidence scores. - */ - async forward(imageSource: string): Promise<{ [category: string]: number }> { - if (this.nativeModule == null) - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' - ); - return await this.nativeModule.generate(imageSource); + async forward( + input: string | PixelData + ): Promise<{ [category: string]: number }> { + return super.forward(input); } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts index dd81408e36..43fb79c645 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts @@ -1,22 +1,21 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; -import { ResourceSource } from '../../types/common'; import { ImageEmbeddingsModelName } from '../../types/imageEmbeddings'; +import { ResourceSource, PixelData } from '../../types/common'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; -import { BaseModule } from '../BaseModule'; import { Logger } from '../../common/Logger'; +import { VisionModule } from './VisionModule'; /** * Module for generating image embeddings from input images. * * @category Typescript API */ -export class ImageEmbeddingsModule extends BaseModule { +export class ImageEmbeddingsModule extends VisionModule { private constructor(nativeModule: unknown) { super(); this.nativeModule = nativeModule; } - /** * Creates an image embeddings instance for a built-in model. * @@ -74,18 +73,8 @@ export class ImageEmbeddingsModule extends BaseModule { ); } - /** - * Executes the model's forward pass to generate an embedding for the provided image. - * - * @param imageSource - A string image source (file path, URI, or Base64). - * @returns A Promise resolving to a `Float32Array` containing the image embedding vector. - */ - async forward(imageSource: string): Promise { - if (this.nativeModule == null) - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' - ); - return new Float32Array(await this.nativeModule.generate(imageSource)); + async forward(input: string | PixelData): Promise { + const result = await super.forward(input); + return new Float32Array(result as unknown as ArrayBuffer); } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts index d24e988930..ffc7203d8b 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts @@ -62,6 +62,20 @@ export type SegmentationLabels = type ResolveLabels = ResolveLabelsFor; +function isPixelData(input: unknown): input is PixelData { + return ( + typeof input === 'object' && + input !== null && + 'dataPtr' in input && + (input as any).dataPtr instanceof Uint8Array && + 'sizes' in input && + Array.isArray((input as any).sizes) && + (input as any).sizes.length === 3 && + 'scalarType' in input && + (input as any).scalarType === ScalarType.BYTE + ); +} + /** * Generic semantic segmentation module with type-safe label maps. * Use a model name (e.g. `'deeplab-v3-resnet50'`) as the generic parameter for built-in models, @@ -84,6 +98,75 @@ export class SemanticSegmentationModule< super(labelMap, nativeModule); } + /** + * Synchronous worklet function for real-time VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * Available after model is loaded. + * + * @example + * ```typescript + * const [runOnFrame, setRunOnFrame] = useState(null); + * setRunOnFrame(() => segmentation.runOnFrame); + * + * const frameOutput = useFrameOutput({ + * onFrame(frame) { + * 'worklet'; + * if (!runOnFrame) return; + * const result = runOnFrame(frame, [], true); + * frame.dispose(); + * } + * }); + * ``` + * + * @param frame - VisionCamera Frame object + * @param classesOfInterest - Labels for which to return per-class probability masks. + * @param resizeToInput - Whether to resize masks to original frame dimensions. Defaults to `true`. + */ + get runOnFrame(): + | (( + frame: Frame, + classesOfInterest?: string[], + resizeToInput?: boolean + ) => any) + | null { + if (!this.nativeModule?.generateFromFrame) { + return null; + } + + const nativeGenerateFromFrame = this.nativeModule.generateFromFrame; + const allClassNames = this.allClassNames; + + return ( + frame: any, + classesOfInterest: string[] = [], + resizeToInput: boolean = true + ): any => { + 'worklet'; + + let nativeBuffer: any = null; + try { + nativeBuffer = frame.getNativeBuffer(); + const frameData = { + nativeBuffer: nativeBuffer.pointer, + }; + return nativeGenerateFromFrame( + frameData, + allClassNames, + classesOfInterest, + resizeToInput + ); + } finally { + if (nativeBuffer?.release) { + nativeBuffer.release(); + } + } + }; + } + /** * Creates a segmentation instance for a built-in model. * The config object is discriminated by `modelName` — each model can require different fields. @@ -184,14 +267,20 @@ export class SemanticSegmentationModule< /** * Executes the model's forward pass to perform semantic segmentation on the provided image. * - * @param imageSource - A string representing the image source (e.g., a file path, URI, or Base64-encoded string). + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) * @param classesOfInterest - An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. * @returns A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. * @throws {RnExecutorchError} If the model is not loaded. */ async forward>( - imageSource: string, + input: string | PixelData, classesOfInterest: K[] = [], resizeToInput: boolean = true ): Promise & Record> { diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index 6027d1fd2d..4f5e82c4ec 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -1,22 +1,21 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; -import { ResourceSource } from '../../types/common'; import { StyleTransferModelName } from '../../types/styleTransfer'; +import { ResourceSource, PixelData } from '../../types/common'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; -import { BaseModule } from '../BaseModule'; import { Logger } from '../../common/Logger'; +import { VisionModule } from './VisionModule'; /** * Module for style transfer tasks. * * @category Typescript API */ -export class StyleTransferModule extends BaseModule { +export class StyleTransferModule extends VisionModule { private constructor(nativeModule: unknown) { super(); this.nativeModule = nativeModule; } - /** * Creates a style transfer instance for a built-in model. * @@ -72,18 +71,7 @@ export class StyleTransferModule extends BaseModule { ); } - /** - * Executes the model's forward pass to apply the selected style to the provided image. - * - * @param imageSource - A string image source (file path, URI, or Base64). - * @returns A Promise resolving to the stylized image as a Base64-encoded string. - */ - async forward(imageSource: string): Promise { - if (this.nativeModule == null) - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' - ); - return await this.nativeModule.generate(imageSource); + async forward(input: string | PixelData): Promise { + return super.forward(input); } } diff --git a/packages/react-native-executorch/src/types/classification.ts b/packages/react-native-executorch/src/types/classification.ts index 144f2af5ae..d38c664493 100644 --- a/packages/react-native-executorch/src/types/classification.ts +++ b/packages/react-native-executorch/src/types/classification.ts @@ -1,5 +1,5 @@ import { RnExecutorchError } from '../errors/errorUtils'; -import { ResourceSource } from './common'; +import { ResourceSource, PixelData, Frame } from './common'; /** * Union of all built-in classification model names. @@ -53,9 +53,46 @@ export interface ClassificationType { /** * Executes the model's forward pass to classify the provided image. - * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be classified. - * @returns A Promise that resolves to the classification result (typically containing labels and confidence scores). + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) + * @returns A Promise that resolves to the classification result (labels and confidence scores). * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ - forward: (imageSource: string) => Promise<{ [category: string]: number }>; + forward: ( + input: string | PixelData + ) => Promise<{ [category: string]: number }>; + + /** + * Synchronous worklet function for real-time VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * Available after model is loaded (`isReady: true`). + * + * @example + * ```typescript + * const { runOnFrame, isReady } = useClassification({ model: MODEL }); + * + * const frameOutput = useFrameOutput({ + * onFrame(frame) { + * 'worklet'; + * if (!runOnFrame) return; + * const result = runOnFrame(frame); + * frame.dispose(); + * } + * }); + * ``` + * + * @param frame - VisionCamera Frame object + * @returns Object mapping class labels to confidence scores. + */ + runOnFrame: ((frame: Frame) => { [category: string]: number }) | null; } diff --git a/packages/react-native-executorch/src/types/imageEmbeddings.ts b/packages/react-native-executorch/src/types/imageEmbeddings.ts index 88308ddd6f..2963639c26 100644 --- a/packages/react-native-executorch/src/types/imageEmbeddings.ts +++ b/packages/react-native-executorch/src/types/imageEmbeddings.ts @@ -1,5 +1,5 @@ import { RnExecutorchError } from '../errors/errorUtils'; -import { ResourceSource } from './common'; +import { ResourceSource, PixelData, Frame } from './common'; /** * Union of all built-in image embeddings model names. @@ -53,9 +53,30 @@ export interface ImageEmbeddingsType { /** * Executes the model's forward pass to generate embeddings (a feature vector) for the provided image. - * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed. + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) * @returns A Promise that resolves to a `Float32Array` containing the generated embedding vector. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ - forward: (imageSource: string) => Promise; + forward: (input: string | PixelData) => Promise; + + /** + * Synchronous worklet function for real-time VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * Available after model is loaded (`isReady: true`). + * + * @param frame - VisionCamera Frame object + * @returns Float32Array containing the embedding vector for the frame. + */ + runOnFrame: ((frame: Frame) => Float32Array) | null; } diff --git a/packages/react-native-executorch/src/types/ocr.ts b/packages/react-native-executorch/src/types/ocr.ts index e31d618478..b4cd8c8b69 100644 --- a/packages/react-native-executorch/src/types/ocr.ts +++ b/packages/react-native-executorch/src/types/ocr.ts @@ -1,6 +1,6 @@ import { symbols } from '../constants/ocr/symbols'; import { RnExecutorchError } from '../errors/errorUtils'; -import { ResourceSource } from './common'; +import { Frame, PixelData, ResourceSource } from './common'; /** * OCRDetection represents a single detected text instance in an image, @@ -110,11 +110,35 @@ export interface OCRType { /** * Executes the OCR pipeline (detection and recognition) on the provided image. - * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed. - * @returns A Promise that resolves to the OCR results (typically containing the recognized text strings and their bounding boxes). + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) + * @returns A Promise that resolves to the OCR results (recognized text and bounding boxes). * @throws {RnExecutorchError} If the models are not loaded or are currently processing another image. */ - forward: (imageSource: string) => Promise; + forward: (input: string | PixelData) => Promise; + + /** + * Synchronous worklet function for VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * **Note**: OCR is a two-stage pipeline (detection + recognition) and may not + * achieve real-time frame rates. Frames may be dropped if inference is still running. + * + * Available after model is loaded (`isReady: true`). + * + * @param frame - VisionCamera Frame object + * @returns Array of OCRDetection results for the frame. + */ + runOnFrame: ((frame: Frame) => OCRDetection[]) | null; } /** diff --git a/packages/react-native-executorch/src/types/semanticSegmentation.ts b/packages/react-native-executorch/src/types/semanticSegmentation.ts index 109f6f094e..10784721b9 100644 --- a/packages/react-native-executorch/src/types/semanticSegmentation.ts +++ b/packages/react-native-executorch/src/types/semanticSegmentation.ts @@ -1,5 +1,5 @@ import { RnExecutorchError } from '../errors/errorUtils'; -import { LabelEnum, Triple, ResourceSource } from './common'; +import { LabelEnum, Triple, ResourceSource, PixelData, Frame } from './common'; /** * Configuration for a custom semantic segmentation model. @@ -148,15 +148,44 @@ export interface SemanticSegmentationType { /** * Executes the model's forward pass to perform semantic segmentation on the provided image. - * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed. + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) * @param classesOfInterest - An optional array of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. * @returns A Promise resolving to an object with an `'ARGMAX'` `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ forward: ( - imageSource: string, + input: string | PixelData, classesOfInterest?: K[], resizeToInput?: boolean ) => Promise & Record>; + + /** + * Synchronous worklet function for real-time VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * Available after model is loaded (`isReady: true`). + * + * @param frame - VisionCamera Frame object + * @param classesOfInterest - Labels for which to return per-class probability masks. + * @param resizeToInput - Whether to resize masks to original frame dimensions. Defaults to `true`. + * @returns Object with `ARGMAX` Int32Array and per-class Float32Array masks. + */ + runOnFrame: + | (( + frame: Frame, + classesOfInterest?: string[], + resizeToInput?: boolean + ) => Record<'ARGMAX', Int32Array> & Record) + | null; } diff --git a/packages/react-native-executorch/src/types/styleTransfer.ts b/packages/react-native-executorch/src/types/styleTransfer.ts index 2571203ee1..6d48bfbab3 100644 --- a/packages/react-native-executorch/src/types/styleTransfer.ts +++ b/packages/react-native-executorch/src/types/styleTransfer.ts @@ -1,5 +1,5 @@ import { RnExecutorchError } from '../errors/errorUtils'; -import { ResourceSource } from './common'; +import { ResourceSource, PixelData, Frame } from './common'; /** * Union of all built-in style transfer model names. @@ -59,9 +59,30 @@ export interface StyleTransferType { /** * Executes the model's forward pass to apply the specific artistic style to the provided image. - * @param imageSource - A string representing the input image source (e.g., a file path, URI, or base64 string) to be stylized. - * @returns A Promise that resolves to a string containing the stylized image (typically as a base64 string or a file URI). + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) + * @returns A Promise that resolves to `PixelData` containing the stylized image as raw RGB pixel data. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ - forward: (imageSource: string) => Promise; + forward: (input: string | PixelData) => Promise; + + /** + * Synchronous worklet function for real-time VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * Available after model is loaded (`isReady: true`). + * + * @param frame - VisionCamera Frame object + * @returns PixelData containing the stylized frame as raw RGB pixel data. + */ + runOnFrame: ((frame: Frame) => PixelData) | null; } From a2fb20602bb5b3dffdb52d63204caa383cb2e8e9 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 25 Feb 2026 16:54:40 +0100 Subject: [PATCH 25/71] fix: rebase things --- .../app/object_detection/index.tsx | 168 +----------------- .../computer_vision/useImageSegmentation.ts | 4 +- .../computer_vision/ObjectDetectionModule.ts | 7 + 3 files changed, 16 insertions(+), 163 deletions(-) diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index 5dad96e0bd..a5e36c344a 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -1,72 +1,16 @@ import Spinner from '../../components/Spinner'; +import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; import { Detection, useObjectDetection, RF_DETR_NANO, } from 'react-native-executorch'; -import { View, StyleSheet, Image, TouchableOpacity, Text } from 'react-native'; +import { View, StyleSheet, Image } from 'react-native'; import ImageWithBboxes from '../../components/ImageWithBboxes'; import React, { useContext, useEffect, useState } from 'react'; import { GeneratingContext } from '../../context'; import ScreenWrapper from '../../ScreenWrapper'; -import ColorPalette from '../../colors'; -import { Images } from 'react-native-nitro-image'; - -// Helper function to convert BGRA to RGB -function convertBGRAtoRGB( - buffer: ArrayBuffer, - width: number, - height: number -): ArrayBuffer { - const source = new Uint8Array(buffer); - const rgb = new Uint8Array(width * height * 3); - - for (let i = 0; i < width * height; i++) { - // BGRA format: [B, G, R, A] → RGB: [R, G, B] - rgb[i * 3 + 0] = source[i * 4 + 2]; // R - rgb[i * 3 + 1] = source[i * 4 + 1]; // G - rgb[i * 3 + 2] = source[i * 4 + 0]; // B - } - - return rgb.buffer; -} - -// Helper function to convert image URI to raw RGB pixel data -async function imageUriToPixelData( - uri: string, - targetWidth: number, - targetHeight: number -): Promise<{ - data: ArrayBuffer; - width: number; - height: number; - channels: number; -}> { - try { - // Load image and resize to target dimensions - const image = await Images.loadFromFileAsync(uri); - const resized = image.resize(targetWidth, targetHeight); - - // Get pixel data as ArrayBuffer (BGRA format from NitroImage) - const rawPixelData = resized.toRawPixelData(); - const buffer = - rawPixelData instanceof ArrayBuffer ? rawPixelData : rawPixelData.buffer; - - // Convert BGRA to RGB as required by the native API - const rgbBuffer = convertBGRAtoRGB(buffer, targetWidth, targetHeight); - - return { - data: rgbBuffer, - width: targetWidth, - height: targetHeight, - channels: 3, // RGB - }; - } catch (error) { - console.error('Error loading image with NitroImage:', error); - throw error; - } -} export default function ObjectDetectionScreen() { const [imageUri, setImageUri] = useState(''); @@ -101,35 +45,7 @@ export default function ObjectDetectionScreen() { const output = await rfDetr.forward(imageUri); setResults(output); } catch (e) { - console.error('Error in runForward:', e); - } - } - }; - - const runForwardPixels = async () => { - if (imageUri && imageDimensions) { - try { - console.log('Converting image to pixel data...'); - // Use original dimensions - let the model resize internally - const pixelData = await imageUriToPixelData( - imageUri, - imageDimensions.width, - imageDimensions.height - ); - - console.log('Running forward with pixel data...', { - width: pixelData.width, - height: pixelData.height, - channels: pixelData.channels, - dataSize: pixelData.data.byteLength, - }); - - // Run inference using unified forward() API - const output = await ssdLite.forward(pixelData, 0.3); - console.log('Pixel data result:', output.length, 'detections'); - setResults(output); - } catch (e) { - console.error('Error in runForwardPixels:', e); + console.error(e); } } }; @@ -208,41 +124,10 @@ export default function ObjectDetectionScreen() { )} - - {/* Custom bottom bar with two buttons */} - - - handleCameraPress(false)}> - 📷 Gallery - - - - - - Run (String) - - - - Run (Pixels) - - - + ); } @@ -287,43 +172,4 @@ const styles = StyleSheet.create({ width: '100%', height: '100%', }, - bottomContainer: { - width: '100%', - gap: 15, - alignItems: 'center', - padding: 16, - flex: 1, - }, - bottomIconsContainer: { - flexDirection: 'row', - justifyContent: 'center', - width: '100%', - }, - iconText: { - fontSize: 16, - color: ColorPalette.primary, - }, - buttonsRow: { - flexDirection: 'row', - width: '100%', - gap: 10, - }, - button: { - height: 50, - justifyContent: 'center', - alignItems: 'center', - backgroundColor: ColorPalette.primary, - color: '#fff', - borderRadius: 8, - }, - halfButton: { - flex: 1, - }, - buttonDisabled: { - opacity: 0.5, - }, - buttonText: { - color: '#fff', - fontSize: 16, - }, }); diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index 55b8d85007..26a8042274 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -9,7 +9,7 @@ import { ModelNameOf, ModelSources, } from '../../types/imageSegmentation'; -import { Frame } from '../../types/common'; +import { Frame, PixelData } from '../../types/common'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; @@ -92,7 +92,7 @@ export const useImageSegmentation = ({ }, [model.modelName, model.modelSource, preventLoad]); const forward = async >>( - imageSource: string, + imageSource: string | PixelData, classesOfInterest: K[] = [], resizeToInput: boolean = true ) => { diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index c24bbd1369..bbb990f7b8 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -169,4 +169,11 @@ export class ObjectDetectionModule< nativeModule ); } + + async forward( + input: string | PixelData, + detectionThreshold: number = 0.5 + ): Promise { + return super.forward(input, detectionThreshold); + } } From 13a46e1fe352d776dbb3a4e6887f1961522d7a68 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 25 Feb 2026 19:22:16 +0100 Subject: [PATCH 26/71] chore: remove comment --- apps/computer-vision/app/vision_camera_live/index.tsx | 2 -- 1 file changed, 2 deletions(-) diff --git a/apps/computer-vision/app/vision_camera_live/index.tsx b/apps/computer-vision/app/vision_camera_live/index.tsx index 4c7b425b18..8c5d71d331 100644 --- a/apps/computer-vision/app/vision_camera_live/index.tsx +++ b/apps/computer-vision/app/vision_camera_live/index.tsx @@ -71,8 +71,6 @@ const MODELS: { id: ModelId; label: string }[] = [ { id: 'ocr', label: 'OCR' }, ]; -// ─── Segmentation colors ───────────────────────────────────────────────────── - const CLASS_COLORS: number[][] = [ [0, 0, 0, 0], [51, 255, 87, 180], From 7bcc115226bbf5f25b25eeb8db006d4140dbd782 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 26 Feb 2026 11:41:13 +0100 Subject: [PATCH 27/71] feat: add dedicated vision camera screen showcasing classification/segmentation/object detection --- apps/computer-vision/app/_layout.tsx | 57 +- .../app/classification_live/index.tsx | 255 ------ .../app/image_segmentation_live/index.tsx | 292 ------- apps/computer-vision/app/index.tsx | 12 +- .../app/object_detection_live/index.tsx | 298 ------- apps/computer-vision/app/ocr_live/index.tsx | 329 -------- .../app/style_transfer_live/index.tsx | 274 ------ .../app/vision_camera/index.tsx | 665 +++++++++++++++ .../app/vision_camera_live/index.tsx | 796 ------------------ 9 files changed, 680 insertions(+), 2298 deletions(-) delete mode 100644 apps/computer-vision/app/classification_live/index.tsx delete mode 100644 apps/computer-vision/app/image_segmentation_live/index.tsx delete mode 100644 apps/computer-vision/app/object_detection_live/index.tsx delete mode 100644 apps/computer-vision/app/ocr_live/index.tsx delete mode 100644 apps/computer-vision/app/style_transfer_live/index.tsx create mode 100644 apps/computer-vision/app/vision_camera/index.tsx delete mode 100644 apps/computer-vision/app/vision_camera_live/index.tsx diff --git a/apps/computer-vision/app/_layout.tsx b/apps/computer-vision/app/_layout.tsx index 3c7fa38ba2..cac2692f07 100644 --- a/apps/computer-vision/app/_layout.tsx +++ b/apps/computer-vision/app/_layout.tsx @@ -59,6 +59,15 @@ export default function _layout() { headerTitleStyle: { color: ColorPalette.primary }, }} > + - - - - - - { - setGlobalGenerating(isGenerating); - }, [isGenerating, setGlobalGenerating]); - - const [topLabel, setTopLabel] = useState(''); - const [topScore, setTopScore] = useState(0); - const [fps, setFps] = useState(0); - const lastFrameTimeRef = useRef(Date.now()); - - const cameraPermission = useCameraPermission(); - const devices = useCameraDevices(); - const device = devices.find((d) => d.position === 'back') ?? devices[0]; - - const format = useMemo(() => { - if (device == null) return undefined; - try { - return getCameraFormat(device, Templates.FrameProcessing); - } catch { - return undefined; - } - }, [device]); - - const updateStats = useCallback( - (result: { label: string; score: number }) => { - setTopLabel(result.label); - setTopScore(result.score); - const now = Date.now(); - const timeDiff = now - lastFrameTimeRef.current; - if (timeDiff > 0) { - setFps(Math.round(1000 / timeDiff)); - } - lastFrameTimeRef.current = now; - }, - [] - ); - - const frameOutput = useFrameOutput({ - pixelFormat: 'rgb', - onFrame(frame) { - 'worklet'; - if (!runOnFrame) { - frame.dispose(); - return; - } - try { - const result = runOnFrame(frame); - if (result) { - // find the top-1 entry - let bestLabel = ''; - let bestScore = -1; - const entries = Object.entries(result); - for (let i = 0; i < entries.length; i++) { - const [label, score] = entries[i]; - if ((score as number) > bestScore) { - bestScore = score as number; - bestLabel = label; - } - } - scheduleOnRN(updateStats, { label: bestLabel, score: bestScore }); - } - } catch { - // ignore frame errors - } finally { - frame.dispose(); - } - }, - }); - - if (!isReady) { - return ( - - ); - } - - if (!cameraPermission.hasPermission) { - return ( - - Camera access needed - cameraPermission.requestPermission()} - style={styles.button} - > - Grant Permission - - - ); - } - - if (device == null) { - return ( - - No camera device found - - ); - } - - return ( - - - - - - - - - - {topLabel || '—'} - - - {topLabel ? (topScore * 100).toFixed(1) + '%' : ''} - - - - - {fps} - fps - - - - - ); -} - -const styles = StyleSheet.create({ - container: { - flex: 1, - backgroundColor: 'black', - }, - centered: { - flex: 1, - backgroundColor: 'black', - justifyContent: 'center', - alignItems: 'center', - gap: 16, - }, - message: { - color: 'white', - fontSize: 18, - }, - button: { - paddingHorizontal: 24, - paddingVertical: 12, - backgroundColor: ColorPalette.primary, - borderRadius: 24, - }, - buttonText: { - color: 'white', - fontSize: 15, - fontWeight: '600', - letterSpacing: 0.3, - }, - bottomBarWrapper: { - position: 'absolute', - bottom: 0, - left: 0, - right: 0, - alignItems: 'center', - paddingHorizontal: 16, - }, - bottomBar: { - flexDirection: 'row', - alignItems: 'center', - backgroundColor: 'rgba(0, 0, 0, 0.55)', - borderRadius: 24, - paddingHorizontal: 28, - paddingVertical: 10, - gap: 24, - maxWidth: '100%', - }, - labelContainer: { - flex: 1, - alignItems: 'flex-start', - }, - labelText: { - color: 'white', - fontSize: 16, - fontWeight: '700', - }, - scoreText: { - color: 'rgba(255,255,255,0.7)', - fontSize: 13, - fontWeight: '500', - }, - statItem: { - alignItems: 'center', - }, - statValue: { - color: 'white', - fontSize: 22, - fontWeight: '700', - letterSpacing: -0.5, - }, - statLabel: { - color: 'rgba(255,255,255,0.55)', - fontSize: 11, - fontWeight: '500', - textTransform: 'uppercase', - letterSpacing: 0.8, - }, - statDivider: { - width: 1, - height: 32, - backgroundColor: 'rgba(255,255,255,0.2)', - }, -}); diff --git a/apps/computer-vision/app/image_segmentation_live/index.tsx b/apps/computer-vision/app/image_segmentation_live/index.tsx deleted file mode 100644 index f665c63c59..0000000000 --- a/apps/computer-vision/app/image_segmentation_live/index.tsx +++ /dev/null @@ -1,292 +0,0 @@ -import React, { - useCallback, - useContext, - useEffect, - useMemo, - useRef, - useState, -} from 'react'; -import { - StatusBar, - StyleSheet, - Text, - TouchableOpacity, - useWindowDimensions, - View, -} from 'react-native'; -import { useSafeAreaInsets } from 'react-native-safe-area-context'; - -import { - Camera, - getCameraFormat, - Templates, - useCameraDevices, - useCameraPermission, - useFrameOutput, -} from 'react-native-vision-camera'; -import { scheduleOnRN } from 'react-native-worklets'; -import { - DEEPLAB_V3_RESNET50, - useImageSegmentation, -} from 'react-native-executorch'; -import { - Canvas, - Image as SkiaImage, - Skia, - AlphaType, - ColorType, - SkImage, -} from '@shopify/react-native-skia'; -import { GeneratingContext } from '../../context'; -import Spinner from '../../components/Spinner'; -import ColorPalette from '../../colors'; - -// RGBA colors for each DeepLab V3 class (alpha = 180 for semi-transparency) -const CLASS_COLORS: number[][] = [ - [0, 0, 0, 0], // 0 background — transparent - [51, 255, 87, 180], // 1 aeroplane - [51, 87, 255, 180], // 2 bicycle - [255, 51, 246, 180], // 3 bird - [51, 255, 246, 180], // 4 boat - [243, 255, 51, 180], // 5 bottle - [141, 51, 255, 180], // 6 bus - [255, 131, 51, 180], // 7 car - [51, 255, 131, 180], // 8 cat - [131, 51, 255, 180], // 9 chair - [255, 255, 51, 180], // 10 cow - [51, 255, 255, 180], // 11 diningtable - [255, 51, 143, 180], // 12 dog - [127, 51, 255, 180], // 13 horse - [51, 255, 175, 180], // 14 motorbike - [255, 175, 51, 180], // 15 person - [179, 255, 51, 180], // 16 pottedplant - [255, 87, 51, 180], // 17 sheep - [255, 51, 162, 180], // 18 sofa - [51, 162, 255, 180], // 19 train - [162, 51, 255, 180], // 20 tvmonitor -]; - -export default function ImageSegmentationLiveScreen() { - const insets = useSafeAreaInsets(); - const { width: screenWidth, height: screenHeight } = useWindowDimensions(); - - const { isReady, isGenerating, downloadProgress, runOnFrame } = - useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); - const { setGlobalGenerating } = useContext(GeneratingContext); - - useEffect(() => { - setGlobalGenerating(isGenerating); - }, [isGenerating, setGlobalGenerating]); - - const [maskImage, setMaskImage] = useState(null); - const [fps, setFps] = useState(0); - const lastFrameTimeRef = useRef(Date.now()); - - const cameraPermission = useCameraPermission(); - const devices = useCameraDevices(); - const device = devices.find((d) => d.position === 'back') ?? devices[0]; - - const format = useMemo(() => { - if (device == null) return undefined; - try { - return getCameraFormat(device, Templates.FrameProcessing); - } catch { - return undefined; - } - }, [device]); - - const updateMask = useCallback((img: SkImage) => { - setMaskImage(img); - const now = Date.now(); - const timeDiff = now - lastFrameTimeRef.current; - if (timeDiff > 0) { - setFps(Math.round(1000 / timeDiff)); - } - lastFrameTimeRef.current = now; - }, []); - - const frameOutput = useFrameOutput({ - pixelFormat: 'rgb', - dropFramesWhileBusy: true, - onFrame(frame) { - 'worklet'; - if (!runOnFrame) { - frame.dispose(); - return; - } - try { - const result = runOnFrame(frame, [], false); - if (result?.ARGMAX) { - const argmax: Int32Array = result.ARGMAX; - // Model output is always square (modelImageSize × modelImageSize). - // Derive width/height from argmax length (sqrt for square output). - const side = Math.round(Math.sqrt(argmax.length)); - const width = side; - const height = side; - - // Build RGBA pixel buffer on the worklet thread to avoid transferring - // the large Int32Array across the worklet→RN boundary via scheduleOnRN. - const pixels = new Uint8Array(width * height * 4); - for (let i = 0; i < argmax.length; i++) { - const color = CLASS_COLORS[argmax[i]] ?? [0, 0, 0, 0]; - pixels[i * 4] = color[0]!; - pixels[i * 4 + 1] = color[1]!; - pixels[i * 4 + 2] = color[2]!; - pixels[i * 4 + 3] = color[3]!; - } - - const skData = Skia.Data.fromBytes(pixels); - const img = Skia.Image.MakeImage( - { - width, - height, - alphaType: AlphaType.Unpremul, - colorType: ColorType.RGBA_8888, - }, - skData, - width * 4 - ); - if (img) { - scheduleOnRN(updateMask, img); - } - } - } catch (e) { - console.log('frame error:', String(e)); - } finally { - frame.dispose(); - } - }, - }); - - if (!isReady) { - return ( - - ); - } - - if (!cameraPermission.hasPermission) { - return ( - - Camera access needed - cameraPermission.requestPermission()} - style={styles.button} - > - Grant Permission - - - ); - } - - if (device == null) { - return ( - - No camera device found - - ); - } - - return ( - - - - - - {maskImage && ( - - - - )} - - - - - {fps} - fps - - - - - ); -} - -const styles = StyleSheet.create({ - container: { - flex: 1, - backgroundColor: 'black', - }, - centered: { - flex: 1, - backgroundColor: 'black', - justifyContent: 'center', - alignItems: 'center', - gap: 16, - }, - message: { - color: 'white', - fontSize: 18, - }, - button: { - paddingHorizontal: 24, - paddingVertical: 12, - backgroundColor: ColorPalette.primary, - borderRadius: 24, - }, - buttonText: { - color: 'white', - fontSize: 15, - fontWeight: '600', - letterSpacing: 0.3, - }, - bottomBarWrapper: { - position: 'absolute', - bottom: 0, - left: 0, - right: 0, - alignItems: 'center', - }, - bottomBar: { - flexDirection: 'row', - alignItems: 'center', - backgroundColor: 'rgba(0, 0, 0, 0.55)', - borderRadius: 24, - paddingHorizontal: 28, - paddingVertical: 10, - gap: 24, - }, - statItem: { - alignItems: 'center', - }, - statValue: { - color: 'white', - fontSize: 22, - fontWeight: '700', - letterSpacing: -0.5, - }, - statLabel: { - color: 'rgba(255,255,255,0.55)', - fontSize: 11, - fontWeight: '500', - textTransform: 'uppercase', - letterSpacing: 0.8, - }, -}); diff --git a/apps/computer-vision/app/index.tsx b/apps/computer-vision/app/index.tsx index e2fbb6e023..9fbfd4f3ac 100644 --- a/apps/computer-vision/app/index.tsx +++ b/apps/computer-vision/app/index.tsx @@ -11,6 +11,12 @@ export default function Home() { Select a demo model + router.navigate('vision_camera/')} + > + Vision Camera + router.navigate('classification/')} @@ -29,12 +35,6 @@ export default function Home() { > Object Detection - router.navigate('object_detection_live/')} - > - Object Detection Live - router.navigate('ocr/')} diff --git a/apps/computer-vision/app/object_detection_live/index.tsx b/apps/computer-vision/app/object_detection_live/index.tsx deleted file mode 100644 index b4210b0541..0000000000 --- a/apps/computer-vision/app/object_detection_live/index.tsx +++ /dev/null @@ -1,298 +0,0 @@ -import React, { - useCallback, - useContext, - useEffect, - useMemo, - useRef, - useState, -} from 'react'; -import { - StatusBar, - StyleSheet, - Text, - TouchableOpacity, - View, -} from 'react-native'; -import { useSafeAreaInsets } from 'react-native-safe-area-context'; - -import { - Camera, - getCameraFormat, - Templates, - useCameraDevices, - useCameraPermission, - useFrameOutput, -} from 'react-native-vision-camera'; -import { scheduleOnRN } from 'react-native-worklets'; -import { - Detection, - SSDLITE_320_MOBILENET_V3_LARGE, - useObjectDetection, -} from 'react-native-executorch'; -import { GeneratingContext } from '../../context'; -import Spinner from '../../components/Spinner'; -import ColorPalette from '../../colors'; - -export default function ObjectDetectionLiveScreen() { - const insets = useSafeAreaInsets(); - const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); - - const model = useObjectDetection({ model: SSDLITE_320_MOBILENET_V3_LARGE }); - const { setGlobalGenerating } = useContext(GeneratingContext); - - useEffect(() => { - setGlobalGenerating(model.isGenerating); - }, [model.isGenerating, setGlobalGenerating]); - const [detectionCount, setDetectionCount] = useState(0); - const [fps, setFps] = useState(0); - const lastFrameTimeRef = useRef(Date.now()); - - const cameraPermission = useCameraPermission(); - const devices = useCameraDevices(); - const device = devices.find((d) => d.position === 'back') ?? devices[0]; - - const format = useMemo(() => { - if (device == null) return undefined; - try { - return getCameraFormat(device, Templates.FrameProcessing); - } catch { - return undefined; - } - }, [device]); - - const updateDetections = useCallback( - (payload: { - results: Detection[]; - imageWidth: number; - imageHeight: number; - }) => { - setDetections(payload.results); - setImageSize({ width: payload.imageWidth, height: payload.imageHeight }); - const now = Date.now(); - const timeDiff = now - lastFrameTimeRef.current; - if (timeDiff > 0) { - setFps(Math.round(1000 / timeDiff)); - } - lastFrameTimeRef.current = now; - }, - [] - ); - - const frameOutput = useFrameOutput({ - pixelFormat: 'rgb', - dropFramesWhileBusy: true, - onFrame(frame) { - 'worklet'; - if (!model.runOnFrame) { - frame.dispose(); - return; - } - // After 90° CW rotation, the image fed to the model has swapped dims. - const imageWidth = - frame.width > frame.height ? frame.height : frame.width; - const imageHeight = - frame.width > frame.height ? frame.width : frame.height; - try { - const result = model.runOnFrame(frame, 0.5); - if (result) { - scheduleOnRN(updateDetections, { - results: result, - imageWidth, - imageHeight, - }); - } - } catch { - // ignore frame errors - } finally { - frame.dispose(); - } - }, - }); - - if (!model.isReady) { - return ( - - ); - } - - if (!cameraPermission.hasPermission) { - return ( - - Camera access needed - cameraPermission.requestPermission()} - style={styles.button} - > - Grant Permission - - - ); - } - - if (device == null) { - return ( - - No camera device found - - ); - } - - return ( - - - - - - {/* Bounding box overlay — measured to match the exact camera preview area */} - - setCanvasSize({ - width: e.nativeEvent.layout.width, - height: e.nativeEvent.layout.height, - }) - } - > - {(() => { - // Cover-fit: camera preview scales to fill the canvas, cropping the - // excess. Compute the same transform so bbox pixel coords map correctly. - const scale = Math.max( - canvasSize.width / imageSize.width, - canvasSize.height / imageSize.height - ); - const offsetX = (canvasSize.width - imageSize.width * scale) / 2; - const offsetY = (canvasSize.height - imageSize.height * scale) / 2; - return detections.map((det, i) => { - const left = det.bbox.x1 * scale + offsetX; - const top = det.bbox.y1 * scale + offsetY; - const width = (det.bbox.x2 - det.bbox.x1) * scale; - const height = (det.bbox.y2 - det.bbox.y1) * scale; - return ( - - - - {det.label} {(det.score * 100).toFixed(0)}% - - - - ); - }); - })()} - - - - - - {detections.length} - objects - - - - {fps} - fps - - - - - ); -} - -const styles = StyleSheet.create({ - container: { - flex: 1, - backgroundColor: 'black', - }, - centered: { - flex: 1, - backgroundColor: 'black', - justifyContent: 'center', - alignItems: 'center', - gap: 16, - }, - message: { - color: 'white', - fontSize: 18, - }, - button: { - paddingHorizontal: 24, - paddingVertical: 12, - backgroundColor: ColorPalette.primary, - borderRadius: 24, - }, - buttonText: { - color: 'white', - fontSize: 15, - fontWeight: '600', - letterSpacing: 0.3, - }, - bbox: { - position: 'absolute', - borderWidth: 2, - borderColor: ColorPalette.primary, - borderRadius: 4, - }, - bboxLabel: { - position: 'absolute', - top: -22, - left: -2, - backgroundColor: ColorPalette.primary, - paddingHorizontal: 6, - paddingVertical: 2, - borderRadius: 4, - }, - bboxLabelText: { - color: 'white', - fontSize: 11, - fontWeight: '600', - }, - bottomBarWrapper: { - position: 'absolute', - bottom: 0, - left: 0, - right: 0, - alignItems: 'center', - }, - bottomBar: { - flexDirection: 'row', - alignItems: 'center', - backgroundColor: 'rgba(0, 0, 0, 0.55)', - borderRadius: 24, - paddingHorizontal: 28, - paddingVertical: 10, - gap: 24, - }, - statItem: { - alignItems: 'center', - }, - statValue: { - color: 'white', - fontSize: 22, - fontWeight: '700', - letterSpacing: -0.5, - }, - statLabel: { - color: 'rgba(255,255,255,0.55)', - fontSize: 11, - fontWeight: '500', - textTransform: 'uppercase', - letterSpacing: 0.8, - }, - statDivider: { - width: 1, - height: 32, - backgroundColor: 'rgba(255,255,255,0.2)', - }, -}); diff --git a/apps/computer-vision/app/ocr_live/index.tsx b/apps/computer-vision/app/ocr_live/index.tsx deleted file mode 100644 index a0c93899f6..0000000000 --- a/apps/computer-vision/app/ocr_live/index.tsx +++ /dev/null @@ -1,329 +0,0 @@ -import React, { - useCallback, - useContext, - useEffect, - useMemo, - useRef, - useState, -} from 'react'; -import { - StatusBar, - StyleSheet, - Text, - TouchableOpacity, - View, -} from 'react-native'; -import { useSafeAreaInsets } from 'react-native-safe-area-context'; - -import { - Camera, - getCameraFormat, - Templates, - useCameraDevices, - useCameraPermission, - useFrameOutput, -} from 'react-native-vision-camera'; -import { scheduleOnRN } from 'react-native-worklets'; -import { OCR_ENGLISH, useOCR, OCRDetection } from 'react-native-executorch'; -import { - Canvas, - Path, - Skia, - Text as SkiaText, - matchFont, -} from '@shopify/react-native-skia'; -import { GeneratingContext } from '../../context'; -import Spinner from '../../components/Spinner'; -import ColorPalette from '../../colors'; - -interface FrameDetections { - detections: OCRDetection[]; - frameWidth: number; - frameHeight: number; -} - -export default function OCRLiveScreen() { - const insets = useSafeAreaInsets(); - const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); - - const { isReady, isGenerating, downloadProgress, runOnFrame } = useOCR({ - model: OCR_ENGLISH, - }); - const { setGlobalGenerating } = useContext(GeneratingContext); - - useEffect(() => { - setGlobalGenerating(isGenerating); - }, [isGenerating, setGlobalGenerating]); - - const [frameDetections, setFrameDetections] = useState({ - detections: [], - frameWidth: 1, - frameHeight: 1, - }); - const [fps, setFps] = useState(0); - const lastFrameTimeRef = useRef(Date.now()); - - const font = matchFont({ fontFamily: 'Helvetica', fontSize: 11 }); - - const cameraPermission = useCameraPermission(); - const devices = useCameraDevices(); - const device = devices.find((d) => d.position === 'back') ?? devices[0]; - - const format = useMemo(() => { - if (device == null) return undefined; - try { - return getCameraFormat(device, Templates.FrameProcessing); - } catch { - return undefined; - } - }, [device]); - - const updateDetections = useCallback((result: FrameDetections) => { - setFrameDetections(result); - const now = Date.now(); - const timeDiff = now - lastFrameTimeRef.current; - if (timeDiff > 0) { - setFps(Math.round(1000 / timeDiff)); - } - lastFrameTimeRef.current = now; - }, []); - - const frameOutput = useFrameOutput({ - dropFramesWhileBusy: true, - pixelFormat: 'rgb', - onFrame(frame) { - 'worklet'; - if (!runOnFrame) { - frame.dispose(); - return; - } - const frameWidth = frame.width; - const frameHeight = frame.height; - try { - const result = runOnFrame(frame); - if (result) { - scheduleOnRN(updateDetections, { - detections: result, - frameWidth, - frameHeight, - }); - } - } catch { - // ignore frame errors - } finally { - frame.dispose(); - } - }, - }); - - if (!isReady) { - return ( - - ); - } - - if (!cameraPermission.hasPermission) { - return ( - - Camera access needed - cameraPermission.requestPermission()} - style={styles.button} - > - Grant Permission - - - ); - } - - if (device == null) { - return ( - - No camera device found - - ); - } - - const { detections, frameWidth, frameHeight } = frameDetections; - - // OCR runs on the raw landscape frame (no rotation applied in native). - // The camera preview displays it as portrait (90° CW rotation applied by iOS). - // After rotation the image dimensions become (frameHeight × frameWidth). - // Cover-fit scale uses post-rotation dims to match what the preview shows. - const isLandscape = frameWidth > frameHeight; - const imageW = isLandscape ? frameHeight : frameWidth; - const imageH = isLandscape ? frameWidth : frameHeight; - const scale = Math.max(canvasSize.width / imageW, canvasSize.height / imageH); - const offsetX = (canvasSize.width - imageW * scale) / 2; - const offsetY = (canvasSize.height - imageH * scale) / 2; - - // Map a raw landscape point to screen coords accounting for rotation + cover-fit. - function toScreenX(px: number, py: number) { - // After 90° CW: rotated_x = frameHeight - py, rotated_y = px - const rx = isLandscape ? frameHeight - py : px; - return rx * scale + offsetX; - } - function toScreenY(px: number, py: number) { - const ry = isLandscape ? px : py; - return ry * scale + offsetY; - } - - return ( - - - - - - {/* Measure the overlay area, then draw polygons inside a Canvas */} - - setCanvasSize({ - width: e.nativeEvent.layout.width, - height: e.nativeEvent.layout.height, - }) - } - > - - {detections.map((det, i) => { - if (!det.bbox || det.bbox.length < 2) return null; - - const path = Skia.Path.Make(); - path.moveTo( - toScreenX(det.bbox[0]!.x, det.bbox[0]!.y), - toScreenY(det.bbox[0]!.x, det.bbox[0]!.y) - ); - for (let j = 1; j < det.bbox.length; j++) { - path.lineTo( - toScreenX(det.bbox[j]!.x, det.bbox[j]!.y), - toScreenY(det.bbox[j]!.x, det.bbox[j]!.y) - ); - } - path.close(); - - const labelX = toScreenX(det.bbox[0]!.x, det.bbox[0]!.y); - const labelY = Math.max( - 0, - toScreenY(det.bbox[0]!.x, det.bbox[0]!.y) - 4 - ); - - return ( - - - - {font && ( - - )} - - ); - })} - - - - - - - {detections.length} - regions - - - - {fps} - fps - - - - - ); -} - -const styles = StyleSheet.create({ - container: { - flex: 1, - backgroundColor: 'black', - }, - centered: { - flex: 1, - backgroundColor: 'black', - justifyContent: 'center', - alignItems: 'center', - gap: 16, - }, - message: { - color: 'white', - fontSize: 18, - }, - button: { - paddingHorizontal: 24, - paddingVertical: 12, - backgroundColor: ColorPalette.primary, - borderRadius: 24, - }, - buttonText: { - color: 'white', - fontSize: 15, - fontWeight: '600', - letterSpacing: 0.3, - }, - bottomBarWrapper: { - position: 'absolute', - bottom: 0, - left: 0, - right: 0, - alignItems: 'center', - }, - bottomBar: { - flexDirection: 'row', - alignItems: 'center', - backgroundColor: 'rgba(0, 0, 0, 0.55)', - borderRadius: 24, - paddingHorizontal: 28, - paddingVertical: 10, - gap: 24, - }, - statItem: { - alignItems: 'center', - }, - statValue: { - color: 'white', - fontSize: 22, - fontWeight: '700', - letterSpacing: -0.5, - }, - statLabel: { - color: 'rgba(255,255,255,0.55)', - fontSize: 11, - fontWeight: '500', - textTransform: 'uppercase', - letterSpacing: 0.8, - }, - statDivider: { - width: 1, - height: 32, - backgroundColor: 'rgba(255,255,255,0.2)', - }, -}); diff --git a/apps/computer-vision/app/style_transfer_live/index.tsx b/apps/computer-vision/app/style_transfer_live/index.tsx deleted file mode 100644 index 57889313f8..0000000000 --- a/apps/computer-vision/app/style_transfer_live/index.tsx +++ /dev/null @@ -1,274 +0,0 @@ -import React, { - useCallback, - useContext, - useEffect, - useMemo, - useRef, - useState, -} from 'react'; -import { - StatusBar, - StyleSheet, - Text, - TouchableOpacity, - useWindowDimensions, - View, -} from 'react-native'; -import { useSafeAreaInsets } from 'react-native-safe-area-context'; - -import { - Camera, - getCameraFormat, - Templates, - useCameraDevices, - useCameraPermission, - useFrameOutput, -} from 'react-native-vision-camera'; -import { scheduleOnRN } from 'react-native-worklets'; -import { - STYLE_TRANSFER_RAIN_PRINCESS, - useStyleTransfer, -} from 'react-native-executorch'; -import { - Canvas, - Image as SkiaImage, - Skia, - AlphaType, - ColorType, - SkImage, -} from '@shopify/react-native-skia'; -import { GeneratingContext } from '../../context'; -import Spinner from '../../components/Spinner'; -import ColorPalette from '../../colors'; - -export default function StyleTransferLiveScreen() { - const insets = useSafeAreaInsets(); - const { width: screenWidth, height: screenHeight } = useWindowDimensions(); - - const { isReady, isGenerating, downloadProgress, runOnFrame } = - useStyleTransfer({ model: STYLE_TRANSFER_RAIN_PRINCESS }); - const { setGlobalGenerating } = useContext(GeneratingContext); - - useEffect(() => { - setGlobalGenerating(isGenerating); - }, [isGenerating, setGlobalGenerating]); - - const [styledImage, setStyledImage] = useState(null); - const [fps, setFps] = useState(0); - const lastFrameTimeRef = useRef(Date.now()); - - const cameraPermission = useCameraPermission(); - const devices = useCameraDevices(); - const device = devices.find((d) => d.position === 'back') ?? devices[0]; - - const format = useMemo(() => { - if (device == null) return undefined; - try { - return getCameraFormat(device, Templates.FrameProcessing); - } catch { - return undefined; - } - }, [device]); - - const updateImage = useCallback((img: SkImage) => { - setStyledImage((prev) => { - prev?.dispose(); - return img; - }); - const now = Date.now(); - const timeDiff = now - lastFrameTimeRef.current; - if (timeDiff > 0) { - setFps(Math.round(1000 / timeDiff)); - } - lastFrameTimeRef.current = now; - }, []); - - const frameOutput = useFrameOutput({ - pixelFormat: 'rgb', - dropFramesWhileBusy: true, - onFrame(frame) { - 'worklet'; - if (!runOnFrame) { - frame.dispose(); - return; - } - try { - const result = runOnFrame(frame); - if (result?.dataPtr) { - const { dataPtr, sizes } = result; - const height = sizes[0]; - const width = sizes[1]; - // Build Skia image on the worklet thread — avoids transferring the - // large pixel buffer across the worklet→RN boundary via scheduleOnRN. - const skData = Skia.Data.fromBytes(dataPtr); - const img = Skia.Image.MakeImage( - { - width, - height, - alphaType: AlphaType.Opaque, - colorType: ColorType.RGBA_8888, - }, - skData, - width * 4 - ); - if (img) { - scheduleOnRN(updateImage, img); - } - } - } catch (e) { - console.log('frame error:', String(e)); - } finally { - frame.dispose(); - } - }, - }); - - if (!isReady) { - return ( - - ); - } - - if (!cameraPermission.hasPermission) { - return ( - - Camera access needed - cameraPermission.requestPermission()} - style={styles.button} - > - Grant Permission - - - ); - } - - if (device == null) { - return ( - - No camera device found - - ); - } - - return ( - - - - {/* Camera always runs to keep frame processing active */} - - - {/* Styled output overlays the camera feed once available */} - {styledImage && ( - - - - )} - - - - - {fps} - fps - - - - candy - style - - - - - ); -} - -const styles = StyleSheet.create({ - container: { - flex: 1, - backgroundColor: 'black', - }, - centered: { - flex: 1, - backgroundColor: 'black', - justifyContent: 'center', - alignItems: 'center', - gap: 16, - }, - message: { - color: 'white', - fontSize: 18, - }, - button: { - paddingHorizontal: 24, - paddingVertical: 12, - backgroundColor: ColorPalette.primary, - borderRadius: 24, - }, - buttonText: { - color: 'white', - fontSize: 15, - fontWeight: '600', - letterSpacing: 0.3, - }, - bottomBarWrapper: { - position: 'absolute', - bottom: 0, - left: 0, - right: 0, - alignItems: 'center', - }, - bottomBar: { - flexDirection: 'row', - alignItems: 'center', - backgroundColor: 'rgba(0, 0, 0, 0.55)', - borderRadius: 24, - paddingHorizontal: 28, - paddingVertical: 10, - gap: 24, - }, - statItem: { - alignItems: 'center', - }, - statValue: { - color: 'white', - fontSize: 22, - fontWeight: '700', - letterSpacing: -0.5, - }, - styleLabel: { - color: 'white', - fontSize: 16, - fontWeight: '700', - }, - statLabel: { - color: 'rgba(255,255,255,0.55)', - fontSize: 11, - fontWeight: '500', - textTransform: 'uppercase', - letterSpacing: 0.8, - }, - statDivider: { - width: 1, - height: 32, - backgroundColor: 'rgba(255,255,255,0.2)', - }, -}); diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx new file mode 100644 index 0000000000..6250188498 --- /dev/null +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -0,0 +1,665 @@ +import React, { + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from 'react'; +import { + ScrollView, + StatusBar, + StyleSheet, + Text, + TouchableOpacity, + View, +} from 'react-native'; +import { useSafeAreaInsets } from 'react-native-safe-area-context'; +import { + Camera, + Frame, + getCameraFormat, + Templates, + useCameraDevices, + useCameraPermission, + useFrameOutput, +} from 'react-native-vision-camera'; +import { createSynchronizable, scheduleOnRN } from 'react-native-worklets'; +import { + DEEPLAB_V3_RESNET50, + Detection, + EFFICIENTNET_V2_S, + SSDLITE_320_MOBILENET_V3_LARGE, + useClassification, + useImageSegmentation, + useObjectDetection, +} from 'react-native-executorch'; +import { + AlphaType, + Canvas, + ColorType, + Image as SkiaImage, + Skia, + SkImage, +} from '@shopify/react-native-skia'; +import { GeneratingContext } from '../../context'; +import Spinner from '../../components/Spinner'; +import ColorPalette from '../../colors'; + +type TaskId = 'classification' | 'objectDetection' | 'segmentation'; +type ModelId = 'classification' | 'objectDetection' | 'segmentation'; + +type TaskVariant = { id: ModelId; label: string }; +type Task = { id: TaskId; label: string; variants: TaskVariant[] }; + +const TASKS: Task[] = [ + { + id: 'classification', + label: 'Classify', + variants: [{ id: 'classification', label: 'EfficientNet V2 S' }], + }, + { + id: 'segmentation', + label: 'Segment', + variants: [{ id: 'segmentation', label: 'DeepLab V3' }], + }, + { + id: 'objectDetection', + label: 'Detect', + variants: [{ id: 'objectDetection', label: 'SSDLite MobileNet' }], + }, +]; + +const CLASS_COLORS: number[][] = [ + [0, 0, 0, 0], + [51, 255, 87, 180], + [51, 87, 255, 180], + [255, 51, 246, 180], + [51, 255, 246, 180], + [243, 255, 51, 180], + [141, 51, 255, 180], + [255, 131, 51, 180], + [51, 255, 131, 180], + [131, 51, 255, 180], + [255, 255, 51, 180], + [51, 255, 255, 180], + [255, 51, 143, 180], + [127, 51, 255, 180], + [51, 255, 175, 180], + [255, 175, 51, 180], + [179, 255, 51, 180], + [255, 87, 51, 180], + [255, 51, 162, 180], + [51, 162, 255, 180], + [162, 51, 255, 180], +]; + +function hashLabel(label: string): number { + let hash = 5381; + for (let i = 0; i < label.length; i++) { + hash = (hash + hash * 32 + label.charCodeAt(i)) % 1000003; + } + return 1 + (Math.abs(hash) % (CLASS_COLORS.length - 1)); +} + +function labelColor(label: string): string { + const color = CLASS_COLORS[hashLabel(label)]!; + return `rgba(${color[0]},${color[1]},${color[2]},1)`; +} + +function labelColorBg(label: string): string { + const color = CLASS_COLORS[hashLabel(label)]!; + return `rgba(${color[0]},${color[1]},${color[2]},0.75)`; +} + +const frameKillSwitch = createSynchronizable(false); + +export default function VisionCameraScreen() { + const insets = useSafeAreaInsets(); + const [activeTask, setActiveTask] = useState('classification'); + const [activeModel, setActiveModel] = useState('classification'); + const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); + const { setGlobalGenerating } = useContext(GeneratingContext); + + const classification = useClassification({ + model: EFFICIENTNET_V2_S, + preventLoad: activeModel !== 'classification', + }); + const objectDetection = useObjectDetection({ + model: SSDLITE_320_MOBILENET_V3_LARGE, + preventLoad: activeModel !== 'objectDetection', + }); + const segmentation = useImageSegmentation({ + model: DEEPLAB_V3_RESNET50, + preventLoad: activeModel !== 'segmentation', + }); + + const activeIsGenerating = { + classification: classification.isGenerating, + objectDetection: objectDetection.isGenerating, + segmentation: segmentation.isGenerating, + }[activeModel]; + + useEffect(() => { + setGlobalGenerating(activeIsGenerating); + }, [activeIsGenerating, setGlobalGenerating]); + + const [fps, setFps] = useState(0); + const [frameMs, setFrameMs] = useState(0); + const lastFrameTimeRef = useRef(Date.now()); + const cameraPermission = useCameraPermission(); + const devices = useCameraDevices(); + const device = devices.find((d) => d.position === 'back') ?? devices[0]; + const format = useMemo(() => { + if (device == null) return undefined; + try { + return getCameraFormat(device, Templates.FrameProcessing); + } catch { + return undefined; + } + }, [device]); + + const [classResult, setClassResult] = useState({ label: '', score: 0 }); + const [detections, setDetections] = useState([]); + const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); + const [maskImage, setMaskImage] = useState(null); + + const updateClass = useCallback((r: { label: string; score: number }) => { + setClassResult(r); + const now = Date.now(); + const diff = now - lastFrameTimeRef.current; + if (diff > 0) { + setFps(Math.round(1000 / diff)); + setFrameMs(diff); + } + lastFrameTimeRef.current = now; + }, []); + + const updateFps = useCallback(() => { + const now = Date.now(); + const diff = now - lastFrameTimeRef.current; + if (diff > 0) { + setFps(Math.round(1000 / diff)); + setFrameMs(diff); + } + lastFrameTimeRef.current = now; + }, []); + + const updateDetections = useCallback( + (p: { results: Detection[]; imageWidth: number; imageHeight: number }) => { + setDetections(p.results); + setImageSize({ width: p.imageWidth, height: p.imageHeight }); + updateFps(); + }, + [updateFps] + ); + + const updateMask = useCallback( + (img: SkImage) => { + setMaskImage((prev) => { + prev?.dispose(); + return img; + }); + updateFps(); + }, + [updateFps] + ); + + const classRof = classification.runOnFrame; + const detRof = objectDetection.runOnFrame; + const segRof = segmentation.runOnFrame; + + useEffect(() => { + frameKillSwitch.setBlocking(true); + setMaskImage((prev) => { + prev?.dispose(); + return null; + }); + const id = setTimeout(() => { + frameKillSwitch.setBlocking(false); + }, 300); + return () => clearTimeout(id); + }, [activeModel]); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + + if (frameKillSwitch.getDirty()) { + frame.dispose(); + return; + } + + try { + if (activeModel === 'classification') { + if (!classRof) return; + const result = classRof(frame); + if (result) { + let bestLabel = ''; + let bestScore = -1; + const entries = Object.entries(result); + for (let i = 0; i < entries.length; i++) { + const [label, score] = entries[i]!; + if ((score as number) > bestScore) { + bestScore = score as number; + bestLabel = label; + } + } + scheduleOnRN(updateClass, { label: bestLabel, score: bestScore }); + } + } else if (activeModel === 'objectDetection') { + if (!detRof) return; + const iw = frame.width > frame.height ? frame.height : frame.width; + const ih = frame.width > frame.height ? frame.width : frame.height; + const result = detRof(frame, 0.5); + if (result) { + scheduleOnRN(updateDetections, { + results: result, + imageWidth: iw, + imageHeight: ih, + }); + } + } else if (activeModel === 'segmentation') { + if (!segRof) return; + const result = segRof(frame, [], false); + if (result?.ARGMAX) { + const argmax: Int32Array = result.ARGMAX; + const side = Math.round(Math.sqrt(argmax.length)); + const pixels = new Uint8Array(side * side * 4); + for (let i = 0; i < argmax.length; i++) { + const color = CLASS_COLORS[argmax[i]!] ?? [0, 0, 0, 0]; + pixels[i * 4] = color[0]!; + pixels[i * 4 + 1] = color[1]!; + pixels[i * 4 + 2] = color[2]!; + pixels[i * 4 + 3] = color[3]!; + } + const skData = Skia.Data.fromBytes(pixels); + const img = Skia.Image.MakeImage( + { + width: side, + height: side, + alphaType: AlphaType.Unpremul, + colorType: ColorType.RGBA_8888, + }, + skData, + side * 4 + ); + if (img) scheduleOnRN(updateMask, img); + } + } + } catch { + // ignore + } finally { + frame.dispose(); + } + }, + [ + activeModel, + classRof, + detRof, + segRof, + updateClass, + updateDetections, + updateMask, + ] + ), + }); + + const activeIsReady = { + classification: classification.isReady, + objectDetection: objectDetection.isReady, + segmentation: segmentation.isReady, + }[activeModel]; + + const activeDownloadProgress = { + classification: classification.downloadProgress, + objectDetection: objectDetection.downloadProgress, + segmentation: segmentation.downloadProgress, + }[activeModel]; + + if (!cameraPermission.hasPermission) { + return ( + + Camera access needed + cameraPermission.requestPermission()} + style={styles.button} + > + Grant Permission + + + ); + } + + if (device == null) { + return ( + + No camera device found + + ); + } + + function coverFit(imgW: number, imgH: number) { + const scale = Math.max(canvasSize.width / imgW, canvasSize.height / imgH); + return { + scale, + offsetX: (canvasSize.width - imgW * scale) / 2, + offsetY: (canvasSize.height - imgH * scale) / 2, + }; + } + + const { + scale: detScale, + offsetX: detOX, + offsetY: detOY, + } = coverFit(imageSize.width, imageSize.height); + + const activeTaskInfo = TASKS.find((t) => t.id === activeTask)!; + const activeVariantLabel = + activeTaskInfo.variants.find((v) => v.id === activeModel)?.label ?? + activeTaskInfo.variants[0]!.label; + + return ( + + + + + + + setCanvasSize({ + width: e.nativeEvent.layout.width, + height: e.nativeEvent.layout.height, + }) + } + > + {activeModel === 'segmentation' && maskImage && ( + + + + )} + + {activeModel === 'objectDetection' && ( + <> + {detections.map((det, i) => { + const left = det.bbox.x1 * detScale + detOX; + const top = det.bbox.y1 * detScale + detOY; + const w = (det.bbox.x2 - det.bbox.x1) * detScale; + const h = (det.bbox.y2 - det.bbox.y1) * detScale; + return ( + + + + {det.label} {(det.score * 100).toFixed(1)} + + + + ); + })} + + )} + + + {activeModel === 'classification' && classResult.label ? ( + + {classResult.label} + + {(classResult.score * 100).toFixed(1)}% + + + ) : null} + + {!activeIsReady && ( + + + + )} + + + + {activeVariantLabel} + + {fps} FPS – {frameMs.toFixed(0)} ms + + + + + {TASKS.map((t) => ( + { + setActiveTask(t.id); + setActiveModel(t.variants[0]!.id); + }} + > + + {t.label} + + + ))} + + + + {activeTaskInfo.variants.map((v) => ( + setActiveModel(v.id)} + > + + {v.label} + + + ))} + + + + ); +} + +const styles = StyleSheet.create({ + container: { flex: 1, backgroundColor: 'black' }, + centered: { + flex: 1, + backgroundColor: 'black', + justifyContent: 'center', + alignItems: 'center', + gap: 16, + }, + message: { color: 'white', fontSize: 18 }, + button: { + paddingHorizontal: 24, + paddingVertical: 12, + backgroundColor: ColorPalette.primary, + borderRadius: 24, + }, + buttonText: { color: 'white', fontSize: 15, fontWeight: '600' }, + loadingOverlay: { + ...StyleSheet.absoluteFillObject, + backgroundColor: 'rgba(0,0,0,0.6)', + justifyContent: 'center', + alignItems: 'center', + }, + + topOverlay: { + position: 'absolute', + top: 0, + left: 0, + right: 0, + alignItems: 'center', + gap: 8, + }, + titleRow: { + alignItems: 'center', + paddingHorizontal: 16, + }, + modelTitle: { + color: 'white', + fontSize: 22, + fontWeight: '700', + textShadowColor: 'rgba(0,0,0,0.7)', + textShadowOffset: { width: 0, height: 1 }, + textShadowRadius: 4, + }, + fpsText: { + color: 'rgba(255,255,255,0.85)', + fontSize: 14, + fontWeight: '500', + marginTop: 2, + textShadowColor: 'rgba(0,0,0,0.7)', + textShadowOffset: { width: 0, height: 1 }, + textShadowRadius: 4, + }, + + tabsContent: { + paddingHorizontal: 12, + gap: 6, + }, + tab: { + paddingHorizontal: 18, + paddingVertical: 7, + borderRadius: 20, + backgroundColor: 'rgba(0,0,0,0.45)', + borderWidth: 1, + borderColor: 'rgba(255,255,255,0.25)', + }, + tabActive: { + backgroundColor: 'rgba(255,255,255,0.2)', + borderColor: 'white', + }, + tabText: { + color: 'rgba(255,255,255,0.7)', + fontSize: 14, + fontWeight: '600', + }, + tabTextActive: { color: 'white' }, + + chipsContent: { + paddingHorizontal: 12, + gap: 6, + }, + variantChip: { + paddingHorizontal: 14, + paddingVertical: 5, + borderRadius: 16, + backgroundColor: 'rgba(0,0,0,0.35)', + borderWidth: 1, + borderColor: 'rgba(255,255,255,0.15)', + }, + variantChipActive: { + backgroundColor: ColorPalette.primary, + borderColor: ColorPalette.primary, + }, + variantChipText: { + color: 'rgba(255,255,255,0.6)', + fontSize: 12, + fontWeight: '500', + }, + variantChipTextActive: { color: 'white' }, + + bbox: { + position: 'absolute', + borderWidth: 2, + borderColor: 'cyan', + borderRadius: 4, + }, + bboxLabel: { + position: 'absolute', + top: -22, + left: -2, + paddingHorizontal: 6, + paddingVertical: 2, + borderRadius: 4, + }, + bboxLabelText: { color: 'white', fontSize: 11, fontWeight: '600' }, + + classResultOverlay: { + ...StyleSheet.absoluteFillObject, + justifyContent: 'center', + alignItems: 'center', + }, + classResultLabel: { + color: 'white', + fontSize: 28, + fontWeight: '700', + textAlign: 'center', + textShadowColor: 'rgba(0,0,0,0.8)', + textShadowOffset: { width: 0, height: 1 }, + textShadowRadius: 6, + paddingHorizontal: 24, + }, + classResultScore: { + color: 'rgba(255,255,255,0.75)', + fontSize: 18, + fontWeight: '500', + marginTop: 4, + textShadowColor: 'rgba(0,0,0,0.8)', + textShadowOffset: { width: 0, height: 1 }, + textShadowRadius: 6, + }, +}); diff --git a/apps/computer-vision/app/vision_camera_live/index.tsx b/apps/computer-vision/app/vision_camera_live/index.tsx deleted file mode 100644 index 8c5d71d331..0000000000 --- a/apps/computer-vision/app/vision_camera_live/index.tsx +++ /dev/null @@ -1,796 +0,0 @@ -import React, { - useCallback, - useContext, - useEffect, - useMemo, - useRef, - useState, -} from 'react'; -import { - ScrollView, - StatusBar, - StyleSheet, - Text, - TouchableOpacity, - View, -} from 'react-native'; -import { useSafeAreaInsets } from 'react-native-safe-area-context'; -import { - Camera, - Frame, - getCameraFormat, - Templates, - useCameraDevices, - useCameraPermission, - useFrameOutput, -} from 'react-native-vision-camera'; -import { createSynchronizable, runOnJS } from 'react-native-worklets'; -import { - DEEPLAB_V3_RESNET50, - Detection, - EFFICIENTNET_V2_S, - OCRDetection, - OCR_ENGLISH, - SSDLITE_320_MOBILENET_V3_LARGE, - STYLE_TRANSFER_RAIN_PRINCESS, - useClassification, - useImageSegmentation, - useObjectDetection, - useOCR, - useStyleTransfer, -} from 'react-native-executorch'; -import { - AlphaType, - Canvas, - ColorType, - Image as SkiaImage, - matchFont, - Path, - Skia, - SkImage, - Text as SkiaText, -} from '@shopify/react-native-skia'; -import { GeneratingContext } from '../../context'; -import Spinner from '../../components/Spinner'; -import ColorPalette from '../../colors'; - -// ─── Model IDs ─────────────────────────────────────────────────────────────── - -type ModelId = - | 'classification' - | 'object_detection' - | 'segmentation' - | 'style_transfer' - | 'ocr'; - -const MODELS: { id: ModelId; label: string }[] = [ - { id: 'classification', label: 'Classification' }, - { id: 'object_detection', label: 'Object Detection' }, - { id: 'segmentation', label: 'Segmentation' }, - { id: 'style_transfer', label: 'Style Transfer' }, - { id: 'ocr', label: 'OCR' }, -]; - -const CLASS_COLORS: number[][] = [ - [0, 0, 0, 0], - [51, 255, 87, 180], - [51, 87, 255, 180], - [255, 51, 246, 180], - [51, 255, 246, 180], - [243, 255, 51, 180], - [141, 51, 255, 180], - [255, 131, 51, 180], - [51, 255, 131, 180], - [131, 51, 255, 180], - [255, 255, 51, 180], - [51, 255, 255, 180], - [255, 51, 143, 180], - [127, 51, 255, 180], - [51, 255, 175, 180], - [255, 175, 51, 180], - [179, 255, 51, 180], - [255, 87, 51, 180], - [255, 51, 162, 180], - [51, 162, 255, 180], - [162, 51, 255, 180], -]; - -// ─── Kill switch — synchronizable boolean shared between JS and worklet thread. -// setBlocking(true) immediately stops the worklet from dispatching new work -// (both in onFrame and inside the async callback) before the old model tears down. -const frameKillSwitch = createSynchronizable(false); - -// ─── Screen ────────────────────────────────────────────────────────────────── - -export default function VisionCameraLiveScreen() { - const insets = useSafeAreaInsets(); - const [activeModel, setActiveModel] = useState('classification'); - const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); - const { setGlobalGenerating } = useContext(GeneratingContext); - - // ── Models (only the active model loads; others are prevented) ── - const classification = useClassification({ - model: EFFICIENTNET_V2_S, - preventLoad: activeModel !== 'classification', - }); - const objectDetection = useObjectDetection({ - model: SSDLITE_320_MOBILENET_V3_LARGE, - preventLoad: activeModel !== 'object_detection', - }); - const segmentation = useImageSegmentation({ - model: DEEPLAB_V3_RESNET50, - preventLoad: activeModel !== 'segmentation', - }); - const styleTransfer = useStyleTransfer({ - model: STYLE_TRANSFER_RAIN_PRINCESS, - preventLoad: activeModel !== 'style_transfer', - }); - const ocr = useOCR({ - model: OCR_ENGLISH, - preventLoad: activeModel !== 'ocr', - }); - - const activeIsGenerating = { - classification: classification.isGenerating, - object_detection: objectDetection.isGenerating, - segmentation: segmentation.isGenerating, - style_transfer: styleTransfer.isGenerating, - ocr: ocr.isGenerating, - }[activeModel]; - - useEffect(() => { - setGlobalGenerating(activeIsGenerating); - }, [activeIsGenerating, setGlobalGenerating]); - - // ── Camera ── - const [fps, setFps] = useState(0); - const lastFrameTimeRef = useRef(Date.now()); - const cameraPermission = useCameraPermission(); - const devices = useCameraDevices(); - const device = devices.find((d) => d.position === 'back') ?? devices[0]; - const format = useMemo(() => { - if (device == null) return undefined; - try { - return getCameraFormat(device, Templates.FrameProcessing); - } catch { - return undefined; - } - }, [device]); - - // ── Per-model result state ── - const [classResult, setClassResult] = useState({ label: '', score: 0 }); - const [detections, setDetections] = useState([]); - const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); - const [maskImage, setMaskImage] = useState(null); - const [styledImage, setStyledImage] = useState(null); - const [ocrData, setOcrData] = useState<{ - detections: OCRDetection[]; - frameWidth: number; - frameHeight: number; - }>({ detections: [], frameWidth: 1, frameHeight: 1 }); - - // ── Stable callbacks ── - function tick() { - const now = Date.now(); - const diff = now - lastFrameTimeRef.current; - if (diff > 0) setFps(Math.round(1000 / diff)); - lastFrameTimeRef.current = now; - } - - const updateClass = useCallback((r: { label: string; score: number }) => { - setClassResult(r); - tick(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); - - const updateDetections = useCallback( - (p: { results: Detection[]; imageWidth: number; imageHeight: number }) => { - setDetections(p.results); - setImageSize({ width: p.imageWidth, height: p.imageHeight }); - tick(); - }, - // eslint-disable-next-line react-hooks/exhaustive-deps - [] - ); - - const updateMask = useCallback((img: SkImage) => { - setMaskImage((prev) => { - prev?.dispose(); - return img; - }); - tick(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); - - const updateStyled = useCallback((img: SkImage) => { - setStyledImage((prev) => { - prev?.dispose(); - return img; - }); - tick(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); - - const updateOcr = useCallback( - (d: { - detections: OCRDetection[]; - frameWidth: number; - frameHeight: number; - }) => { - setOcrData(d); - tick(); - }, - // eslint-disable-next-line react-hooks/exhaustive-deps - [] - ); - - // ── runOnJS-wrapped callbacks — created on the RN thread so the Babel plugin - // can serialize them into remote functions. These can then be safely called - // from any worklet runtime, including the asyncRunner's worker runtime. - const notifyClass = runOnJS(updateClass); - const notifyDetections = runOnJS(updateDetections); - const notifyMask = runOnJS(updateMask); - const notifyStyled = runOnJS(updateStyled); - const notifyOcr = runOnJS(updateOcr); - - // ── Pull the active model's runOnFrame out of the hook each render. - // These are worklet functions (not plain JS objects), so they CAN be - // captured directly in a useCallback closure — the worklets runtime - // serializes them correctly. A new closure is produced whenever the - // active runOnFrame changes, causing useFrameOutput to re-register. - const classRof = classification.runOnFrame; - const detRof = objectDetection.runOnFrame; - const segRof = segmentation.runOnFrame; - const stRof = styleTransfer.runOnFrame; - const ocrRof = ocr.runOnFrame; - - // When switching models: activate kill switch synchronously so the worklet - // thread stops calling runOnFrame before delete() fires on the old model. - // Then re-enable once the new model's preventLoad has taken effect. - useEffect(() => { - frameKillSwitch.setBlocking(true); - setMaskImage((prev) => { - prev?.dispose(); - return null; - }); - setStyledImage((prev) => { - prev?.dispose(); - return null; - }); - const id = setTimeout(() => { - frameKillSwitch.setBlocking(false); - }, 300); - return () => clearTimeout(id); - }, [activeModel]); - - // ── Single frame output. - // onFrame is re-created (and re-registered by useFrameOutput) whenever the - // active model or its runOnFrame worklet changes. The kill switch provides - // synchronous cross-thread protection during the transition window. - const frameOutput = useFrameOutput({ - pixelFormat: 'rgb', - dropFramesWhileBusy: true, - onFrame: useCallback( - (frame: Frame) => { - 'worklet'; - - // Kill switch is set synchronously from JS when switching models — - // guaranteed visible here before the next frame is dispatched. - if (frameKillSwitch.getDirty()) { - frame.dispose(); - return; - } - - try { - if (activeModel === 'classification') { - if (!classRof) return; - const result = classRof(frame); - if (result) { - let bestLabel = ''; - let bestScore = -1; - const entries = Object.entries(result); - for (let i = 0; i < entries.length; i++) { - const [label, score] = entries[i]!; - if ((score as number) > bestScore) { - bestScore = score as number; - bestLabel = label; - } - } - notifyClass({ - label: bestLabel, - score: bestScore, - }); - } - } else if (activeModel === 'object_detection') { - if (!detRof) return; - const iw = frame.width > frame.height ? frame.height : frame.width; - const ih = frame.width > frame.height ? frame.width : frame.height; - const result = detRof(frame, 0.5); - if (result) { - notifyDetections({ - results: result, - imageWidth: iw, - imageHeight: ih, - }); - } - } else if (activeModel === 'segmentation') { - if (!segRof) return; - const result = segRof(frame, [], false); - if (result?.ARGMAX) { - const argmax: Int32Array = result.ARGMAX; - const side = Math.round(Math.sqrt(argmax.length)); - const pixels = new Uint8Array(side * side * 4); - for (let i = 0; i < argmax.length; i++) { - const color = CLASS_COLORS[argmax[i]!] ?? [0, 0, 0, 0]; - pixels[i * 4] = color[0]!; - pixels[i * 4 + 1] = color[1]!; - pixels[i * 4 + 2] = color[2]!; - pixels[i * 4 + 3] = color[3]!; - } - const skData = Skia.Data.fromBytes(pixels); - const img = Skia.Image.MakeImage( - { - width: side, - height: side, - alphaType: AlphaType.Unpremul, - colorType: ColorType.RGBA_8888, - }, - skData, - side * 4 - ); - if (img) notifyMask(img); - } - } else if (activeModel === 'style_transfer') { - if (!stRof) return; - const result = stRof(frame); - if (result?.dataPtr) { - const { dataPtr, sizes } = result; - const h = sizes[0]!; - const w = sizes[1]!; - const skData = Skia.Data.fromBytes(dataPtr); - const img = Skia.Image.MakeImage( - { - width: w, - height: h, - alphaType: AlphaType.Opaque, - colorType: ColorType.RGBA_8888, - }, - skData, - w * 4 - ); - if (img) notifyStyled(img); - } - } else if (activeModel === 'ocr') { - if (!ocrRof) return; - const fw = frame.width; - const fh = frame.height; - const result = ocrRof(frame); - if (result) { - notifyOcr({ - detections: result, - frameWidth: fw, - frameHeight: fh, - }); - } - } - } catch { - // ignore - } finally { - frame.dispose(); - } - }, - [ - activeModel, - classRof, - detRof, - segRof, - stRof, - ocrRof, - notifyClass, - notifyDetections, - notifyMask, - notifyStyled, - notifyOcr, - ] - ), - }); - - // ── Loading state: only care about the active model ── - const activeIsReady = { - classification: classification.isReady, - object_detection: objectDetection.isReady, - segmentation: segmentation.isReady, - style_transfer: styleTransfer.isReady, - ocr: ocr.isReady, - }[activeModel]; - - const activeDownloadProgress = { - classification: classification.downloadProgress, - object_detection: objectDetection.downloadProgress, - segmentation: segmentation.downloadProgress, - style_transfer: styleTransfer.downloadProgress, - ocr: ocr.downloadProgress, - }[activeModel]; - - if (!cameraPermission.hasPermission) { - return ( - - Camera access needed - cameraPermission.requestPermission()} - style={styles.button} - > - Grant Permission - - - ); - } - - if (device == null) { - return ( - - No camera device found - - ); - } - - // ── Cover-fit helpers ── - function coverFit(imgW: number, imgH: number) { - const scale = Math.max(canvasSize.width / imgW, canvasSize.height / imgH); - return { - scale, - offsetX: (canvasSize.width - imgW * scale) / 2, - offsetY: (canvasSize.height - imgH * scale) / 2, - }; - } - - // ── OCR coord transform ── - const { - detections: ocrDets, - frameWidth: ocrFW, - frameHeight: ocrFH, - } = ocrData; - const ocrIsLandscape = ocrFW > ocrFH; - const ocrImgW = ocrIsLandscape ? ocrFH : ocrFW; - const ocrImgH = ocrIsLandscape ? ocrFW : ocrFH; - const { - scale: ocrScale, - offsetX: ocrOX, - offsetY: ocrOY, - } = coverFit(ocrImgW, ocrImgH); - function ocrToX(px: number, py: number) { - return (ocrIsLandscape ? ocrFH - py : px) * ocrScale + ocrOX; - } - function ocrToY(px: number, py: number) { - return (ocrIsLandscape ? px : py) * ocrScale + ocrOY; - } - - // ── Object detection cover-fit ── - const { - scale: detScale, - offsetX: detOX, - offsetY: detOY, - } = coverFit(imageSize.width, imageSize.height); - - const font = matchFont({ fontFamily: 'Helvetica', fontSize: 11 }); - - return ( - - - - - - {/* ── Overlays ── */} - - setCanvasSize({ - width: e.nativeEvent.layout.width, - height: e.nativeEvent.layout.height, - }) - } - > - {activeModel === 'segmentation' && maskImage && ( - - - - )} - - {activeModel === 'style_transfer' && styledImage && ( - - - - )} - - {activeModel === 'object_detection' && ( - <> - {detections.map((det, i) => { - const left = det.bbox.x1 * detScale + detOX; - const top = det.bbox.y1 * detScale + detOY; - const w = (det.bbox.x2 - det.bbox.x1) * detScale; - const h = (det.bbox.y2 - det.bbox.y1) * detScale; - return ( - - - - {det.label} {(det.score * 100).toFixed(0)}% - - - - ); - })} - - )} - - {activeModel === 'ocr' && ( - - {ocrDets.map((det, i) => { - if (!det.bbox || det.bbox.length < 2) return null; - const path = Skia.Path.Make(); - path.moveTo( - ocrToX(det.bbox[0]!.x, det.bbox[0]!.y), - ocrToY(det.bbox[0]!.x, det.bbox[0]!.y) - ); - for (let j = 1; j < det.bbox.length; j++) { - path.lineTo( - ocrToX(det.bbox[j]!.x, det.bbox[j]!.y), - ocrToY(det.bbox[j]!.x, det.bbox[j]!.y) - ); - } - path.close(); - const lx = ocrToX(det.bbox[0]!.x, det.bbox[0]!.y); - const ly = Math.max( - 0, - ocrToY(det.bbox[0]!.x, det.bbox[0]!.y) - 4 - ); - return ( - - - - {font && ( - - )} - - ); - })} - - )} - - - {!activeIsReady && ( - - m.id === activeModel)?.label} ${(activeDownloadProgress * 100).toFixed(0)}%`} - /> - - )} - - - - {MODELS.map((m) => ( - setActiveModel(m.id)} - > - - {m.label} - - - ))} - - - - - - {activeModel === 'classification' && ( - - - {classResult.label || '—'} - - {classResult.label ? ( - - {(classResult.score * 100).toFixed(1)}% - - ) : null} - - )} - {activeModel === 'object_detection' && ( - - {detections.length} - objects - - )} - {activeModel === 'segmentation' && ( - - DeepLab V3 - segmentation - - )} - {activeModel === 'style_transfer' && ( - - Rain Princess - style - - )} - {activeModel === 'ocr' && ( - - {ocrDets.length} - regions - - )} - - - {fps} - fps - - - - - ); -} - -// ─── Styles ────────────────────────────────────────────────────────────────── - -const styles = StyleSheet.create({ - container: { flex: 1, backgroundColor: 'black' }, - centered: { - flex: 1, - backgroundColor: 'black', - justifyContent: 'center', - alignItems: 'center', - gap: 16, - }, - message: { color: 'white', fontSize: 18 }, - button: { - paddingHorizontal: 24, - paddingVertical: 12, - backgroundColor: ColorPalette.primary, - borderRadius: 24, - }, - buttonText: { color: 'white', fontSize: 15, fontWeight: '600' }, - loadingOverlay: { - ...StyleSheet.absoluteFillObject, - backgroundColor: 'rgba(0,0,0,0.6)', - justifyContent: 'center', - alignItems: 'center', - }, - topBarWrapper: { - position: 'absolute', - top: 0, - left: 0, - right: 0, - }, - pickerContent: { - paddingHorizontal: 12, - gap: 8, - }, - chip: { - paddingHorizontal: 16, - paddingVertical: 8, - borderRadius: 20, - backgroundColor: 'rgba(0,0,0,0.55)', - borderWidth: 1, - borderColor: 'rgba(255,255,255,0.2)', - }, - chipActive: { - backgroundColor: ColorPalette.primary, - borderColor: ColorPalette.primary, - }, - chipText: { - color: 'rgba(255,255,255,0.8)', - fontSize: 13, - fontWeight: '600', - }, - chipTextActive: { color: 'white' }, - bbox: { - position: 'absolute', - borderWidth: 2, - borderColor: ColorPalette.primary, - borderRadius: 4, - }, - bboxLabel: { - position: 'absolute', - top: -22, - left: -2, - backgroundColor: ColorPalette.primary, - paddingHorizontal: 6, - paddingVertical: 2, - borderRadius: 4, - }, - bboxLabelText: { color: 'white', fontSize: 11, fontWeight: '600' }, - bottomBarWrapper: { - position: 'absolute', - bottom: 0, - left: 0, - right: 0, - alignItems: 'center', - }, - bottomBar: { - flexDirection: 'row', - alignItems: 'center', - backgroundColor: 'rgba(0,0,0,0.55)', - borderRadius: 24, - paddingHorizontal: 28, - paddingVertical: 10, - gap: 24, - }, - resultContainer: { alignItems: 'flex-start', maxWidth: 220 }, - resultText: { - color: 'white', - fontSize: 16, - fontWeight: '700', - }, - resultSub: { - color: 'rgba(255,255,255,0.6)', - fontSize: 12, - fontWeight: '500', - }, - statDivider: { - width: 1, - height: 32, - backgroundColor: 'rgba(255,255,255,0.2)', - }, - statItem: { alignItems: 'center' }, - statValue: { - color: 'white', - fontSize: 22, - fontWeight: '700', - letterSpacing: -0.5, - }, - statLabel: { - color: 'rgba(255,255,255,0.55)', - fontSize: 11, - fontWeight: '500', - textTransform: 'uppercase', - letterSpacing: 0.8, - }, -}); From 4643ce0d2f9ee67c67713756971d7cb2267f94e7 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 26 Feb 2026 11:53:27 +0100 Subject: [PATCH 28/71] fix: drawing style transfer image --- .../app/style_transfer/index.tsx | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/apps/computer-vision/app/style_transfer/index.tsx b/apps/computer-vision/app/style_transfer/index.tsx index 90801cb053..ce84a6c583 100644 --- a/apps/computer-vision/app/style_transfer/index.tsx +++ b/apps/computer-vision/app/style_transfer/index.tsx @@ -27,6 +27,7 @@ export default function StyleTransferScreen() { const [imageUri, setImageUri] = useState(''); const [styledImage, setStyledImage] = useState(null); + const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); const handleCameraPress = async (isCamera: boolean) => { const image = await getImage(isCamera); @@ -43,16 +44,8 @@ export default function StyleTransferScreen() { const output = await model.forward(imageUri); const height = output.sizes[0]; const width = output.sizes[1]; - // Convert RGB -> RGBA for Skia - const rgba = new Uint8Array(width * height * 4); - const rgb = output.dataPtr; - for (let i = 0; i < width * height; i++) { - rgba[i * 4] = rgb[i * 3]; - rgba[i * 4 + 1] = rgb[i * 3 + 1]; - rgba[i * 4 + 2] = rgb[i * 3 + 2]; - rgba[i * 4 + 3] = 255; - } - const skData = Skia.Data.fromBytes(rgba); + // Native already returns RGBA uint8 — use directly + const skData = Skia.Data.fromBytes(output.dataPtr); const img = Skia.Image.MakeImage( { width, @@ -83,16 +76,26 @@ export default function StyleTransferScreen() { {styledImage ? ( - - - + + setCanvasSize({ + width: e.nativeEvent.layout.width, + height: e.nativeEvent.layout.height, + }) + } + > + + + + ) : ( Date: Thu, 26 Feb 2026 13:07:48 +0100 Subject: [PATCH 29/71] fix: tests --- .../app/style_transfer/index.tsx | 1 - .../common/rnexecutorch/tests/CMakeLists.txt | 23 +++++++++++++++---- .../tests/integration/StyleTransferTest.cpp | 8 ++++--- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/apps/computer-vision/app/style_transfer/index.tsx b/apps/computer-vision/app/style_transfer/index.tsx index ce84a6c583..db238d671b 100644 --- a/apps/computer-vision/app/style_transfer/index.tsx +++ b/apps/computer-vision/app/style_transfer/index.tsx @@ -44,7 +44,6 @@ export default function StyleTransferScreen() { const output = await model.forward(imageUri); const height = output.sizes[0]; const width = output.sizes[1]; - // Native already returns RGBA uint8 — use directly const skData = Skia.Data.fromBytes(output.dataPtr); const img = Skia.Image.MakeImage( { diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index 426aafc1f3..6705b687cb 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -162,8 +162,11 @@ add_rn_test(BaseModelTests integration/BaseModelTest.cpp) add_rn_test(ClassificationTests integration/ClassificationTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/classification/Classification.cpp + ${RNEXECUTORCH_DIR}/models/VisionModel.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${IMAGE_UTILS_SOURCES} - LIBS opencv_deps + LIBS opencv_deps android ) add_rn_test(ObjectDetectionTests integration/ObjectDetectionTest.cpp @@ -181,8 +184,11 @@ add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/embeddings/image/ImageEmbeddings.cpp ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp + ${RNEXECUTORCH_DIR}/models/VisionModel.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${IMAGE_UTILS_SOURCES} - LIBS opencv_deps + LIBS opencv_deps android ) add_rn_test(TextEmbeddingsTests integration/TextEmbeddingsTest.cpp @@ -196,8 +202,11 @@ add_rn_test(TextEmbeddingsTests integration/TextEmbeddingsTest.cpp add_rn_test(StyleTransferTests integration/StyleTransferTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/style_transfer/StyleTransfer.cpp + ${RNEXECUTORCH_DIR}/models/VisionModel.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${IMAGE_UTILS_SOURCES} - LIBS opencv_deps + LIBS opencv_deps android ) add_rn_test(VADTests integration/VoiceActivityDetectionTest.cpp @@ -273,8 +282,10 @@ add_rn_test(OCRTests integration/OCRTest.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/DetectorUtils.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognitionHandlerUtils.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognizerUtils.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${IMAGE_UTILS_SOURCES} - LIBS opencv_deps + LIBS opencv_deps android ) add_rn_test(VerticalOCRTests integration/VerticalOCRTest.cpp @@ -287,6 +298,8 @@ add_rn_test(VerticalOCRTests integration/VerticalOCRTest.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/DetectorUtils.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognitionHandlerUtils.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognizerUtils.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${IMAGE_UTILS_SOURCES} - LIBS opencv_deps + LIBS opencv_deps android ) diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp index 532b4c04b2..c92299cb15 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp @@ -60,7 +60,9 @@ TEST(StyleTransferGenerateTests, MalformedURIThrows) { TEST(StyleTransferGenerateTests, ValidImageReturnsNonNull) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); auto result = model.generateFromString(kValidTestImagePath); - EXPECT_NE(result, nullptr); + EXPECT_NE(result.dataPtr, nullptr); + EXPECT_GT(result.width, 0); + EXPECT_GT(result.height, 0); } TEST(StyleTransferGenerateTests, MultipleGeneratesWork) { @@ -68,8 +70,8 @@ TEST(StyleTransferGenerateTests, MultipleGeneratesWork) { EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath)); auto result1 = model.generateFromString(kValidTestImagePath); auto result2 = model.generateFromString(kValidTestImagePath); - EXPECT_NE(result1, nullptr); - EXPECT_NE(result2, nullptr); + EXPECT_NE(result1.dataPtr, nullptr); + EXPECT_NE(result2.dataPtr, nullptr); } TEST(StyleTransferInheritedTests, GetInputShapeWorks) { From 6d04387f8c68c54b85602f3dbe7ba2c653e50dd6 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 26 Feb 2026 13:52:30 +0100 Subject: [PATCH 30/71] feat: add possibility to switch between front/back camera --- .../app/vision_camera/index.tsx | 63 ++++++++++++++++++- .../rnexecutorch/models/VisionModel.cpp | 11 +--- 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index 6250188498..ccf8e41d64 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -42,6 +42,7 @@ import { Skia, SkImage, } from '@shopify/react-native-skia'; +import Svg, { Path, Polygon } from 'react-native-svg'; import { GeneratingContext } from '../../context'; import Spinner from '../../components/Spinner'; import ColorPalette from '../../colors'; @@ -119,6 +120,9 @@ export default function VisionCameraScreen() { const [activeTask, setActiveTask] = useState('classification'); const [activeModel, setActiveModel] = useState('classification'); const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); + const [cameraPosition, setCameraPosition] = useState<'back' | 'front'>( + 'back' + ); const { setGlobalGenerating } = useContext(GeneratingContext); const classification = useClassification({ @@ -149,7 +153,8 @@ export default function VisionCameraScreen() { const lastFrameTimeRef = useRef(Date.now()); const cameraPermission = useCameraPermission(); const devices = useCameraDevices(); - const device = devices.find((d) => d.position === 'back') ?? devices[0]; + const device = + devices.find((d) => d.position === cameraPosition) ?? devices[0]; const format = useMemo(() => { if (device == null) return undefined; try { @@ -375,7 +380,10 @@ export default function VisionCameraScreen() { /> setCanvasSize({ @@ -422,6 +430,9 @@ export default function VisionCameraScreen() { style={[ styles.bboxLabel, { backgroundColor: labelColorBg(det.label) }, + cameraPosition === 'front' && { + transform: [{ scaleX: -1 }], + }, ]} > @@ -518,6 +529,37 @@ export default function VisionCameraScreen() { ))} + + + + setCameraPosition((p) => (p === 'back' ? 'front' : 'back')) + } + > + + {/* Camera body */} + + {/* Rotate arrows — arc with arrowhead around the lens */} + + + + + ); } @@ -662,4 +704,21 @@ const styles = StyleSheet.create({ textShadowOffset: { width: 0, height: 1 }, textShadowRadius: 6, }, + bottomOverlay: { + position: 'absolute', + bottom: 0, + left: 0, + right: 0, + alignItems: 'center', + }, + flipButton: { + width: 56, + height: 56, + borderRadius: 28, + backgroundColor: 'rgba(255,255,255,0.2)', + justifyContent: 'center', + alignItems: 'center', + borderWidth: 1.5, + borderColor: 'rgba(255,255,255,0.4)', + }, }); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index 8f67175c41..c0ce049f28 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -11,16 +11,7 @@ using namespace facebook; cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) const { auto frameObj = frameData.asObject(runtime); - cv::Mat frame = ::rnexecutorch::utils::extractFrame(runtime, frameObj); - - // Camera sensors natively deliver frames in landscape orientation. - // Rotate 90° CW so all models receive upright portrait frames. - if (frame.cols > frame.rows) { - cv::Mat upright; - cv::rotate(frame, upright, cv::ROTATE_90_CLOCKWISE); - return upright; - } - return frame; + return ::rnexecutorch::utils::extractFrame(runtime, frameObj); } cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { From ca22fa0f9b6a2ac2c10fdb955733f1423a1863a2 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 26 Feb 2026 15:49:29 +0100 Subject: [PATCH 31/71] fix: rotation issue --- apps/computer-vision/app/vision_camera/index.tsx | 1 + .../common/rnexecutorch/models/VisionModel.cpp | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index ccf8e41d64..e09bdcc174 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -377,6 +377,7 @@ export default function VisionCameraScreen() { outputs={[frameOutput]} isActive={true} format={format} + orientationSource="interface" /> frame.rows) { + cv::Mat upright; + cv::rotate(frame, upright, cv::ROTATE_90_CLOCKWISE); + return upright; + } + return frame; } cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { From c2285c88ec03915fbc656eeee7870138a62db3ae Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 11 Mar 2026 11:25:24 +0100 Subject: [PATCH 32/71] fix: issues after rebase --- .../app/vision_camera/index.tsx | 150 ++++++++++++---- apps/computer-vision/package.json | 6 +- .../host_objects/JsiConversions.h | 7 +- .../host_objects/ModelHostObject.h | 34 ++-- .../metaprogramming/TypeConcepts.h | 20 --- .../object_detection/ObjectDetection.cpp | 33 ---- .../BaseSemanticSegmentation.cpp | 74 ++++++-- .../BaseSemanticSegmentation.h | 36 ++-- .../useSemanticSegmentation.ts | 33 +++- packages/react-native-executorch/src/index.ts | 1 + .../SemanticSegmentationModule.ts | 4 +- yarn.lock | 166 ++---------------- 12 files changed, 261 insertions(+), 303 deletions(-) diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index e09bdcc174..7d005c5c59 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -26,12 +26,19 @@ import { } from 'react-native-vision-camera'; import { createSynchronizable, scheduleOnRN } from 'react-native-worklets'; import { - DEEPLAB_V3_RESNET50, + DEEPLAB_V3_RESNET50_QUANTIZED, + DEEPLAB_V3_RESNET101_QUANTIZED, + DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED, + LRASPP_MOBILENET_V3_LARGE_QUANTIZED, + FCN_RESNET50_QUANTIZED, + FCN_RESNET101_QUANTIZED, + SELFIE_SEGMENTATION, Detection, EFFICIENTNET_V2_S, + RF_DETR_NANO, SSDLITE_320_MOBILENET_V3_LARGE, useClassification, - useImageSegmentation, + useSemanticSegmentation, useObjectDetection, } from 'react-native-executorch'; import { @@ -48,7 +55,17 @@ import Spinner from '../../components/Spinner'; import ColorPalette from '../../colors'; type TaskId = 'classification' | 'objectDetection' | 'segmentation'; -type ModelId = 'classification' | 'objectDetection' | 'segmentation'; +type ModelId = + | 'classification' + | 'objectDetection_ssdlite' + | 'objectDetection_rfdetr' + | 'segmentation_deeplab_resnet50' + | 'segmentation_deeplab_resnet101' + | 'segmentation_deeplab_mobilenet' + | 'segmentation_lraspp' + | 'segmentation_fcn_resnet50' + | 'segmentation_fcn_resnet101' + | 'segmentation_selfie'; type TaskVariant = { id: ModelId; label: string }; type Task = { id: TaskId; label: string; variants: TaskVariant[] }; @@ -62,12 +79,23 @@ const TASKS: Task[] = [ { id: 'segmentation', label: 'Segment', - variants: [{ id: 'segmentation', label: 'DeepLab V3' }], + variants: [ + { id: 'segmentation_deeplab_resnet50', label: 'DeepLab ResNet50' }, + { id: 'segmentation_deeplab_resnet101', label: 'DeepLab ResNet101' }, + { id: 'segmentation_deeplab_mobilenet', label: 'DeepLab MobileNet' }, + { id: 'segmentation_lraspp', label: 'LRASPP MobileNet' }, + { id: 'segmentation_fcn_resnet50', label: 'FCN ResNet50' }, + { id: 'segmentation_fcn_resnet101', label: 'FCN ResNet101' }, + { id: 'segmentation_selfie', label: 'Selfie' }, + ], }, { id: 'objectDetection', label: 'Detect', - variants: [{ id: 'objectDetection', label: 'SSDLite MobileNet' }], + variants: [ + { id: 'objectDetection_ssdlite', label: 'SSDLite MobileNet' }, + { id: 'objectDetection_rfdetr', label: 'RF-DETR Nano' }, + ], }, ]; @@ -129,20 +157,76 @@ export default function VisionCameraScreen() { model: EFFICIENTNET_V2_S, preventLoad: activeModel !== 'classification', }); - const objectDetection = useObjectDetection({ + const objectDetectionSsdlite = useObjectDetection({ model: SSDLITE_320_MOBILENET_V3_LARGE, - preventLoad: activeModel !== 'objectDetection', + preventLoad: activeModel !== 'objectDetection_ssdlite', }); - const segmentation = useImageSegmentation({ - model: DEEPLAB_V3_RESNET50, - preventLoad: activeModel !== 'segmentation', + const objectDetectionRfdetr = useObjectDetection({ + model: RF_DETR_NANO, + preventLoad: activeModel !== 'objectDetection_rfdetr', }); - const activeIsGenerating = { - classification: classification.isGenerating, - objectDetection: objectDetection.isGenerating, - segmentation: segmentation.isGenerating, - }[activeModel]; + const activeObjectDetection = + { + objectDetection_ssdlite: objectDetectionSsdlite, + objectDetection_rfdetr: objectDetectionRfdetr, + }[activeModel as 'objectDetection_ssdlite' | 'objectDetection_rfdetr'] ?? + null; + const segDeeplabResnet50 = useSemanticSegmentation({ + model: DEEPLAB_V3_RESNET50_QUANTIZED, + preventLoad: activeModel !== 'segmentation_deeplab_resnet50', + }); + const segDeeplabResnet101 = useSemanticSegmentation({ + model: DEEPLAB_V3_RESNET101_QUANTIZED, + preventLoad: activeModel !== 'segmentation_deeplab_resnet101', + }); + const segDeeplabMobilenet = useSemanticSegmentation({ + model: DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED, + preventLoad: activeModel !== 'segmentation_deeplab_mobilenet', + }); + const segLraspp = useSemanticSegmentation({ + model: LRASPP_MOBILENET_V3_LARGE_QUANTIZED, + preventLoad: activeModel !== 'segmentation_lraspp', + }); + const segFcnResnet50 = useSemanticSegmentation({ + model: FCN_RESNET50_QUANTIZED, + preventLoad: activeModel !== 'segmentation_fcn_resnet50', + }); + const segFcnResnet101 = useSemanticSegmentation({ + model: FCN_RESNET101_QUANTIZED, + preventLoad: activeModel !== 'segmentation_fcn_resnet101', + }); + const segSelfie = useSemanticSegmentation({ + model: SELFIE_SEGMENTATION, + preventLoad: activeModel !== 'segmentation_selfie', + }); + + const activeSegmentation = + { + segmentation_deeplab_resnet50: segDeeplabResnet50, + segmentation_deeplab_resnet101: segDeeplabResnet101, + segmentation_deeplab_mobilenet: segDeeplabMobilenet, + segmentation_lraspp: segLraspp, + segmentation_fcn_resnet50: segFcnResnet50, + segmentation_fcn_resnet101: segFcnResnet101, + segmentation_selfie: segSelfie, + }[ + activeModel as + | 'segmentation_deeplab_resnet50' + | 'segmentation_deeplab_resnet101' + | 'segmentation_deeplab_mobilenet' + | 'segmentation_lraspp' + | 'segmentation_fcn_resnet50' + | 'segmentation_fcn_resnet101' + | 'segmentation_selfie' + ] ?? null; + + const activeIsGenerating = + activeModel === 'classification' + ? classification.isGenerating + : activeModel.startsWith('objectDetection') + ? (activeObjectDetection?.isGenerating ?? false) + : (activeSegmentation?.isGenerating ?? false); useEffect(() => { setGlobalGenerating(activeIsGenerating); @@ -211,8 +295,8 @@ export default function VisionCameraScreen() { ); const classRof = classification.runOnFrame; - const detRof = objectDetection.runOnFrame; - const segRof = segmentation.runOnFrame; + const detRof = activeObjectDetection?.runOnFrame ?? null; + const segRof = activeSegmentation?.runOnFrame ?? null; useEffect(() => { frameKillSwitch.setBlocking(true); @@ -255,7 +339,7 @@ export default function VisionCameraScreen() { } scheduleOnRN(updateClass, { label: bestLabel, score: bestScore }); } - } else if (activeModel === 'objectDetection') { + } else if (activeModel.startsWith('objectDetection')) { if (!detRof) return; const iw = frame.width > frame.height ? frame.height : frame.width; const ih = frame.width > frame.height ? frame.width : frame.height; @@ -267,7 +351,7 @@ export default function VisionCameraScreen() { imageHeight: ih, }); } - } else if (activeModel === 'segmentation') { + } else if (activeModel.startsWith('segmentation')) { if (!segRof) return; const result = segRof(frame, [], false); if (result?.ARGMAX) { @@ -313,17 +397,19 @@ export default function VisionCameraScreen() { ), }); - const activeIsReady = { - classification: classification.isReady, - objectDetection: objectDetection.isReady, - segmentation: segmentation.isReady, - }[activeModel]; + const activeIsReady = + activeModel === 'classification' + ? classification.isReady + : activeModel.startsWith('objectDetection') + ? (activeObjectDetection?.isReady ?? false) + : (activeSegmentation?.isReady ?? false); - const activeDownloadProgress = { - classification: classification.downloadProgress, - objectDetection: objectDetection.downloadProgress, - segmentation: segmentation.downloadProgress, - }[activeModel]; + const activeDownloadProgress = + activeModel === 'classification' + ? classification.downloadProgress + : activeModel.startsWith('objectDetection') + ? (activeObjectDetection?.downloadProgress ?? 0) + : (activeSegmentation?.downloadProgress ?? 0); if (!cameraPermission.hasPermission) { return ( @@ -393,7 +479,7 @@ export default function VisionCameraScreen() { }) } > - {activeModel === 'segmentation' && maskImage && ( + {activeModel.startsWith('segmentation') && maskImage && ( )} - {activeModel === 'objectDetection' && ( + {activeModel.startsWith('objectDetection') && ( <> {detections.map((det, i) => { const left = det.bbox.x1 * detScale + detOX; @@ -480,7 +566,6 @@ export default function VisionCameraScreen() { horizontal showsHorizontalScrollIndicator={false} contentContainerStyle={styles.tabsContent} - pointerEvents="box-none" > {TASKS.map((t) => ( {activeTaskInfo.variants.map((v) => ( #include +#include #include #include #include @@ -581,9 +582,9 @@ getJsiValue(const models::style_transfer::PixelDataResult &result, return obj; } -inline jsi::Value -getJsiValue(const models::image_segmentation::SegmentationResult &result, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue( + const rnexecutorch::models::image_segmentation::SegmentationResult &result, + jsi::Runtime &runtime) { jsi::Object dict(runtime); auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, result.argmax); diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 7a432a50f6..a2d915e699 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -86,6 +86,16 @@ template class ModelHostObject : public JsiHostObject { addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, synchronousHostFunction<&Model::streamStop>, "streamStop")); + + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, + promiseHostFunction<&Model::generateFromPhonemes>, + "generateFromPhonemes")); + + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, + promiseHostFunction<&Model::streamFromPhonemes>, + "streamFromPhonemes")); } if constexpr (meta::SameAs) { @@ -196,23 +206,6 @@ template class ModelHostObject : public JsiHostObject { "generateFromString")); } - if constexpr (meta::HasGenerateFromFrame) { - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::streamStop>, - "streamStop")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::streamInsert>, - "streamInsert")); - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::generateFromPhonemes>, - "generateFromPhonemes")); - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::streamFromPhonemes>, - "streamFromPhonemes")); - } - if constexpr (meta::HasGenerateFromString) { addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, @@ -232,13 +225,6 @@ template class ModelHostObject : public JsiHostObject { promiseHostFunction<&Model::generateFromPixels>, "generateFromPixels")); } - - if constexpr (meta::HasGenerateFromPixels) { - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, - visionHostFunction<&Model::generateFromPixels>, - "generateFromPixels")); - } } // A generic host function that runs synchronously, works analogously to the diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index 5cf0c79e14..2d7612f250 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -26,26 +26,6 @@ concept HasGenerateFromPixels = requires(T t) { { &T::generateFromPixels }; }; -template -concept HasGenerateFromString = requires(T t) { - { &T::generateFromString }; -}; - -template -concept HasGenerateFromPixels = requires(T t) { - { &T::generateFromPixels }; -}; - -template -concept HasGenerateFromString = requires(T t) { - { &T::generateFromString }; -}; - -template -concept HasGenerateFromPixels = requires(T t) { - { &T::generateFromPixels }; -}; - template concept HasGenerateFromFrame = requires(T t) { { &T::generateFromFrame }; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index e54f3e9a4a..bf209682ac 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -78,39 +78,6 @@ cv::Mat ObjectDetection::preprocessFrame(const cv::Mat &frame) const { return rgb; } -cv::Mat ObjectDetection::preprocessFrame(const cv::Mat &frame) const { - const std::vector tensorDims = getAllInputShapes()[0]; - cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1], - tensorDims[tensorDims.size() - 2]); - - cv::Mat rgb; - - if (frame.channels() == 4) { -#ifdef __APPLE__ - cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); -#else - cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); -#endif - } else if (frame.channels() == 3) { - rgb = frame; - } else { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unsupported frame format: %d channels", frame.channels()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - // Only resize if dimensions don't match - if (rgb.size() != tensorSize) { - cv::Mat resized; - cv::resize(rgb, resized, tensorSize); - return resized; - } - - return rgb; -} - std::vector ObjectDetection::postprocess(const std::vector &tensors, cv::Size originalSize, double detectionThreshold) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp index 3016f8edf1..cf883728d9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp @@ -15,7 +15,8 @@ BaseSemanticSegmentation::BaseSemanticSegmentation( const std::string &modelSource, std::vector normMean, std::vector normStd, std::vector allClasses, std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker), allClasses_(std::move(allClasses)) { + : VisionModel(modelSource, callInvoker), + allClasses_(std::move(allClasses)) { initModelImageSize(); if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); @@ -49,6 +50,30 @@ void BaseSemanticSegmentation::initModelImageSize() { numModelPixels = modelImageSize.area(); } +cv::Mat BaseSemanticSegmentation::preprocessFrame(const cv::Mat &frame) const { + cv::Mat rgb; + if (frame.channels() == 4) { +#ifdef __APPLE__ + cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); +#else + cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); +#endif + } else if (frame.channels() == 3) { + rgb = frame; + } else { + char msg[64]; + std::snprintf(msg, sizeof(msg), "Unsupported frame format: %d channels", + frame.channels()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, msg); + } + if (rgb.size() != modelImageSize) { + cv::Mat resized; + cv::resize(rgb, resized, modelImageSize); + return resized; + } + return rgb; +} + TensorPtr BaseSemanticSegmentation::preprocess(const std::string &imageSource, cv::Size &originalSize) { auto [inputTensor, origSize] = image_processing::readImageToTensor( @@ -62,11 +87,8 @@ std::shared_ptr BaseSemanticSegmentation::generate( std::set> classesOfInterest, bool resize) { std::scoped_lock lock(inference_mutex_); - cv::Mat preprocessed = preprocessFrame(image); - - const std::vector tensorDims = getAllInputShapes()[0]; - auto inputTensor = - image_processing::getTensorFromMatrix(tensorDims, preprocessed); + cv::Size originalSize; + auto inputTensor = preprocess(imageSource, originalSize); auto forwardResult = BaseModel::forward(inputTensor); @@ -76,11 +98,39 @@ std::shared_ptr BaseSemanticSegmentation::generate( "Ensure the model input is correct."); } - return postprocess(forwardResult->at(0).toTensor(), originalSize, allClasses_, - classesOfInterest, resize); + auto result = computeResult(forwardResult->at(0).toTensor(), originalSize, + allClasses_, classesOfInterest, resize); + return populateDictionary(result.argmax, result.classBuffers); +} + +image_segmentation::SegmentationResult +BaseSemanticSegmentation::generateFromFrame( + jsi::Runtime &runtime, const jsi::Value &frameData, + std::set> classesOfInterest, bool resize) { + std::scoped_lock lock(inference_mutex_); + + cv::Mat frame = extractFromFrame(runtime, frameData); + cv::Mat preprocessed = preprocessFrame(frame); + cv::Size originalSize = frame.size(); + + const std::vector tensorDims = getAllInputShapes()[0]; + auto inputTensor = + (normMean_ && normStd_) + ? image_processing::getTensorFromMatrix(tensorDims, preprocessed, + *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(tensorDims, preprocessed); + + auto forwardResult = BaseModel::forward(inputTensor); + if (!forwardResult.ok()) { + throw RnExecutorchError(forwardResult.error(), + "The model's forward function did not succeed."); + } + + return computeResult(forwardResult->at(0).toTensor(), originalSize, + allClasses_, classesOfInterest, resize); } -std::shared_ptr BaseSemanticSegmentation::postprocess( +image_segmentation::SegmentationResult BaseSemanticSegmentation::computeResult( const Tensor &tensor, cv::Size originalSize, std::vector &allClasses, std::set> &classesOfInterest, bool resize) { @@ -189,13 +239,13 @@ std::shared_ptr BaseSemanticSegmentation::postprocess( } } - return populateDictionary(argmax, buffersToReturn); + return image_segmentation::SegmentationResult{argmax, buffersToReturn}; } std::shared_ptr BaseSemanticSegmentation::populateDictionary( std::shared_ptr argmax, - std::shared_ptr>> + std::shared_ptr< + std::unordered_map>> classesToOutput) { auto promisePtr = std::make_shared>(); std::future doneFuture = promisePtr->get_future(); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index bd2a6b9e84..97fe3815c8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -9,7 +9,7 @@ #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" #include #include -#include +#include namespace rnexecutorch { namespace models::semantic_segmentation { @@ -18,7 +18,7 @@ using namespace facebook; using executorch::aten::Tensor; using executorch::extension::TensorPtr; -class BaseSemanticSegmentation : public BaseModel { +class BaseSemanticSegmentation : public VisionModel { public: BaseSemanticSegmentation(const std::string &modelSource, std::vector normMean, @@ -26,18 +26,29 @@ class BaseSemanticSegmentation : public BaseModel { std::vector allClasses, std::shared_ptr callInvoker); + // Async path: called from promiseHostFunction on a thread-pool thread. + // Returns a jsi::Object via callInvoker (safe to block there). [[nodiscard("Registered non-void function")]] std::shared_ptr generate(std::string imageSource, std::set> classesOfInterest, bool resize); + // Sync path: called from visionHostFunction on the camera worklet thread. + // Must NOT use callInvoker — returns a plain SegmentationResult that + // visionHostFunction converts to JSI via getJsiValue. + [[nodiscard("Registered non-void function")]] + image_segmentation::SegmentationResult + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, + std::set> classesOfInterest, + bool resize); + protected: cv::Mat preprocessFrame(const cv::Mat &frame) const override; - virtual SegmentationResult - postprocess(const Tensor &tensor, cv::Size originalSize, - std::vector &allClasses, - std::set> &classesOfInterest, - bool resize); + virtual image_segmentation::SegmentationResult + computeResult(const Tensor &tensor, cv::Size originalSize, + std::vector &allClasses, + std::set> &classesOfInterest, + bool resize); cv::Size modelImageSize; std::size_t numModelPixels; @@ -48,12 +59,13 @@ class BaseSemanticSegmentation : public BaseModel { private: void initModelImageSize(); - SegmentationResult runInference( - cv::Mat image, cv::Size originalSize, std::vector allClasses, - std::set> classesOfInterest, bool resize); + TensorPtr preprocess(const std::string &imageSource, cv::Size &originalSize); - TensorPtr preprocessFromString(const std::string &imageSource, - cv::Size &originalSize); + std::shared_ptr populateDictionary( + std::shared_ptr argmax, + std::shared_ptr< + std::unordered_map>> + classesToOutput); }; } // namespace models::semantic_segmentation diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts index dd43aaf8b3..19a5640318 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts @@ -34,14 +34,20 @@ export const useSemanticSegmentation = < }: SemanticSegmentationProps): SemanticSegmentationType< SegmentationLabels> > => { - const { error, isReady, isGenerating, downloadProgress, runForward } = - useModuleFactory({ - factory: (config, onProgress) => - SemanticSegmentationModule.fromModelName(config, onProgress), - config: model, - deps: [model.modelName, model.modelSource], - preventLoad, - }); + const { + error, + isReady, + isGenerating, + downloadProgress, + runForward, + instance, + } = useModuleFactory({ + factory: (config, onProgress) => + SemanticSegmentationModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); const forward = >>( imageSource: string, @@ -52,5 +58,14 @@ export const useSemanticSegmentation = < inst.forward(imageSource, classesOfInterest, resizeToInput) ); - return { error, isReady, isGenerating, downloadProgress, forward }; + const runOnFrame = instance?.runOnFrame ?? null; + + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + runOnFrame, + }; }; diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index 5bb4d3d134..1947fd7269 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -129,6 +129,7 @@ export * from './hooks/computer_vision/useClassification'; export * from './hooks/computer_vision/useObjectDetection'; export * from './hooks/computer_vision/useStyleTransfer'; export * from './hooks/computer_vision/useSemanticSegmentation'; +export * from './hooks/computer_vision/useSemanticSegmentation'; export * from './hooks/computer_vision/useOCR'; export * from './hooks/computer_vision/useVerticalOCR'; export * from './hooks/computer_vision/useImageEmbeddings'; diff --git a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts index ffc7203d8b..14f2cb2439 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts @@ -138,7 +138,6 @@ export class SemanticSegmentationModule< } const nativeGenerateFromFrame = this.nativeModule.generateFromFrame; - const allClassNames = this.allClassNames; return ( frame: any, @@ -155,7 +154,6 @@ export class SemanticSegmentationModule< }; return nativeGenerateFromFrame( frameData, - allClassNames, classesOfInterest, resizeToInput ); @@ -296,7 +294,7 @@ export class SemanticSegmentationModule< ); const nativeResult = await this.nativeModule.generate( - imageSource, + input, classesOfInterestNames, resizeToInput ); diff --git a/yarn.lock b/yarn.lock index b76e118881..0ed00cdfd9 100644 --- a/yarn.lock +++ b/yarn.lock @@ -110,19 +110,6 @@ __metadata: languageName: node linkType: hard -"@babel/generator@npm:^7.29.0": - version: 7.29.1 - resolution: "@babel/generator@npm:7.29.1" - dependencies: - "@babel/parser": "npm:^7.29.0" - "@babel/types": "npm:^7.29.0" - "@jridgewell/gen-mapping": "npm:^0.3.12" - "@jridgewell/trace-mapping": "npm:^0.3.28" - jsesc: "npm:^3.0.2" - checksum: 10/61fe4ddd6e817aa312a14963ccdbb5c9a8c57e8b97b98d19a8a99ccab2215fda1a5f52bc8dd8d2e3c064497ddeb3ab8ceb55c76fa0f58f8169c34679d2256fe0 - languageName: node - linkType: hard - "@babel/helper-annotate-as-pure@npm:^7.27.1, @babel/helper-annotate-as-pure@npm:^7.27.3": version: 7.27.3 resolution: "@babel/helper-annotate-as-pure@npm:7.27.3" @@ -255,13 +242,6 @@ __metadata: languageName: node linkType: hard -"@babel/helper-plugin-utils@npm:^7.28.6": - version: 7.28.6 - resolution: "@babel/helper-plugin-utils@npm:7.28.6" - checksum: 10/21c853bbc13dbdddf03309c9a0477270124ad48989e1ad6524b83e83a77524b333f92edd2caae645c5a7ecf264ec6d04a9ebe15aeb54c7f33c037b71ec521e4a - languageName: node - linkType: hard - "@babel/helper-remap-async-to-generator@npm:^7.18.9, @babel/helper-remap-async-to-generator@npm:^7.27.1": version: 7.27.1 resolution: "@babel/helper-remap-async-to-generator@npm:7.27.1" @@ -288,19 +268,6 @@ __metadata: languageName: node linkType: hard -"@babel/helper-replace-supers@npm:^7.28.6": - version: 7.28.6 - resolution: "@babel/helper-replace-supers@npm:7.28.6" - dependencies: - "@babel/helper-member-expression-to-functions": "npm:^7.28.5" - "@babel/helper-optimise-call-expression": "npm:^7.27.1" - "@babel/traverse": "npm:^7.28.6" - peerDependencies: - "@babel/core": ^7.0.0 - checksum: 10/ad2724713a4d983208f509e9607e8f950855f11bd97518a700057eb8bec69d687a8f90dc2da0c3c47281d2e3b79cf1d14ecf1fe3e1ee0a8e90b61aee6759c9a7 - languageName: node - linkType: hard - "@babel/helper-skip-transparent-expression-wrappers@npm:^7.20.0, @babel/helper-skip-transparent-expression-wrappers@npm:^7.27.1": version: 7.27.1 resolution: "@babel/helper-skip-transparent-expression-wrappers@npm:7.27.1" @@ -376,17 +343,6 @@ __metadata: languageName: node linkType: hard -"@babel/parser@npm:^7.28.6, @babel/parser@npm:^7.29.0": - version: 7.29.0 - resolution: "@babel/parser@npm:7.29.0" - dependencies: - "@babel/types": "npm:^7.29.0" - bin: - parser: ./bin/babel-parser.js - checksum: 10/b1576dca41074997a33ee740d87b330ae2e647f4b7da9e8d2abd3772b18385d303b0cee962b9b88425e0f30d58358dbb8d63792c1a2d005c823d335f6a029747 - languageName: node - linkType: hard - "@babel/plugin-bugfix-firefox-class-in-computed-class-key@npm:^7.28.5": version: 7.28.5 resolution: "@babel/plugin-bugfix-firefox-class-in-computed-class-key@npm:7.28.5" @@ -811,17 +767,6 @@ __metadata: languageName: node linkType: hard -"@babel/plugin-syntax-typescript@npm:^7.28.6": - version: 7.28.6 - resolution: "@babel/plugin-syntax-typescript@npm:7.28.6" - dependencies: - "@babel/helper-plugin-utils": "npm:^7.28.6" - peerDependencies: - "@babel/core": ^7.0.0-0 - checksum: 10/5c55f9c63bd36cf3d7e8db892294c8f85000f9c1526c3a1cc310d47d1e174f5c6f6605e5cc902c4636d885faba7a9f3d5e5edc6b35e4f3b1fd4c2d58d0304fa5 - languageName: node - linkType: hard - "@babel/plugin-syntax-unicode-sets-regex@npm:^7.18.6": version: 7.18.6 resolution: "@babel/plugin-syntax-unicode-sets-regex@npm:7.18.6" @@ -1564,21 +1509,6 @@ __metadata: languageName: node linkType: hard -"@babel/plugin-transform-typescript@npm:^7.27.1": - version: 7.28.6 - resolution: "@babel/plugin-transform-typescript@npm:7.28.6" - dependencies: - "@babel/helper-annotate-as-pure": "npm:^7.27.3" - "@babel/helper-create-class-features-plugin": "npm:^7.28.6" - "@babel/helper-plugin-utils": "npm:^7.28.6" - "@babel/helper-skip-transparent-expression-wrappers": "npm:^7.27.1" - "@babel/plugin-syntax-typescript": "npm:^7.28.6" - peerDependencies: - "@babel/core": ^7.0.0-0 - checksum: 10/a0bccc531fa8710a45b0b593140273741e0e4a0721b1ef6ef9dfefae0bbe61528440d65aab7936929551fd76793272257d74f60cf66891352f793294930a4b67 - languageName: node - linkType: hard - "@babel/plugin-transform-unicode-escapes@npm:^7.27.1": version: 7.27.1 resolution: "@babel/plugin-transform-unicode-escapes@npm:7.27.1" @@ -1808,16 +1738,6 @@ __metadata: languageName: node linkType: hard -"@babel/types@npm:^7.28.6, @babel/types@npm:^7.29.0": - version: 7.29.0 - resolution: "@babel/types@npm:7.29.0" - dependencies: - "@babel/helper-string-parser": "npm:^7.27.1" - "@babel/helper-validator-identifier": "npm:^7.28.5" - checksum: 10/bfc2b211210f3894dcd7e6a33b2d1c32c93495dc1e36b547376aa33441abe551ab4bc1640d4154ee2acd8e46d3bbc925c7224caae02fcaf0e6a771e97fccc661 - languageName: node - linkType: hard - "@bcoe/v8-coverage@npm:^0.2.3": version: 0.2.3 resolution: "@bcoe/v8-coverage@npm:0.2.3" @@ -5990,18 +5910,6 @@ __metadata: languageName: node linkType: hard -"ajv@npm:^8.11.0": - version: 8.18.0 - resolution: "ajv@npm:8.18.0" - dependencies: - fast-deep-equal: "npm:^3.1.3" - fast-uri: "npm:^3.0.1" - json-schema-traverse: "npm:^1.0.0" - require-from-string: "npm:^2.0.2" - checksum: 10/bfed9de827a2b27c6d4084324eda76a4e32bdde27410b3e9b81d06e6f8f5c78370fc6b93fe1d869f1939ff1d7c4ae8896960995acb8425e3e9288c8884247c48 - languageName: node - linkType: hard - "anser@npm:^1.4.9": version: 1.4.10 resolution: "anser@npm:1.4.10" @@ -7362,14 +7270,14 @@ __metadata: react-native-gesture-handler: "npm:~2.28.0" react-native-image-picker: "npm:^7.2.2" react-native-loading-spinner-overlay: "npm:^3.0.1" - react-native-nitro-image: "npm:^0.12.0" - react-native-nitro-modules: "npm:^0.33.9" + react-native-nitro-image: "npm:0.13.0" + react-native-nitro-modules: "npm:0.35.0" react-native-reanimated: "npm:~4.2.2" react-native-safe-area-context: "npm:~5.6.0" react-native-screens: "npm:~4.16.0" react-native-svg: "npm:15.15.3" react-native-svg-transformer: "npm:^1.5.3" - react-native-vision-camera: "npm:5.0.0-beta.2" + react-native-vision-camera: "npm:5.0.0-beta.6" react-native-worklets: "npm:0.7.4" languageName: unknown linkType: soft @@ -14567,45 +14475,24 @@ __metadata: languageName: node linkType: hard -"react-native-nitro-image@npm:^0.12.0": - version: 0.12.0 - resolution: "react-native-nitro-image@npm:0.12.0" - peerDependencies: - react: "*" - react-native: "*" - react-native-nitro-modules: "*" - checksum: 10/03f165381c35e060d4d05eae3ce029b32a4009482f327e9526840f306181ca87a862b335e12667c55d4ee9f2069542ca93dd112feb7f1822bf7d2ddc38fe58f0 - languageName: node - linkType: hard - -"react-native-nitro-modules@npm:^0.33.9": - version: 0.33.9 - resolution: "react-native-nitro-modules@npm:0.33.9" - peerDependencies: - react: "*" - react-native: "*" - checksum: 10/4ebf4db46d1e4987a0e52054724081aa9712bcd1d505a6dbdd47aebc6afe72a7abaa0e947651d9f3cc594e4eb3dba47fc6f59db27c5a5ed383946e40d96543a0 - languageName: node - linkType: hard - -"react-native-nitro-image@npm:^0.12.0": - version: 0.12.0 - resolution: "react-native-nitro-image@npm:0.12.0" +"react-native-nitro-image@npm:0.13.0": + version: 0.13.0 + resolution: "react-native-nitro-image@npm:0.13.0" peerDependencies: react: "*" react-native: "*" react-native-nitro-modules: "*" - checksum: 10/03f165381c35e060d4d05eae3ce029b32a4009482f327e9526840f306181ca87a862b335e12667c55d4ee9f2069542ca93dd112feb7f1822bf7d2ddc38fe58f0 + checksum: 10/77f04e0c262fed839aa16276a31cce6d6969788d2aada55594fd083959dec9e00bd75f7bd8333cc59f3768bf329592736da60688e827a8d25d18de8bbda9b2d7 languageName: node linkType: hard -"react-native-nitro-modules@npm:^0.33.9": - version: 0.33.9 - resolution: "react-native-nitro-modules@npm:0.33.9" +"react-native-nitro-modules@npm:0.35.0": + version: 0.35.0 + resolution: "react-native-nitro-modules@npm:0.35.0" peerDependencies: react: "*" react-native: "*" - checksum: 10/4ebf4db46d1e4987a0e52054724081aa9712bcd1d505a6dbdd47aebc6afe72a7abaa0e947651d9f3cc594e4eb3dba47fc6f59db27c5a5ed383946e40d96543a0 + checksum: 10/6c9166a115a03bfc26d3cb9a75761a1fdf33a06bdfb853779539cfe3d7dc2239e242a7fbd4cbb7e9dc0af90a373606fc607fa8a09ef4ff49f7ff29ccff736bbf languageName: node linkType: hard @@ -14701,16 +14588,16 @@ __metadata: languageName: node linkType: hard -"react-native-vision-camera@npm:5.0.0-beta.2": - version: 5.0.0-beta.2 - resolution: "react-native-vision-camera@npm:5.0.0-beta.2" +"react-native-vision-camera@npm:5.0.0-beta.6": + version: 5.0.0-beta.6 + resolution: "react-native-vision-camera@npm:5.0.0-beta.6" peerDependencies: react: "*" react-native: "*" react-native-nitro-image: "*" react-native-nitro-modules: "*" react-native-worklets: "*" - checksum: 10/1f38d097d001c10b8544d0b931a9387a91c5df1e0677ae53e639962a90589586af02ca658ca5e99a5ca179af8d86bc8365227cf70750f2df4bfb775f4a26fc6d + checksum: 10/5c1f3104869e51b173d2bba88a69397c08949314a47a14d70e9ea01b6531c8c74ea6e15f8de4bab6c55eb82c673f8231b6be223d8b4b44a8df433f1cd35c0376 languageName: node linkType: hard @@ -14737,29 +14624,6 @@ __metadata: languageName: node linkType: hard -"react-native-worklets@npm:^0.7.2": - version: 0.7.4 - resolution: "react-native-worklets@npm:0.7.4" - dependencies: - "@babel/plugin-transform-arrow-functions": "npm:7.27.1" - "@babel/plugin-transform-class-properties": "npm:7.27.1" - "@babel/plugin-transform-classes": "npm:7.28.4" - "@babel/plugin-transform-nullish-coalescing-operator": "npm:7.27.1" - "@babel/plugin-transform-optional-chaining": "npm:7.27.1" - "@babel/plugin-transform-shorthand-properties": "npm:7.27.1" - "@babel/plugin-transform-template-literals": "npm:7.27.1" - "@babel/plugin-transform-unicode-regex": "npm:7.27.1" - "@babel/preset-typescript": "npm:7.27.1" - convert-source-map: "npm:2.0.0" - semver: "npm:7.7.3" - peerDependencies: - "@babel/core": "*" - react: "*" - react-native: "*" - checksum: 10/922b209940e298d21313d22f8a6eb87ad603442850c7ff8bc9cfef694cb211d7ec9903e24ee20b6bcf6164f8e7c165b65307dcca3d67465fdffda1c45fe05d1d - languageName: node - linkType: hard - "react-native@npm:0.81.5": version: 0.81.5 resolution: "react-native@npm:0.81.5" From 2a6879294c54f6ff25efada6a5ef2265b8d47081 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 11 Mar 2026 13:42:07 +0100 Subject: [PATCH 33/71] refactor: apply code review fixes for vision camera integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - visionHostFunction: preserve RnExecutorchError code in catch block - OCR/VerticalOCR generateFromFrame: add 90° CW rotation for landscape frames - VisionModel: lift preprocessFrame and modelImageSize from 5 subclasses into base class Co-Authored-By: Claude Sonnet 4.6 --- .../host_objects/ModelHostObject.h | 16 +++++++++ .../rnexecutorch/models/VisionModel.cpp | 28 +++++++++++++++ .../common/rnexecutorch/models/VisionModel.h | 14 ++++---- .../models/classification/Classification.cpp | 30 +--------------- .../models/classification/Classification.h | 5 --- .../embeddings/image/ImageEmbeddings.cpp | 30 +--------------- .../models/embeddings/image/ImageEmbeddings.h | 5 --- .../object_detection/ObjectDetection.cpp | 35 +------------------ .../models/object_detection/ObjectDetection.h | 4 --- .../common/rnexecutorch/models/ocr/OCR.cpp | 6 ++++ .../BaseSemanticSegmentation.cpp | 24 ------------- .../BaseSemanticSegmentation.h | 4 --- .../models/style_transfer/StyleTransfer.cpp | 30 +--------------- .../models/style_transfer/StyleTransfer.h | 5 --- .../models/vertical_ocr/VerticalOCR.cpp | 6 ++++ 15 files changed, 68 insertions(+), 174 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index a2d915e699..b78cbc7a81 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -331,8 +331,24 @@ template class ModelHostObject : public JsiHostObject { return jsi_conversion::getJsiValue(std::move(result), runtime); } + } catch (const RnExecutorchError &e) { + jsi::Object errorData(runtime); + errorData.setProperty(runtime, "code", e.getNumericCode()); + errorData.setProperty(runtime, "message", + jsi::String::createFromUtf8(runtime, e.what())); + throw jsi::JSError(runtime, jsi::Value(runtime, std::move(errorData))); + } catch (const std::runtime_error &e) { + // This catch should be merged with the next one + // (std::runtime_error inherits from std::exception) HOWEVER react + // native has broken RTTI which breaks proper exception type + // checking. Remove when the following change is present in our + // version: + // https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e + throw jsi::JSError(runtime, e.what()); } catch (const std::exception &e) { throw jsi::JSError(runtime, e.what()); + } catch (...) { + throw jsi::JSError(runtime, "Unknown error in vision function"); } } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index 8f67175c41..78b9b042db 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -23,6 +23,34 @@ cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime, return frame; } +cv::Mat VisionModel::preprocessFrame(const cv::Mat &frame) const { + cv::Mat rgb; + + if (frame.channels() == 4) { +#ifdef __APPLE__ + cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); +#else + cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); +#endif + } else if (frame.channels() == 3) { + rgb = frame; + } else { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Unsupported frame format: %d channels", frame.channels()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + if (rgb.size() != modelImageSize) { + cv::Mat resized; + cv::resize(rgb, resized, modelImageSize); + return resized; + } + + return rgb; +} + cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { if (tensorView.sizes.size() != 3) { char errorMessage[100]; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h index 38cf26dead..7b442df3e7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -86,11 +86,9 @@ class VisionModel : public BaseModel { /** * @brief Preprocess a camera frame for model input * - * This method should implement model-specific preprocessing such as: - * - Resizing to the model's expected input size - * - Color space conversion (e.g., BGR to RGB) - * - Normalization - * - Any other model-specific transformations + * Converts 4-channel frames (BGRA on iOS, RGBA on Android) to RGB and + * resizes to modelImageSize if needed. Subclasses may override for + * model-specific preprocessing (e.g., normalisation). * * @param frame Input frame from camera (already extracted and rotated by * FrameExtractor) @@ -99,7 +97,11 @@ class VisionModel : public BaseModel { * @note The input frame is already in RGB format and rotated 90° clockwise * @note This method is called under mutex protection in generateFromFrame() */ - virtual cv::Mat preprocessFrame(const cv::Mat &frame) const = 0; + virtual cv::Mat preprocessFrame(const cv::Mat &frame) const; + + /// Expected input image dimensions derived from the model's input shape. + /// Set by subclass constructors after loading the model. + cv::Size modelImageSize{0, 0}; /** * @brief Extract and preprocess frame from VisionCamera in one call diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index 2a00d5dce8..6f27b2f841 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -22,7 +22,7 @@ Classification::Classification(const std::string &modelSource, if (modelInputShape.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " + "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", modelInputShape.size()); throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, @@ -32,34 +32,6 @@ Classification::Classification(const std::string &modelSource, modelInputShape[modelInputShape.size() - 2]); } -cv::Mat Classification::preprocessFrame(const cv::Mat &frame) const { - cv::Mat rgb; - - if (frame.channels() == 4) { -#ifdef __APPLE__ - cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); -#else - cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); -#endif - } else if (frame.channels() == 3) { - rgb = frame; - } else { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unsupported frame format: %d channels", frame.channels()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - if (rgb.size() != modelImageSize) { - cv::Mat resized; - cv::resize(rgb, resized, modelImageSize); - return resized; - } - - return rgb; -} - std::unordered_map Classification::runInference(cv::Mat image) { std::scoped_lock lock(inference_mutex_); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h index 473d9b4bb3..9f62864b9e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h @@ -31,15 +31,10 @@ class Classification : public VisionModel { std::string_view, float> generateFromPixels(JSTensorViewIn pixelData); -protected: - cv::Mat preprocessFrame(const cv::Mat &frame) const override; - private: std::unordered_map runInference(cv::Mat image); std::unordered_map postprocess(const Tensor &tensor); - - cv::Size modelImageSize{0, 0}; }; } // namespace models::classification diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp index a82fffbb22..c54456c743 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp @@ -22,7 +22,7 @@ ImageEmbeddings::ImageEmbeddings( if (modelInputShape.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " + "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", modelInputShape.size()); throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, @@ -32,34 +32,6 @@ ImageEmbeddings::ImageEmbeddings( modelInputShape[modelInputShape.size() - 2]); } -cv::Mat ImageEmbeddings::preprocessFrame(const cv::Mat &frame) const { - cv::Mat rgb; - - if (frame.channels() == 4) { -#ifdef __APPLE__ - cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); -#else - cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); -#endif - } else if (frame.channels() == 3) { - rgb = frame; - } else { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unsupported frame format: %d channels", frame.channels()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - if (rgb.size() != modelImageSize) { - cv::Mat resized; - cv::resize(rgb, resized, modelImageSize); - return resized; - } - - return rgb; -} - std::shared_ptr ImageEmbeddings::runInference(cv::Mat image) { std::scoped_lock lock(inference_mutex_); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h index ec11ee5c69..3a20301724 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h @@ -31,13 +31,8 @@ class ImageEmbeddings final : public VisionModel { "Registered non-void function")]] std::shared_ptr generateFromPixels(JSTensorViewIn pixelData); -protected: - cv::Mat preprocessFrame(const cv::Mat &frame) const override; - private: std::shared_ptr runInference(cv::Mat image); - - cv::Size modelImageSize{0, 0}; }; } // namespace models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index bf209682ac..1c0fec27c2 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -23,7 +23,7 @@ ObjectDetection::ObjectDetection( if (modelInputShape.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " + "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", modelInputShape.size()); throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, @@ -45,39 +45,6 @@ ObjectDetection::ObjectDetection( } } -cv::Mat ObjectDetection::preprocessFrame(const cv::Mat &frame) const { - const std::vector tensorDims = getAllInputShapes()[0]; - cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1], - tensorDims[tensorDims.size() - 2]); - - cv::Mat rgb; - - if (frame.channels() == 4) { -#ifdef __APPLE__ - cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); -#else - cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); -#endif - } else if (frame.channels() == 3) { - rgb = frame; - } else { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unsupported frame format: %d channels", frame.channels()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - // Only resize if dimensions don't match - if (rgb.size() != tensorSize) { - cv::Mat resized; - cv::resize(rgb, resized, tensorSize); - return resized; - } - - return rgb; -} - std::vector ObjectDetection::postprocess(const std::vector &tensors, cv::Size originalSize, double detectionThreshold) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h index f1159d88e0..d94087688a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h @@ -77,7 +77,6 @@ class ObjectDetection : public VisionModel { protected: std::vector runInference(cv::Mat image, double detectionThreshold); - cv::Mat preprocessFrame(const cv::Mat &frame) const override; private: /** @@ -100,9 +99,6 @@ class ObjectDetection : public VisionModel { cv::Size originalSize, double detectionThreshold); - /// Expected input image dimensions derived from the model's input shape. - cv::Size modelImageSize{0, 0}; - /// Optional per-channel mean for input normalisation (set in constructor). std::optional normMean_; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp index 50834a1b82..102d8ee479 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp @@ -53,6 +53,12 @@ std::vector OCR::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) { auto frameObj = frameData.asObject(runtime); cv::Mat frame = ::rnexecutorch::utils::extractFrame(runtime, frameObj); + // Camera sensors deliver landscape frames; rotate to portrait orientation. + if (frame.cols > frame.rows) { + cv::Mat upright; + cv::rotate(frame, upright, cv::ROTATE_90_CLOCKWISE); + frame = std::move(upright); + } // extractFrame returns RGB; convert to BGR for consistency with readImage cv::cvtColor(frame, frame, cv::COLOR_RGB2BGR); return runInference(frame); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp index cf883728d9..429f53ba65 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp @@ -50,30 +50,6 @@ void BaseSemanticSegmentation::initModelImageSize() { numModelPixels = modelImageSize.area(); } -cv::Mat BaseSemanticSegmentation::preprocessFrame(const cv::Mat &frame) const { - cv::Mat rgb; - if (frame.channels() == 4) { -#ifdef __APPLE__ - cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); -#else - cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); -#endif - } else if (frame.channels() == 3) { - rgb = frame; - } else { - char msg[64]; - std::snprintf(msg, sizeof(msg), "Unsupported frame format: %d channels", - frame.channels()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, msg); - } - if (rgb.size() != modelImageSize) { - cv::Mat resized; - cv::resize(rgb, resized, modelImageSize); - return resized; - } - return rgb; -} - TensorPtr BaseSemanticSegmentation::preprocess(const std::string &imageSource, cv::Size &originalSize) { auto [inputTensor, origSize] = image_processing::readImageToTensor( diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index 97fe3815c8..00ce9282ab 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -42,15 +42,11 @@ class BaseSemanticSegmentation : public VisionModel { bool resize); protected: - cv::Mat preprocessFrame(const cv::Mat &frame) const override; - virtual image_segmentation::SegmentationResult computeResult(const Tensor &tensor, cv::Size originalSize, std::vector &allClasses, std::set> &classesOfInterest, bool resize); - - cv::Size modelImageSize; std::size_t numModelPixels; std::optional normMean_; std::optional normStd_; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index c334f5d842..eeedcdce2f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -24,7 +24,7 @@ StyleTransfer::StyleTransfer(const std::string &modelSource, if (modelInputShape.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " + "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", modelInputShape.size()); throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, @@ -34,34 +34,6 @@ StyleTransfer::StyleTransfer(const std::string &modelSource, modelInputShape[modelInputShape.size() - 2]); } -cv::Mat StyleTransfer::preprocessFrame(const cv::Mat &frame) const { - cv::Mat rgb; - - if (frame.channels() == 4) { -#ifdef __APPLE__ - cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); -#else - cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); -#endif - } else if (frame.channels() == 3) { - rgb = frame; - } else { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unsupported frame format: %d channels", frame.channels()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - if (rgb.size() != modelImageSize) { - cv::Mat resized; - cv::resize(rgb, resized, modelImageSize); - return resized; - } - - return rgb; -} - PixelDataResult StyleTransfer::postprocess(const Tensor &tensor, cv::Size outputSize) { // Convert tensor output (at modelImageSize) to CV_8UC3 BGR mat diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h index 99f9f4b3ac..d018e66e0a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -33,9 +33,6 @@ class StyleTransfer : public VisionModel { [[nodiscard("Registered non-void function")]] PixelDataResult generateFromPixels(JSTensorViewIn pixelData); -protected: - cv::Mat preprocessFrame(const cv::Mat &frame) const override; - private: // outputSize: size to resize the styled output to before returning. // Pass modelImageSize for real-time frame processing (avoids large allocs). @@ -43,8 +40,6 @@ class StyleTransfer : public VisionModel { PixelDataResult runInference(cv::Mat image, cv::Size outputSize); PixelDataResult postprocess(const Tensor &tensor, cv::Size outputSize); - - cv::Size modelImageSize{0, 0}; }; } // namespace models::style_transfer diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp index 71ea737f8e..9534dfeab6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp @@ -58,6 +58,12 @@ VerticalOCR::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) { auto frameObj = frameData.asObject(runtime); cv::Mat frame = ::rnexecutorch::utils::extractFrame(runtime, frameObj); + // Camera sensors deliver landscape frames; rotate to portrait orientation. + if (frame.cols > frame.rows) { + cv::Mat upright; + cv::rotate(frame, upright, cv::ROTATE_90_CLOCKWISE); + frame = std::move(upright); + } // extractFrame returns RGB; convert to BGR for consistency with readImage cv::cvtColor(frame, frame, cv::COLOR_RGB2BGR); return runInference(frame); From 159f24d46472e52cd928ab5bf2baffeecc52436d Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 11 Mar 2026 13:46:18 +0100 Subject: [PATCH 34/71] feat: some improvements --- .../app/object_detection/index.tsx | 43 ------ .../data_processing/ImageProcessing.cpp | 2 +- .../rnexecutorch/utils/FrameExtractor.cpp | 5 +- .../rnexecutorch/utils/FrameExtractor.h | 2 +- .../rnexecutorch/utils/FrameProcessor.cpp | 2 +- .../rnexecutorch/utils/FrameProcessor.h | 2 +- .../src/controllers/BaseOCRController.ts | 17 +-- .../computer_vision/useImageSegmentation.ts | 131 ------------------ .../useSemanticSegmentation.ts | 3 +- .../computer_vision/ObjectDetectionModule.ts | 7 - .../SemanticSegmentationModule.ts | 21 +-- .../modules/computer_vision/VisionModule.ts | 2 +- 12 files changed, 17 insertions(+), 220 deletions(-) delete mode 100644 packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index a5e36c344a..2f8fa6d58e 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -50,49 +50,6 @@ export default function ObjectDetectionScreen() { } }; - const runForwardPixels = async () => { - try { - console.log('Testing with hardcoded pixel data...'); - - // Create a simple 320x320 test image (all zeros - black image) - // In a real scenario, you would load actual image pixel data here - const width = 320; - const height = 320; - const channels = 3; // RGB - - // Create a black image (you can replace this with actual pixel data) - const rgbData = new Uint8Array(width * height * channels); - - // Optionally, add some test pattern (e.g., white square in center) - for (let y = 100; y < 220; y++) { - for (let x = 100; x < 220; x++) { - const idx = (y * width + x) * 3; - rgbData[idx + 0] = 255; // R - rgbData[idx + 1] = 255; // G - rgbData[idx + 2] = 255; // B - } - } - - const pixelData: PixelData = { - dataPtr: rgbData, - sizes: [height, width, channels], - scalarType: ScalarType.BYTE, - }; - - console.log('Running forward with hardcoded pixel data...', { - sizes: pixelData.sizes, - dataSize: pixelData.dataPtr.byteLength, - }); - - // Run inference using unified forward() API - const output = await ssdLite.forward(pixelData, 0.3); - console.log('Pixel data result:', output.length, 'detections'); - setResults(output); - } catch (e) { - console.error('Error in runForwardPixels:', e); - } - }; - if (!rfDetr.isReady) { return ( ({ - model, - preventLoad = false, -}: ImageSegmentationProps): ImageSegmentationType< - SegmentationLabels> -> => { - const [error, setError] = useState(null); - const [isReady, setIsReady] = useState(false); - const [isGenerating, setIsGenerating] = useState(false); - const [downloadProgress, setDownloadProgress] = useState(0); - const [instance, setInstance] = useState - > | null>(null); - const [runOnFrame, setRunOnFrame] = useState< - | (( - frame: Frame, - classesOfInterest?: string[], - resizeToInput?: boolean - ) => any) - | null - >(null); - - useEffect(() => { - if (preventLoad) return; - - let isMounted = true; - let currentInstance: ImageSegmentationModule> | null = null; - - (async () => { - setDownloadProgress(0); - setError(null); - setIsReady(false); - try { - currentInstance = await ImageSegmentationModule.fromModelName( - model, - (progress) => { - if (isMounted) setDownloadProgress(progress); - } - ); - if (isMounted) { - setInstance(currentInstance); - setIsReady(true); - const worklet = currentInstance.runOnFrame; - if (worklet) { - setRunOnFrame(() => worklet); - } - } - } catch (err) { - if (isMounted) setError(parseUnknownError(err)); - } - })(); - - return () => { - isMounted = false; - setIsReady(false); - setRunOnFrame(null); - currentInstance?.delete(); - }; - - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [model.modelName, model.modelSource, preventLoad]); - - const forward = async >>( - imageSource: string | PixelData, - classesOfInterest: K[] = [], - resizeToInput: boolean = true - ) => { - if (!isReady || !instance) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' - ); - } - if (isGenerating) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModelGenerating, - 'The model is currently generating. Please wait until previous model run is complete.' - ); - } - try { - setIsGenerating(true); - return await instance.forward( - imageSource, - classesOfInterest, - resizeToInput - ); - } finally { - setIsGenerating(false); - } - }; - - return { - error, - isReady, - isGenerating, - downloadProgress, - forward, - runOnFrame, - }; -}; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts index 19a5640318..622aa1a541 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts @@ -1,3 +1,4 @@ +import { PixelData } from '../..'; import { SemanticSegmentationModule, SegmentationLabels, @@ -50,7 +51,7 @@ export const useSemanticSegmentation = < }); const forward = >>( - imageSource: string, + imageSource: string | PixelData, classesOfInterest: K[] = [], resizeToInput: boolean = true ) => diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index bbb990f7b8..c24bbd1369 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -169,11 +169,4 @@ export class ObjectDetectionModule< nativeModule ); } - - async forward( - input: string | PixelData, - detectionThreshold: number = 0.5 - ): Promise { - return super.forward(input, detectionThreshold); - } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts index 14f2cb2439..841306ec5c 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts @@ -1,4 +1,9 @@ -import { ResourceSource, LabelEnum } from '../../types/common'; +import { + ResourceSource, + LabelEnum, + Frame, + PixelData, +} from '../../types/common'; import { DeeplabLabel, ModelNameOf, @@ -62,20 +67,6 @@ export type SegmentationLabels = type ResolveLabels = ResolveLabelsFor; -function isPixelData(input: unknown): input is PixelData { - return ( - typeof input === 'object' && - input !== null && - 'dataPtr' in input && - (input as any).dataPtr instanceof Uint8Array && - 'sizes' in input && - Array.isArray((input as any).sizes) && - (input as any).sizes.length === 3 && - 'scalarType' in input && - (input as any).scalarType === ScalarType.BYTE - ); -} - /** * Generic semantic segmentation module with type-safe label maps. * Use a model name (e.g. `'deeplab-v3-resnet50'`) as the generic parameter for built-in models, diff --git a/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts index 762d09987e..a6e7983768 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts @@ -15,7 +15,7 @@ import { Frame, PixelData, ScalarType } from '../../types/common'; * * @category Typescript API */ -function isPixelData(input: unknown): input is PixelData { +export function isPixelData(input: unknown): input is PixelData { return ( typeof input === 'object' && input !== null && From 8317400593321d76c92de71ec8ade4d2aa09ac4a Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 11 Mar 2026 14:20:46 +0100 Subject: [PATCH 35/71] fix: rebase conflict --- .../host_objects/ModelHostObject.h | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index b78cbc7a81..35da066f57 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -86,16 +86,6 @@ template class ModelHostObject : public JsiHostObject { addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, synchronousHostFunction<&Model::streamStop>, "streamStop")); - - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, - promiseHostFunction<&Model::generateFromPhonemes>, - "generateFromPhonemes")); - - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, - promiseHostFunction<&Model::streamFromPhonemes>, - "streamFromPhonemes")); } if constexpr (meta::SameAs) { @@ -197,6 +187,18 @@ template class ModelHostObject : public JsiHostObject { addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::stream>, "stream")); + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, synchronousHostFunction<&Model::streamStop>, + "streamStop")); + addFunctions( + JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::generateFromPhonemes>, + "generateFromPhonemes")); + + addFunctions( + JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::streamFromPhonemes>, + "streamFromPhonemes")); } if constexpr (meta::HasGenerateFromString) { From db503308e7b035f8f44b3a372392b6f328e4b917 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 11 Mar 2026 14:26:36 +0100 Subject: [PATCH 36/71] chore: completely remove api reference --- .../classes/ImageSegmentationModule.md | 356 ------------------ 1 file changed, 356 deletions(-) delete mode 100644 docs/docs/06-api-reference/classes/ImageSegmentationModule.md diff --git a/docs/docs/06-api-reference/classes/ImageSegmentationModule.md b/docs/docs/06-api-reference/classes/ImageSegmentationModule.md deleted file mode 100644 index 6b41289069..0000000000 --- a/docs/docs/06-api-reference/classes/ImageSegmentationModule.md +++ /dev/null @@ -1,356 +0,0 @@ -# Class: ImageSegmentationModule\ - -Defined in: [modules/computer_vision/ImageSegmentationModule.ts:60](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L60) - -Generic image segmentation module with type-safe label maps. -Use a model name (e.g. `'deeplab-v3'`) as the generic parameter for built-in models, -or a custom label enum for custom configs. - -## Extends - -- `BaseModule` - -## Type Parameters - -### T - -`T` _extends_ [`SegmentationModelName`](../type-aliases/SegmentationModelName.md) \| [`LabelEnum`](../type-aliases/LabelEnum.md) - -Either a built-in model name (`'deeplab-v3'`, `'selfie-segmentation'`) -or a custom [LabelEnum](../type-aliases/LabelEnum.md) label map. - -## Properties - -### generateFromFrame() - -> **generateFromFrame**: (`frameData`, ...`args`) => `any` - -Defined in: [modules/BaseModule.ts:56](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/BaseModule.ts#L56) - -Process a camera frame directly for real-time inference. - -This method is bound to a native JSI function after calling `load()`, -making it worklet-compatible and safe to call from VisionCamera's -frame processor thread. - -**Performance characteristics:** - -- **Zero-copy path**: When using `frame.getNativeBuffer()` from VisionCamera v5, - frame data is accessed directly without copying (fastest, recommended). -- **Copy path**: When using `frame.toArrayBuffer()`, pixel data is copied - from native to JS, then accessed from native code (slower, fallback). - -**Usage with VisionCamera:** - -```typescript -const frameOutput = useFrameOutput({ - pixelFormat: 'rgb', - onFrame(frame) { - 'worklet'; - // Zero-copy approach (recommended) - const nativeBuffer = frame.getNativeBuffer(); - const result = model.generateFromFrame( - { - nativeBuffer: nativeBuffer.pointer, - width: frame.width, - height: frame.height, - }, - ...args - ); - nativeBuffer.release(); - frame.dispose(); - }, -}); -``` - -#### Parameters - -##### frameData - -[`Frame`](../interfaces/Frame.md) - -Frame data object with either nativeBuffer (zero-copy) or data (ArrayBuffer) - -##### args - -...`any`[] - -Additional model-specific arguments (e.g., threshold, options) - -#### Returns - -`any` - -Model-specific output (e.g., detections, classifications, embeddings) - -#### See - -[Frame](../interfaces/Frame.md) for frame data format details - -#### Inherited from - -`BaseModule.generateFromFrame` - ---- - -### nativeModule - -> **nativeModule**: `any` = `null` - -Defined in: [modules/BaseModule.ts:17](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/BaseModule.ts#L17) - -**`Internal`** - -Native module instance (JSI Host Object) - -#### Inherited from - -`BaseModule.nativeModule` - -## Methods - -### delete() - -> **delete**(): `void` - -Defined in: [modules/BaseModule.ts:100](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/BaseModule.ts#L100) - -Unloads the model from memory and releases native resources. - -Always call this method when you're done with a model to prevent memory leaks. - -#### Returns - -`void` - -#### Inherited from - -`BaseModule.delete` - ---- - -### forward() - -> **forward**\<`K`\>(`imageSource`, `classesOfInterest`, `resizeToInput`): `Promise`\<`Record`\<`"ARGMAX"`, `Int32Array`\<`ArrayBufferLike`\>\> & `Record`\<`K`, `Float32Array`\<`ArrayBufferLike`\>\>\> - -Defined in: [modules/computer_vision/ImageSegmentationModule.ts:176](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L176) - -Executes the model's forward pass to perform semantic segmentation on the provided image. - -#### Type Parameters - -##### K - -`K` _extends_ `string` \| `number` \| `symbol` - -#### Parameters - -##### imageSource - -`string` - -A string representing the image source (e.g., a file path, URI, or Base64-encoded string). - -##### classesOfInterest - -`K`[] = `[]` - -An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. - -##### resizeToInput - -`boolean` = `true` - -Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. - -#### Returns - -`Promise`\<`Record`\<`"ARGMAX"`, `Int32Array`\<`ArrayBufferLike`\>\> & `Record`\<`K`, `Float32Array`\<`ArrayBufferLike`\>\>\> - -A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. - -#### Throws - -If the model is not loaded. - ---- - -### forwardET() - -> `protected` **forwardET**(`inputTensor`): `Promise`\<[`TensorPtr`](../interfaces/TensorPtr.md)[]\> - -Defined in: [modules/BaseModule.ts:80](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/BaseModule.ts#L80) - -**`Internal`** - -Runs the model's forward method with the given input tensors. -It returns the output tensors that mimic the structure of output from ExecuTorch. - -#### Parameters - -##### inputTensor - -[`TensorPtr`](../interfaces/TensorPtr.md)[] - -Array of input tensors. - -#### Returns - -`Promise`\<[`TensorPtr`](../interfaces/TensorPtr.md)[]\> - -Array of output tensors. - -#### Inherited from - -`BaseModule.forwardET` - ---- - -### getInputShape() - -> **getInputShape**(`methodName`, `index`): `Promise`\<`number`[]\> - -Defined in: [modules/BaseModule.ts:91](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/BaseModule.ts#L91) - -Gets the input shape for a given method and index. - -#### Parameters - -##### methodName - -`string` - -method name - -##### index - -`number` - -index of the argument which shape is requested - -#### Returns - -`Promise`\<`number`[]\> - -The input shape as an array of numbers. - -#### Inherited from - -`BaseModule.getInputShape` - ---- - -### load() - -> **load**(): `Promise`\<`void`\> - -Defined in: [modules/computer_vision/ImageSegmentationModule.ts:76](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L76) - -Load the model and prepare it for inference. - -#### Returns - -`Promise`\<`void`\> - -#### Overrides - -`BaseModule.load` - ---- - -### fromCustomConfig() - -> `static` **fromCustomConfig**\<`L`\>(`modelSource`, `config`, `onDownloadProgress`): `Promise`\<`ImageSegmentationModule`\<`L`\>\> - -Defined in: [modules/computer_vision/ImageSegmentationModule.ts:142](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L142) - -Creates a segmentation instance with a user-provided label map and custom config. -Use this when working with a custom-exported segmentation model that is not one of the built-in models. - -#### Type Parameters - -##### L - -`L` _extends_ `Readonly`\<`Record`\<`string`, `string` \| `number`\>\> - -#### Parameters - -##### modelSource - -[`ResourceSource`](../type-aliases/ResourceSource.md) - -A fetchable resource pointing to the model binary. - -##### config - -[`SegmentationConfig`](../type-aliases/SegmentationConfig.md)\<`L`\> - -A [SegmentationConfig](../type-aliases/SegmentationConfig.md) object with the label map and optional preprocessing parameters. - -##### onDownloadProgress - -(`progress`) => `void` - -Optional callback to monitor download progress, receiving a value between 0 and 1. - -#### Returns - -`Promise`\<`ImageSegmentationModule`\<`L`\>\> - -A Promise resolving to an `ImageSegmentationModule` instance typed to the provided label map. - -#### Example - -```ts -const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; -const segmentation = await ImageSegmentationModule.fromCustomConfig( - 'https://example.com/custom_model.pte', - { labelMap: MyLabels } -); -``` - ---- - -### fromModelName() - -> `static` **fromModelName**\<`C`\>(`config`, `onDownloadProgress`): `Promise`\<`ImageSegmentationModule`\<[`ModelNameOf`](../type-aliases/ModelNameOf.md)\<`C`\>\>\> - -Defined in: [modules/computer_vision/ImageSegmentationModule.ts:95](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L95) - -Creates a segmentation instance for a built-in model. -The config object is discriminated by `modelName` — each model can require different fields. - -#### Type Parameters - -##### C - -`C` _extends_ [`ModelSources`](../type-aliases/ModelSources.md) - -#### Parameters - -##### config - -`C` - -A [ModelSources](../type-aliases/ModelSources.md) object specifying which model to load and where to fetch it from. - -##### onDownloadProgress - -(`progress`) => `void` - -Optional callback to monitor download progress, receiving a value between 0 and 1. - -#### Returns - -`Promise`\<`ImageSegmentationModule`\<[`ModelNameOf`](../type-aliases/ModelNameOf.md)\<`C`\>\>\> - -A Promise resolving to an `ImageSegmentationModule` instance typed to the chosen model's label map. - -#### Example - -```ts -const segmentation = await ImageSegmentationModule.fromModelName({ - modelName: 'deeplab-v3', - modelSource: 'https://example.com/deeplab.pte', -}); -``` From 57e2379abc37c2bb6a4eed40b319929e62b54a45 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 11 Mar 2026 14:29:10 +0100 Subject: [PATCH 37/71] chore: image_segmentation -> semantic_segmentation - namespace --- .../common/rnexecutorch/host_objects/JsiConversions.h | 5 +++-- .../semantic_segmentation/BaseSemanticSegmentation.cpp | 7 ++++--- .../semantic_segmentation/BaseSemanticSegmentation.h | 4 ++-- .../rnexecutorch/models/semantic_segmentation/Types.h | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index f5ed4f804b..169763ac6e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -18,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -583,7 +583,8 @@ getJsiValue(const models::style_transfer::PixelDataResult &result, } inline jsi::Value getJsiValue( - const rnexecutorch::models::image_segmentation::SegmentationResult &result, + const rnexecutorch::models::semantic_segmentation::SegmentationResult + &result, jsi::Runtime &runtime) { jsi::Object dict(runtime); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp index 429f53ba65..c5802d09a7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp @@ -79,7 +79,7 @@ std::shared_ptr BaseSemanticSegmentation::generate( return populateDictionary(result.argmax, result.classBuffers); } -image_segmentation::SegmentationResult +semantic_segmentation::SegmentationResult BaseSemanticSegmentation::generateFromFrame( jsi::Runtime &runtime, const jsi::Value &frameData, std::set> classesOfInterest, bool resize) { @@ -106,7 +106,8 @@ BaseSemanticSegmentation::generateFromFrame( allClasses_, classesOfInterest, resize); } -image_segmentation::SegmentationResult BaseSemanticSegmentation::computeResult( +semantic_segmentation::SegmentationResult +BaseSemanticSegmentation::computeResult( const Tensor &tensor, cv::Size originalSize, std::vector &allClasses, std::set> &classesOfInterest, bool resize) { @@ -215,7 +216,7 @@ image_segmentation::SegmentationResult BaseSemanticSegmentation::computeResult( } } - return image_segmentation::SegmentationResult{argmax, buffersToReturn}; + return semantic_segmentation::SegmentationResult{argmax, buffersToReturn}; } std::shared_ptr BaseSemanticSegmentation::populateDictionary( diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index 00ce9282ab..6a5a6f9215 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -36,13 +36,13 @@ class BaseSemanticSegmentation : public VisionModel { // Must NOT use callInvoker — returns a plain SegmentationResult that // visionHostFunction converts to JSI via getJsiValue. [[nodiscard("Registered non-void function")]] - image_segmentation::SegmentationResult + semantic_segmentation::SegmentationResult generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, std::set> classesOfInterest, bool resize); protected: - virtual image_segmentation::SegmentationResult + virtual semantic_segmentation::SegmentationResult computeResult(const Tensor &tensor, cv::Size originalSize, std::vector &allClasses, std::set> &classesOfInterest, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h index b5d6f5067d..b305b96a70 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h @@ -5,7 +5,7 @@ #include #include -namespace rnexecutorch::models::image_segmentation { +namespace rnexecutorch::models::semantic_segmentation { struct SegmentationResult { std::shared_ptr argmax; @@ -14,4 +14,4 @@ struct SegmentationResult { classBuffers; }; -} // namespace rnexecutorch::models::image_segmentation +} // namespace rnexecutorch::models::semantic_segmentation From 7364d49f39f63fa60a4fa32d3fdecf72096bf807 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 11 Mar 2026 14:43:19 +0100 Subject: [PATCH 38/71] feat: cache getInputShape result across inferences --- .../rnexecutorch/models/VisionModel.cpp | 7 +++--- .../common/rnexecutorch/models/VisionModel.h | 17 ++++++++++---- .../models/classification/Classification.cpp | 12 ++++------ .../embeddings/image/ImageEmbeddings.cpp | 11 ++++------ .../object_detection/ObjectDetection.cpp | 22 +++++++++---------- .../BaseSemanticSegmentation.cpp | 20 ++++++++--------- .../models/style_transfer/StyleTransfer.cpp | 17 ++++++-------- 7 files changed, 51 insertions(+), 55 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index 78b9b042db..64ddad9862 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -42,9 +42,10 @@ cv::Mat VisionModel::preprocessFrame(const cv::Mat &frame) const { errorMessage); } - if (rgb.size() != modelImageSize) { + const cv::Size targetSize = modelInputSize(); + if (rgb.size() != targetSize) { cv::Mat resized; - cv::resize(rgb, resized, modelImageSize); + cv::resize(rgb, resized, targetSize); return resized; } @@ -87,4 +88,4 @@ cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { return image; } -} // namespace rnexecutorch::models \ No newline at end of file +} // namespace rnexecutorch::models diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h index 7b442df3e7..bb81c97723 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -99,9 +99,18 @@ class VisionModel : public BaseModel { */ virtual cv::Mat preprocessFrame(const cv::Mat &frame) const; - /// Expected input image dimensions derived from the model's input shape. - /// Set by subclass constructors after loading the model. - cv::Size modelImageSize{0, 0}; + /// Cached input tensor shape (getAllInputShapes()[0]). + /// Set once by each subclass constructor to avoid per-frame metadata lookups. + std::vector inputTensorDims_; + + /// Convenience accessor: spatial dimensions of the model input. + cv::Size modelInputSize() const { + if (inputTensorDims_.size() < 2) { + return {0, 0}; + } + return cv::Size(inputTensorDims_[inputTensorDims_.size() - 1], + inputTensorDims_[inputTensorDims_.size() - 2]); + } /** * @brief Extract and preprocess frame from VisionCamera in one call @@ -167,4 +176,4 @@ class VisionModel : public BaseModel { REGISTER_CONSTRUCTOR(models::VisionModel, std::string, std::shared_ptr); -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index 6f27b2f841..424cb5cb80 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -18,18 +18,16 @@ Classification::Classification(const std::string &modelSource, throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - std::vector modelInputShape = inputShapes[0]; - if (modelInputShape.size() < 2) { + inputTensorDims_ = inputShapes[0]; + if (inputTensorDims_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - modelInputShape.size()); + inputTensorDims_.size()); throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, errorMessage); } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); } std::unordered_map @@ -38,9 +36,8 @@ Classification::runInference(cv::Mat image) { cv::Mat preprocessed = preprocessFrame(image); - const std::vector tensorDims = getAllInputShapes()[0]; auto inputTensor = - image_processing::getTensorFromMatrix(tensorDims, preprocessed); + image_processing::getTensorFromMatrix(inputTensorDims_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { @@ -48,7 +45,6 @@ Classification::runInference(cv::Mat image) { "The model's forward function did not succeed. " "Ensure the model input is correct."); } - return postprocess(forwardResult->at(0).toTensor()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp index c54456c743..bb8c949348 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp @@ -18,18 +18,16 @@ ImageEmbeddings::ImageEmbeddings( throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - std::vector modelInputShape = inputTensors[0]; - if (modelInputShape.size() < 2) { + inputTensorDims_ = inputTensors[0]; + if (inputTensorDims_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - modelInputShape.size()); + inputTensorDims_.size()); throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, errorMessage); } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); } std::shared_ptr @@ -38,9 +36,8 @@ ImageEmbeddings::runInference(cv::Mat image) { cv::Mat preprocessed = preprocessFrame(image); - const std::vector tensorDims = getAllInputShapes()[0]; auto inputTensor = - image_processing::getTensorFromMatrix(tensorDims, preprocessed); + image_processing::getTensorFromMatrix(inputTensorDims_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 1c0fec27c2..49e642b5ed 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -19,18 +19,16 @@ ObjectDetection::ObjectDetection( throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - std::vector modelInputShape = inputTensors[0]; - if (modelInputShape.size() < 2) { + inputTensorDims_ = inputTensors[0]; + if (inputTensorDims_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - modelInputShape.size()); + inputTensorDims_.size()); throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, errorMessage); } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); } else if (!normMean.empty()) { @@ -48,10 +46,10 @@ ObjectDetection::ObjectDetection( std::vector ObjectDetection::postprocess(const std::vector &tensors, cv::Size originalSize, double detectionThreshold) { - float widthRatio = - static_cast(originalSize.width) / modelImageSize.width; + const cv::Size inputSize = modelInputSize(); + float widthRatio = static_cast(originalSize.width) / inputSize.width; float heightRatio = - static_cast(originalSize.height) / modelImageSize.height; + static_cast(originalSize.height) / inputSize.height; std::vector detections; auto bboxTensor = tensors.at(0).toTensor(); @@ -102,12 +100,12 @@ ObjectDetection::runInference(cv::Mat image, double detectionThreshold) { cv::Size originalSize = image.size(); cv::Mat preprocessed = preprocessFrame(image); - const std::vector tensorDims = getAllInputShapes()[0]; auto inputTensor = (normMean_ && normStd_) - ? image_processing::getTensorFromMatrix(tensorDims, preprocessed, - *normMean_, *normStd_) - : image_processing::getTensorFromMatrix(tensorDims, preprocessed); + ? image_processing::getTensorFromMatrix( + inputTensorDims_, preprocessed, *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(inputTensorDims_, + preprocessed); auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp index c5802d09a7..22ceb6fc49 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp @@ -38,22 +38,20 @@ void BaseSemanticSegmentation::initModelImageSize() { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - std::vector modelInputShape = inputShapes[0]; - if (modelInputShape.size() < 2) { + inputTensorDims_ = inputShapes[0]; + if (inputTensorDims_.size() < 2) { throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, "Unexpected model input size, expected at least 2 " "dimensions but got: " + - std::to_string(modelInputShape.size()) + "."); + std::to_string(inputTensorDims_.size()) + "."); } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); - numModelPixels = modelImageSize.area(); + numModelPixels = modelInputSize().area(); } TensorPtr BaseSemanticSegmentation::preprocess(const std::string &imageSource, cv::Size &originalSize) { auto [inputTensor, origSize] = image_processing::readImageToTensor( - imageSource, getAllInputShapes()[0], false, normMean_, normStd_); + imageSource, inputTensorDims_, false, normMean_, normStd_); originalSize = origSize; return inputTensor; } @@ -89,12 +87,12 @@ BaseSemanticSegmentation::generateFromFrame( cv::Mat preprocessed = preprocessFrame(frame); cv::Size originalSize = frame.size(); - const std::vector tensorDims = getAllInputShapes()[0]; auto inputTensor = (normMean_ && normStd_) - ? image_processing::getTensorFromMatrix(tensorDims, preprocessed, - *normMean_, *normStd_) - : image_processing::getTensorFromMatrix(tensorDims, preprocessed); + ? image_processing::getTensorFromMatrix( + inputTensorDims_, preprocessed, *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(inputTensorDims_, + preprocessed); auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index eeedcdce2f..83c8ab7cc7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -20,24 +20,22 @@ StyleTransfer::StyleTransfer(const std::string &modelSource, throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors"); } - std::vector modelInputShape = inputShapes[0]; - if (modelInputShape.size() < 2) { + inputTensorDims_ = inputShapes[0]; + if (inputTensorDims_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - modelInputShape.size()); + inputTensorDims_.size()); throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, errorMessage); } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); } PixelDataResult StyleTransfer::postprocess(const Tensor &tensor, cv::Size outputSize) { - // Convert tensor output (at modelImageSize) to CV_8UC3 BGR mat - cv::Mat mat = image_processing::getMatrixFromTensor(modelImageSize, tensor); + // Convert tensor output (at model input size) to CV_8UC3 BGR mat + cv::Mat mat = image_processing::getMatrixFromTensor(modelInputSize(), tensor); // Resize only if requested output differs from model output size if (mat.size() != outputSize) { @@ -64,9 +62,8 @@ PixelDataResult StyleTransfer::runInference(cv::Mat image, cv::Mat preprocessed = preprocessFrame(image); - const std::vector tensorDims = getAllInputShapes()[0]; auto inputTensor = - image_processing::getTensorFromMatrix(tensorDims, preprocessed); + image_processing::getTensorFromMatrix(inputTensorDims_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { @@ -95,7 +92,7 @@ PixelDataResult StyleTransfer::generateFromFrame(jsi::Runtime &runtime, // For real-time frame processing, output at modelImageSize to avoid // allocating large buffers (e.g. 1280x720x3 ~2.7MB) on every frame. - return runInference(frame, modelImageSize); + return runInference(frame, modelInputSize()); } PixelDataResult StyleTransfer::generateFromPixels(JSTensorViewIn pixelData) { From 55a98adf571b54b494c1c0b46cd7718498761e85 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 11 Mar 2026 16:16:28 +0100 Subject: [PATCH 39/71] reafactor: small fixes --- .../common/rnexecutorch/models/VisionModel.cpp | 1 - .../common/rnexecutorch/utils/FrameExtractor.cpp | 3 +-- packages/react-native-executorch/src/index.ts | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index 64ddad9862..9d9ab8a7ea 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -1,7 +1,6 @@ #include "VisionModel.h" #include #include -#include #include namespace rnexecutorch::models { diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp index d574b4bf07..d14c522184 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp @@ -47,8 +47,7 @@ cv::Mat extractFromCVPixelBuffer(void *pixelBuffer) { errorMessage); } - // Note: We don't unlock here - Vision Camera manages the lifecycle - // When frame.dispose() is called, Vision Camera will unlock and release + CVPixelBufferUnlockBaseAddress(buffer, kCVPixelBufferLock_ReadOnly); return mat; } diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index 1947fd7269..5bb4d3d134 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -129,7 +129,6 @@ export * from './hooks/computer_vision/useClassification'; export * from './hooks/computer_vision/useObjectDetection'; export * from './hooks/computer_vision/useStyleTransfer'; export * from './hooks/computer_vision/useSemanticSegmentation'; -export * from './hooks/computer_vision/useSemanticSegmentation'; export * from './hooks/computer_vision/useOCR'; export * from './hooks/computer_vision/useVerticalOCR'; export * from './hooks/computer_vision/useImageEmbeddings'; From 2d19b6db2d08be260b0f731a492d2b4fe91754a3 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Mar 2026 09:38:48 +0100 Subject: [PATCH 40/71] feat: seperate functions for pixelData/Frame -> Mat conversion --- .../host_objects/JSTensorViewIn.h | 3 + .../host_objects/ModelHostObject.h | 7 --- .../metaprogramming/TypeConcepts.h | 5 -- .../rnexecutorch/models/VisionModel.cpp | 46 +-------------- .../common/rnexecutorch/models/ocr/OCR.cpp | 58 +++++------------- .../BaseSemanticSegmentation.cpp | 2 +- .../BaseSemanticSegmentation.h | 5 +- .../models/vertical_ocr/VerticalOCR.cpp | 59 +++++-------------- .../rnexecutorch/utils/FrameProcessor.cpp | 48 +++++++++++++++ .../rnexecutorch/utils/FrameProcessor.h | 29 ++++++--- .../SemanticSegmentationModule.ts | 2 +- 11 files changed, 107 insertions(+), 157 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h index 4057950b23..1eb288ee30 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h @@ -1,5 +1,8 @@ #pragma once +#include +#include + namespace rnexecutorch { using executorch::aten::ScalarType; diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 35da066f57..7ef7953c85 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -208,13 +208,6 @@ template class ModelHostObject : public JsiHostObject { "generateFromString")); } - if constexpr (meta::HasGenerateFromString) { - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::generateFromString>, - "generateFromString")); - } - if constexpr (meta::HasGenerateFromFrame) { addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, visionHostFunction<&Model::generateFromFrame>, diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index 2d7612f250..f625bf6e76 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -11,11 +11,6 @@ concept DerivedFromOrSameAs = std::is_base_of_v; template concept SameAs = std::is_same_v; -template -concept HasGenerate = requires(T t) { - { &T::generate }; -}; - template concept HasGenerateFromString = requires(T t) { { &T::generateFromString }; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index 9d9ab8a7ea..63983ab893 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -9,17 +9,7 @@ using namespace facebook; cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) const { - auto frameObj = frameData.asObject(runtime); - cv::Mat frame = ::rnexecutorch::utils::extractFrame(runtime, frameObj); - - // Camera sensors natively deliver frames in landscape orientation. - // Rotate 90° CW so all models receive upright portrait frames. - if (frame.cols > frame.rows) { - cv::Mat upright; - cv::rotate(frame, upright, cv::ROTATE_90_CLOCKWISE); - return upright; - } - return frame; + return ::rnexecutorch::utils::frameToMat(runtime, frameData); } cv::Mat VisionModel::preprocessFrame(const cv::Mat &frame) const { @@ -52,39 +42,7 @@ cv::Mat VisionModel::preprocessFrame(const cv::Mat &frame) const { } cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { - if (tensorView.sizes.size() != 3) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Invalid pixel data: sizes must have 3 elements " - "[height, width, channels], got %zu", - tensorView.sizes.size()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - int32_t height = tensorView.sizes[0]; - int32_t width = tensorView.sizes[1]; - int32_t channels = tensorView.sizes[2]; - - if (channels != 3) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Invalid pixel data: expected 3 channels (RGB), got %d", - channels); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - if (tensorView.scalarType != ScalarType::Byte) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); - } - - uint8_t *dataPtr = static_cast(tensorView.dataPtr); - cv::Mat image(height, width, CV_8UC3, dataPtr); - - return image; + return ::rnexecutorch::utils::pixelsToMat(tensorView); } } // namespace rnexecutorch::models diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp index 102d8ee479..8de712edd4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp @@ -51,55 +51,25 @@ std::vector OCR::generateFromString(std::string input) { std::vector OCR::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) { - auto frameObj = frameData.asObject(runtime); - cv::Mat frame = ::rnexecutorch::utils::extractFrame(runtime, frameObj); - // Camera sensors deliver landscape frames; rotate to portrait orientation. - if (frame.cols > frame.rows) { - cv::Mat upright; - cv::rotate(frame, upright, cv::ROTATE_90_CLOCKWISE); - frame = std::move(upright); - } - // extractFrame returns RGB; convert to BGR for consistency with readImage - cv::cvtColor(frame, frame, cv::COLOR_RGB2BGR); - return runInference(frame); + cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData); + cv::Mat bgr; +#ifdef __APPLE__ + cv::cvtColor(frame, bgr, cv::COLOR_BGRA2BGR); +#elif defined(__ANDROID__) + cv::cvtColor(frame, bgr, cv::COLOR_RGBA2BGR); +#else + throw RnExecutorchError( + RnExecutorchErrorCode::PlatformNotSupported, + "generateFromFrame is not supported on this platform"); +#endif + return runInference(bgr); } std::vector OCR::generateFromPixels(JSTensorViewIn pixelData) { - if (pixelData.sizes.size() != 3) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Invalid pixel data: sizes must have 3 elements " - "[height, width, channels], got %zu", - pixelData.sizes.size()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - int32_t height = pixelData.sizes[0]; - int32_t width = pixelData.sizes[1]; - int32_t channels = pixelData.sizes[2]; - - if (channels != 3) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Invalid pixel data: expected 3 channels (RGB), got %d", - channels); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - if (pixelData.scalarType != executorch::aten::ScalarType::Byte) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); - } - - uint8_t *dataPtr = static_cast(pixelData.dataPtr); - // Input is RGB from JS; convert to BGR for consistency with readImage - cv::Mat rgbImage(height, width, CV_8UC3, dataPtr); cv::Mat image; - cv::cvtColor(rgbImage, image, cv::COLOR_RGB2BGR); + cv::cvtColor(::rnexecutorch::utils::pixelsToMat(pixelData), image, + cv::COLOR_RGB2BGR); return runInference(image); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp index 22ceb6fc49..35d8b1ff78 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp @@ -56,7 +56,7 @@ TensorPtr BaseSemanticSegmentation::preprocess(const std::string &imageSource, return inputTensor; } -std::shared_ptr BaseSemanticSegmentation::generate( +std::shared_ptr BaseSemanticSegmentation::generateFromString( std::string imageSource, std::set> classesOfInterest, bool resize) { std::scoped_lock lock(inference_mutex_); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index 6a5a6f9215..18533e9bd4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -29,8 +29,9 @@ class BaseSemanticSegmentation : public VisionModel { // Async path: called from promiseHostFunction on a thread-pool thread. // Returns a jsi::Object via callInvoker (safe to block there). [[nodiscard("Registered non-void function")]] std::shared_ptr - generate(std::string imageSource, - std::set> classesOfInterest, bool resize); + generateFromString(std::string imageSource, + std::set> classesOfInterest, + bool resize); // Sync path: called from visionHostFunction on the camera worklet thread. // Must NOT use callInvoker — returns a plain SegmentationResult that diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp index 9534dfeab6..8f6cbe5072 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp @@ -1,5 +1,4 @@ #include "VerticalOCR.h" -#include #include #include #include @@ -56,55 +55,25 @@ VerticalOCR::generateFromString(std::string input) { std::vector VerticalOCR::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) { - auto frameObj = frameData.asObject(runtime); - cv::Mat frame = ::rnexecutorch::utils::extractFrame(runtime, frameObj); - // Camera sensors deliver landscape frames; rotate to portrait orientation. - if (frame.cols > frame.rows) { - cv::Mat upright; - cv::rotate(frame, upright, cv::ROTATE_90_CLOCKWISE); - frame = std::move(upright); - } - // extractFrame returns RGB; convert to BGR for consistency with readImage - cv::cvtColor(frame, frame, cv::COLOR_RGB2BGR); - return runInference(frame); + cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData); + cv::Mat bgr; +#ifdef __APPLE__ + cv::cvtColor(frame, bgr, cv::COLOR_BGRA2BGR); +#elif defined(__ANDROID__) + cv::cvtColor(frame, bgr, cv::COLOR_RGBA2BGR); +#else + throw RnExecutorchError( + RnExecutorchErrorCode::PlatformNotSupported, + "generateFromFrame is not supported on this platform"); +#endif + return runInference(bgr); } std::vector VerticalOCR::generateFromPixels(JSTensorViewIn pixelData) { - if (pixelData.sizes.size() != 3) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Invalid pixel data: sizes must have 3 elements " - "[height, width, channels], got %zu", - pixelData.sizes.size()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - int32_t height = pixelData.sizes[0]; - int32_t width = pixelData.sizes[1]; - int32_t channels = pixelData.sizes[2]; - - if (channels != 3) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Invalid pixel data: expected 3 channels (RGB), got %d", - channels); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - if (pixelData.scalarType != executorch::aten::ScalarType::Byte) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); - } - - uint8_t *dataPtr = static_cast(pixelData.dataPtr); - // Input is RGB from JS; convert to BGR for consistency with readImage - cv::Mat rgbImage(height, width, CV_8UC3, dataPtr); cv::Mat image; - cv::cvtColor(rgbImage, image, cv::COLOR_RGB2BGR); + cv::cvtColor(::rnexecutorch::utils::pixelsToMat(pixelData), image, + cv::COLOR_RGB2BGR); return runInference(image); } diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp index 30238ad5c4..93f645b008 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp @@ -25,4 +25,52 @@ cv::Mat extractFrame(jsi::Runtime &runtime, const jsi::Object &frameData) { return extractFromNativeBuffer(bufferPtr); } + +cv::Mat frameToMat(jsi::Runtime &runtime, const jsi::Value &frameData) { + auto frameObj = frameData.asObject(runtime); + cv::Mat frame = extractFrame(runtime, frameObj); + + // Camera sensors deliver landscape frames; rotate to portrait orientation. + if (frame.cols > frame.rows) { + cv::Mat upright; + cv::rotate(frame, upright, cv::ROTATE_90_CLOCKWISE); + return upright; + } + return frame; +} + +cv::Mat pixelsToMat(const JSTensorViewIn &pixelData) { + if (pixelData.sizes.size() != 3) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Invalid pixel data: sizes must have 3 elements " + "[height, width, channels], got %zu", + pixelData.sizes.size()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + int32_t height = pixelData.sizes[0]; + int32_t width = pixelData.sizes[1]; + int32_t channels = pixelData.sizes[2]; + + if (channels != 3) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Invalid pixel data: expected 3 channels (RGB), got %d", + channels); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + if (pixelData.scalarType != executorch::aten::ScalarType::Byte) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidUserInput, + "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); + } + + uint8_t *dataPtr = static_cast(pixelData.dataPtr); + return cv::Mat(height, width, CV_8UC3, dataPtr); +} + } // namespace rnexecutorch::utils diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h index 403f4bde91..757fa95bbc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h @@ -2,6 +2,7 @@ #include #include +#include namespace rnexecutorch::utils { @@ -10,18 +11,30 @@ using namespace facebook; /** * @brief Extract cv::Mat from VisionCamera frame data via nativeBuffer * - * @param runtime JSI runtime - * @param frameData JSI object containing frame data from VisionCamera - * Expected properties: - * - nativeBuffer: BigInt pointer to native buffer - * - * @return cv::Mat wrapping the frame data (zero-copy) + * Returns an RGB mat (as delivered by the native buffer). * * @throws RnExecutorchError if nativeBuffer is not present or extraction fails - * * @note The returned cv::Mat does not own the data. - * Caller must ensure the source frame remains valid during use. */ cv::Mat extractFrame(jsi::Runtime &runtime, const jsi::Object &frameData); +/** + * @brief Convert a VisionCamera frame to a rotated RGB cv::Mat. + * + * Handles frame extraction and landscape→portrait rotation. + * Callers are responsible for any further colour space conversion. + */ +cv::Mat frameToMat(jsi::Runtime &runtime, const jsi::Value &frameData); + +/** + * @brief Validate a JSTensorViewIn and wrap its data in a RGB cv::Mat. + * + * Validates sizes (must be [H, W, 3]), scalar type (Byte), and returns a + * cv::Mat that wraps the raw pixel buffer without copying. + * Callers are responsible for any further colour space conversion. + * + * @throws RnExecutorchError on invalid input + */ +cv::Mat pixelsToMat(const JSTensorViewIn &pixelData); + } // namespace rnexecutorch::utils diff --git a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts index 841306ec5c..ce339727d7 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts @@ -284,7 +284,7 @@ export class SemanticSegmentationModule< String(label) ); - const nativeResult = await this.nativeModule.generate( + const nativeResult = await this.nativeModule.generateFromString( input, classesOfInterestNames, resizeToInput From 978f9e775049bb644372e4bb001f1482dda6fb4a Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Mar 2026 14:09:21 +0100 Subject: [PATCH 41/71] refactor: unify the contract the visionModels follow --- .../metaprogramming/TypeConcepts.h | 5 + .../rnexecutorch/models/VisionModel.cpp | 34 ++--- .../common/rnexecutorch/models/VisionModel.h | 47 +++---- .../models/classification/Classification.cpp | 10 +- .../embeddings/image/ImageEmbeddings.cpp | 10 +- .../object_detection/ObjectDetection.cpp | 12 +- .../BaseSemanticSegmentation.cpp | 120 ++++++------------ .../BaseSemanticSegmentation.h | 29 ++--- .../models/style_transfer/StyleTransfer.cpp | 10 +- 9 files changed, 103 insertions(+), 174 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index f625bf6e76..2d7612f250 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -11,6 +11,11 @@ concept DerivedFromOrSameAs = std::is_base_of_v; template concept SameAs = std::is_same_v; +template +concept HasGenerate = requires(T t) { + { &T::generate }; +}; + template concept HasGenerateFromString = requires(T t) { { &T::generateFromString }; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index 63983ab893..89727361ca 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -9,36 +9,24 @@ using namespace facebook; cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) const { - return ::rnexecutorch::utils::frameToMat(runtime, frameData); -} - -cv::Mat VisionModel::preprocessFrame(const cv::Mat &frame) const { + cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData); cv::Mat rgb; - - if (frame.channels() == 4) { #ifdef __APPLE__ - cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); + cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); #else - cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); + cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); #endif - } else if (frame.channels() == 3) { - rgb = frame; - } else { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unsupported frame format: %d channels", frame.channels()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } + return rgb; +} +cv::Mat VisionModel::preprocess(const cv::Mat &image) const { const cv::Size targetSize = modelInputSize(); - if (rgb.size() != targetSize) { - cv::Mat resized; - cv::resize(rgb, resized, targetSize); - return resized; + if (image.size() == targetSize) { + return image; } - - return rgb; + cv::Mat resized; + cv::resize(image, resized, targetSize); + return resized; } cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h index bb81c97723..766e6ff968 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -23,8 +23,8 @@ namespace models { * Usage: * Subclasses should: * 1. Inherit from VisionModel instead of BaseModel - * 2. Implement preprocessFrame() with model-specific preprocessing - * 3. Delegate to runInference() which handles locking internally + * 2. Optionally override preprocess() for model-specific preprocessing + * 3. Implement runInference() which acquires the lock internally * * Example: * @code @@ -84,55 +84,42 @@ class VisionModel : public BaseModel { mutable std::mutex inference_mutex_; /** - * @brief Preprocess a camera frame for model input + * @brief Resize an RGB image to the model's expected input size * - * Converts 4-channel frames (BGRA on iOS, RGBA on Android) to RGB and - * resizes to modelImageSize if needed. Subclasses may override for + * Resizes to modelInputSize() if needed. Subclasses may override for * model-specific preprocessing (e.g., normalisation). * - * @param frame Input frame from camera (already extracted and rotated by - * FrameExtractor) - * @return Preprocessed cv::Mat ready for tensor conversion + * @param image Input image in RGB format + * @return cv::Mat resized to modelInputSize(), in RGB format * - * @note The input frame is already in RGB format and rotated 90° clockwise - * @note This method is called under mutex protection in generateFromFrame() + * @note Called from runInference() under the inference mutex */ - virtual cv::Mat preprocessFrame(const cv::Mat &frame) const; + virtual cv::Mat preprocess(const cv::Mat &image) const; /// Cached input tensor shape (getAllInputShapes()[0]). /// Set once by each subclass constructor to avoid per-frame metadata lookups. - std::vector inputTensorDims_; + std::vector modelInputShape_; /// Convenience accessor: spatial dimensions of the model input. cv::Size modelInputSize() const { - if (inputTensorDims_.size() < 2) { + if (modelInputShape_.size() < 2) { return {0, 0}; } - return cv::Size(inputTensorDims_[inputTensorDims_.size() - 1], - inputTensorDims_[inputTensorDims_.size() - 2]); + return cv::Size(modelInputShape_[modelInputShape_.size() - 1], + modelInputShape_[modelInputShape_.size() - 2]); } /** - * @brief Extract and preprocess frame from VisionCamera in one call + * @brief Extract an RGB cv::Mat from a VisionCamera frame * - * This is a convenience method that combines frame extraction and - * preprocessing. It handles both nativeBuffer (zero-copy) and ArrayBuffer - * paths automatically. + * Calls frameToMat() then converts the raw 4-channel frame + * (BGRA on iOS, RGBA on Android) to RGB. * * @param runtime JSI runtime * @param frameData JSI value containing frame data from VisionCamera + * @return cv::Mat in RGB format (3 channels) * - * @return Preprocessed cv::Mat ready for tensor conversion - * - * @throws std::runtime_error if frame extraction fails - * - * @note This method does NOT acquire the inference mutex - caller is - * responsible - * @note Typical usage: - * @code - * cv::Mat preprocessed = extractFromFrame(runtime, frameData); - * auto tensor = image_processing::getTensorFromMatrix(dims, preprocessed); - * @endcode + * @note Does NOT acquire the inference mutex — caller is responsible */ cv::Mat extractFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) const; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index 424cb5cb80..f713b59605 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -18,13 +18,13 @@ Classification::Classification(const std::string &modelSource, throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - inputTensorDims_ = inputShapes[0]; - if (inputTensorDims_.size() < 2) { + modelInputShape_ = inputShapes[0]; + if (modelInputShape_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - inputTensorDims_.size()); + modelInputShape_.size()); throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, errorMessage); } @@ -34,10 +34,10 @@ std::unordered_map Classification::runInference(cv::Mat image) { std::scoped_lock lock(inference_mutex_); - cv::Mat preprocessed = preprocessFrame(image); + cv::Mat preprocessed = preprocess(image); auto inputTensor = - image_processing::getTensorFromMatrix(inputTensorDims_, preprocessed); + image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp index bb8c949348..f742f13f65 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp @@ -18,13 +18,13 @@ ImageEmbeddings::ImageEmbeddings( throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - inputTensorDims_ = inputTensors[0]; - if (inputTensorDims_.size() < 2) { + modelInputShape_ = inputTensors[0]; + if (modelInputShape_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - inputTensorDims_.size()); + modelInputShape_.size()); throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, errorMessage); } @@ -34,10 +34,10 @@ std::shared_ptr ImageEmbeddings::runInference(cv::Mat image) { std::scoped_lock lock(inference_mutex_); - cv::Mat preprocessed = preprocessFrame(image); + cv::Mat preprocessed = preprocess(image); auto inputTensor = - image_processing::getTensorFromMatrix(inputTensorDims_, preprocessed); + image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 49e642b5ed..d30fc8aeed 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -19,13 +19,13 @@ ObjectDetection::ObjectDetection( throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - inputTensorDims_ = inputTensors[0]; - if (inputTensorDims_.size() < 2) { + modelInputShape_ = inputTensors[0]; + if (modelInputShape_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - inputTensorDims_.size()); + modelInputShape_.size()); throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, errorMessage); } @@ -98,13 +98,13 @@ ObjectDetection::runInference(cv::Mat image, double detectionThreshold) { std::scoped_lock lock(inference_mutex_); cv::Size originalSize = image.size(); - cv::Mat preprocessed = preprocessFrame(image); + cv::Mat preprocessed = preprocess(image); auto inputTensor = (normMean_ && normStd_) ? image_processing::getTensorFromMatrix( - inputTensorDims_, preprocessed, *normMean_, *normStd_) - : image_processing::getTensorFromMatrix(inputTensorDims_, + modelInputShape_, preprocessed, *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp index 35d8b1ff78..5ecf7493c4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp @@ -1,8 +1,6 @@ #include "BaseSemanticSegmentation.h" #include "jsi/jsi.h" -#include - #include #include #include @@ -38,70 +36,67 @@ void BaseSemanticSegmentation::initModelImageSize() { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - inputTensorDims_ = inputShapes[0]; - if (inputTensorDims_.size() < 2) { + modelInputShape_ = inputShapes[0]; + if (modelInputShape_.size() < 2) { throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, "Unexpected model input size, expected at least 2 " "dimensions but got: " + - std::to_string(inputTensorDims_.size()) + "."); + std::to_string(modelInputShape_.size()) + "."); } numModelPixels = modelInputSize().area(); } -TensorPtr BaseSemanticSegmentation::preprocess(const std::string &imageSource, - cv::Size &originalSize) { - auto [inputTensor, origSize] = image_processing::readImageToTensor( - imageSource, inputTensorDims_, false, normMean_, normStd_); - originalSize = origSize; - return inputTensor; -} - -std::shared_ptr BaseSemanticSegmentation::generateFromString( - std::string imageSource, - std::set> classesOfInterest, bool resize) { +semantic_segmentation::SegmentationResult +BaseSemanticSegmentation::runInference( + cv::Mat image, cv::Size originalSize, + std::set> &classesOfInterest, bool resize) { std::scoped_lock lock(inference_mutex_); - cv::Size originalSize; - auto inputTensor = preprocess(imageSource, originalSize); + cv::Mat preprocessed = VisionModel::preprocess(image); + auto inputTensor = + (normMean_ && normStd_) + ? image_processing::getTensorFromMatrix( + modelInputShape_, preprocessed, *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(modelInputShape_, + preprocessed); auto forwardResult = BaseModel::forward(inputTensor); - if (!forwardResult.ok()) { throw RnExecutorchError(forwardResult.error(), "The model's forward function did not succeed. " "Ensure the model input is correct."); } - auto result = computeResult(forwardResult->at(0).toTensor(), originalSize, - allClasses_, classesOfInterest, resize); - return populateDictionary(result.argmax, result.classBuffers); + return computeResult(forwardResult->at(0).toTensor(), originalSize, + allClasses_, classesOfInterest, resize); } semantic_segmentation::SegmentationResult -BaseSemanticSegmentation::generateFromFrame( - jsi::Runtime &runtime, const jsi::Value &frameData, +BaseSemanticSegmentation::generateFromString( + std::string imageSource, std::set> classesOfInterest, bool resize) { - std::scoped_lock lock(inference_mutex_); + cv::Mat imageBGR = image_processing::readImage(imageSource); + cv::Size originalSize = imageBGR.size(); + cv::Mat imageRGB; + cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); - cv::Mat frame = extractFromFrame(runtime, frameData); - cv::Mat preprocessed = preprocessFrame(frame); - cv::Size originalSize = frame.size(); - - auto inputTensor = - (normMean_ && normStd_) - ? image_processing::getTensorFromMatrix( - inputTensorDims_, preprocessed, *normMean_, *normStd_) - : image_processing::getTensorFromMatrix(inputTensorDims_, - preprocessed); + return runInference(imageRGB, originalSize, classesOfInterest, resize); +} - auto forwardResult = BaseModel::forward(inputTensor); - if (!forwardResult.ok()) { - throw RnExecutorchError(forwardResult.error(), - "The model's forward function did not succeed."); - } +semantic_segmentation::SegmentationResult +BaseSemanticSegmentation::generateFromPixels( + JSTensorViewIn pixelData, + std::set> classesOfInterest, bool resize) { + cv::Mat image = extractFromPixels(pixelData); + return runInference(image, image.size(), classesOfInterest, resize); +} - return computeResult(forwardResult->at(0).toTensor(), originalSize, - allClasses_, classesOfInterest, resize); +semantic_segmentation::SegmentationResult +BaseSemanticSegmentation::generateFromFrame( + jsi::Runtime &runtime, const jsi::Value &frameData, + std::set> classesOfInterest, bool resize) { + cv::Mat frame = extractFromFrame(runtime, frameData); + return runInference(frame, frame.size(), classesOfInterest, resize); } semantic_segmentation::SegmentationResult @@ -217,45 +212,4 @@ BaseSemanticSegmentation::computeResult( return semantic_segmentation::SegmentationResult{argmax, buffersToReturn}; } -std::shared_ptr BaseSemanticSegmentation::populateDictionary( - std::shared_ptr argmax, - std::shared_ptr< - std::unordered_map>> - classesToOutput) { - auto promisePtr = std::make_shared>(); - std::future doneFuture = promisePtr->get_future(); - - std::shared_ptr dictPtr = nullptr; - callInvoker->invokeAsync( - [argmax, classesToOutput, &dictPtr, promisePtr](jsi::Runtime &runtime) { - dictPtr = std::make_shared(runtime); - auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax); - - auto int32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Int32Array"); - auto int32Array = - int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) - .getObject(runtime); - dictPtr->setProperty(runtime, "ARGMAX", int32Array); - - for (auto &[classLabel, owningBuffer] : *classesToOutput) { - auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); - - auto float32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Float32Array"); - auto float32Array = - float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) - .getObject(runtime); - - dictPtr->setProperty( - runtime, jsi::String::createFromAscii(runtime, classLabel.data()), - float32Array); - } - promisePtr->set_value(); - }); - - doneFuture.wait(); - return dictPtr; -} - } // namespace rnexecutorch::models::semantic_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index 18533e9bd4..a30ae375bf 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -1,13 +1,10 @@ #pragma once -#include -#include #include #include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include #include #include @@ -16,7 +13,6 @@ namespace models::semantic_segmentation { using namespace facebook; using executorch::aten::Tensor; -using executorch::extension::TensorPtr; class BaseSemanticSegmentation : public VisionModel { public: @@ -26,16 +22,18 @@ class BaseSemanticSegmentation : public VisionModel { std::vector allClasses, std::shared_ptr callInvoker); - // Async path: called from promiseHostFunction on a thread-pool thread. - // Returns a jsi::Object via callInvoker (safe to block there). - [[nodiscard("Registered non-void function")]] std::shared_ptr + [[nodiscard("Registered non-void function")]] + semantic_segmentation::SegmentationResult generateFromString(std::string imageSource, std::set> classesOfInterest, bool resize); - // Sync path: called from visionHostFunction on the camera worklet thread. - // Must NOT use callInvoker — returns a plain SegmentationResult that - // visionHostFunction converts to JSI via getJsiValue. + [[nodiscard("Registered non-void function")]] + semantic_segmentation::SegmentationResult + generateFromPixels(JSTensorViewIn pixelData, + std::set> classesOfInterest, + bool resize); + [[nodiscard("Registered non-void function")]] semantic_segmentation::SegmentationResult generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, @@ -56,13 +54,10 @@ class BaseSemanticSegmentation : public VisionModel { private: void initModelImageSize(); - TensorPtr preprocess(const std::string &imageSource, cv::Size &originalSize); - - std::shared_ptr populateDictionary( - std::shared_ptr argmax, - std::shared_ptr< - std::unordered_map>> - classesToOutput); + semantic_segmentation::SegmentationResult + runInference(cv::Mat image, cv::Size originalSize, + std::set> &classesOfInterest, + bool resize); }; } // namespace models::semantic_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index 83c8ab7cc7..96b0cdcac5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -20,13 +20,13 @@ StyleTransfer::StyleTransfer(const std::string &modelSource, throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors"); } - inputTensorDims_ = inputShapes[0]; - if (inputTensorDims_.size() < 2) { + modelInputShape_ = inputShapes[0]; + if (modelInputShape_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - inputTensorDims_.size()); + modelInputShape_.size()); throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, errorMessage); } @@ -60,10 +60,10 @@ PixelDataResult StyleTransfer::runInference(cv::Mat image, cv::Size originalSize) { std::scoped_lock lock(inference_mutex_); - cv::Mat preprocessed = preprocessFrame(image); + cv::Mat preprocessed = preprocess(image); auto inputTensor = - image_processing::getTensorFromMatrix(inputTensorDims_, preprocessed); + image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { From 87faec71972c3244b5567d8190f6f37ef4850c9d Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Mar 2026 14:40:41 +0100 Subject: [PATCH 42/71] feat: remove BaseLabeldModule and use VisionLabeledModule --- .../src/modules/BaseLabeledModule.ts | 59 ---------- .../computer_vision/ObjectDetectionModule.ts | 4 +- .../SemanticSegmentationModule.ts | 108 ++---------------- .../computer_vision/VisionLabeledModule.ts | 45 +++++++- 4 files changed, 54 insertions(+), 162 deletions(-) delete mode 100644 packages/react-native-executorch/src/modules/BaseLabeledModule.ts diff --git a/packages/react-native-executorch/src/modules/BaseLabeledModule.ts b/packages/react-native-executorch/src/modules/BaseLabeledModule.ts deleted file mode 100644 index 01678f83d5..0000000000 --- a/packages/react-native-executorch/src/modules/BaseLabeledModule.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { ResourceFetcher } from '../utils/ResourceFetcher'; -import { LabelEnum, ResourceSource } from '../types/common'; -import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; -import { RnExecutorchError } from '../errors/errorUtils'; -import { BaseModule } from './BaseModule'; - -/** - * Fetches a model binary and returns its local path, throwing if the download - * was interrupted (paused or cancelled). - * - * @internal - */ -export async function fetchModelPath( - source: ResourceSource, - onDownloadProgress: (progress: number) => void -): Promise { - const paths = await ResourceFetcher.fetch(onDownloadProgress, source); - if (!paths?.[0]) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'The download has been interrupted. Please retry.' - ); - } - return paths[0]; -} - -/** - * Given a model configs record (mapping model names to `{ labelMap }`) and a - * type `T` (either a model name key or a raw {@link LabelEnum}), resolves to - * the label map for that model or `T` itself. - * - * @internal - */ -export type ResolveLabels< - T, - Configs extends Record, -> = T extends keyof Configs - ? Configs[T]['labelMap'] - : T extends LabelEnum - ? T - : never; - -/** - * Base class for vision modules that carry a type-safe label map. - * - * @typeParam LabelMap - The resolved {@link LabelEnum} for the model's output classes. - * @internal - */ -export abstract class BaseLabeledModule< - LabelMap extends LabelEnum, -> extends BaseModule { - protected readonly labelMap: LabelMap; - - protected constructor(labelMap: LabelMap, nativeModule: unknown) { - super(); - this.labelMap = labelMap; - this.nativeModule = nativeModule; - } -} diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index c24bbd1369..20eb51db56 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -13,8 +13,8 @@ import { import { fetchModelPath, ResolveLabels as ResolveLabelsFor, -} from '../BaseLabeledModule'; -import { VisionLabeledModule } from './VisionLabeledModule'; + VisionLabeledModule, +} from './VisionLabeledModule'; const ModelConfigs = { 'ssdlite-320-mobilenet-v3-large': { diff --git a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts index ce339727d7..a35df92d83 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts @@ -1,9 +1,4 @@ -import { - ResourceSource, - LabelEnum, - Frame, - PixelData, -} from '../../types/common'; +import { ResourceSource, LabelEnum, PixelData } from '../../types/common'; import { DeeplabLabel, ModelNameOf, @@ -12,14 +7,12 @@ import { SemanticSegmentationModelName, SelfieSegmentationLabel, } from '../../types/semanticSegmentation'; -import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; -import { RnExecutorchError } from '../../errors/errorUtils'; import { IMAGENET1K_MEAN, IMAGENET1K_STD } from '../../constants/commonVision'; import { - BaseLabeledModule, fetchModelPath, ResolveLabels as ResolveLabelsFor, -} from '../BaseLabeledModule'; + VisionLabeledModule, +} from './VisionLabeledModule'; const PascalVocSegmentationConfig = { labelMap: DeeplabLabel, @@ -84,78 +77,14 @@ type ResolveLabels = */ export class SemanticSegmentationModule< T extends SemanticSegmentationModelName | LabelEnum, -> extends BaseLabeledModule> { +> extends VisionLabeledModule< + Record<'ARGMAX', Int32Array> & Record, Float32Array>, + ResolveLabels +> { private constructor(labelMap: ResolveLabels, nativeModule: unknown) { super(labelMap, nativeModule); } - /** - * Synchronous worklet function for real-time VisionCamera frame processing. - * Automatically handles native buffer extraction and cleanup. - * - * **Use this for VisionCamera frame processing in worklets.** - * For async processing, use `forward()` instead. - * - * Available after model is loaded. - * - * @example - * ```typescript - * const [runOnFrame, setRunOnFrame] = useState(null); - * setRunOnFrame(() => segmentation.runOnFrame); - * - * const frameOutput = useFrameOutput({ - * onFrame(frame) { - * 'worklet'; - * if (!runOnFrame) return; - * const result = runOnFrame(frame, [], true); - * frame.dispose(); - * } - * }); - * ``` - * - * @param frame - VisionCamera Frame object - * @param classesOfInterest - Labels for which to return per-class probability masks. - * @param resizeToInput - Whether to resize masks to original frame dimensions. Defaults to `true`. - */ - get runOnFrame(): - | (( - frame: Frame, - classesOfInterest?: string[], - resizeToInput?: boolean - ) => any) - | null { - if (!this.nativeModule?.generateFromFrame) { - return null; - } - - const nativeGenerateFromFrame = this.nativeModule.generateFromFrame; - - return ( - frame: any, - classesOfInterest: string[] = [], - resizeToInput: boolean = true - ): any => { - 'worklet'; - - let nativeBuffer: any = null; - try { - nativeBuffer = frame.getNativeBuffer(); - const frameData = { - nativeBuffer: nativeBuffer.pointer, - }; - return nativeGenerateFromFrame( - frameData, - classesOfInterest, - resizeToInput - ); - } finally { - if (nativeBuffer?.release) { - nativeBuffer.release(); - } - } - }; - } - /** * Creates a segmentation instance for a built-in model. * The config object is discriminated by `modelName` — each model can require different fields. @@ -268,29 +197,12 @@ export class SemanticSegmentationModule< * @returns A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. * @throws {RnExecutorchError} If the model is not loaded. */ - async forward>( + override async forward>( input: string | PixelData, classesOfInterest: K[] = [], resizeToInput: boolean = true ): Promise & Record> { - if (this.nativeModule == null) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded.' - ); - } - - const classesOfInterestNames = classesOfInterest.map((label) => - String(label) - ); - - const nativeResult = await this.nativeModule.generateFromString( - input, - classesOfInterestNames, - resizeToInput - ); - - return nativeResult as Record<'ARGMAX', Int32Array> & - Record; + const classesOfInterestNames = classesOfInterest.map(String); + return super.forward(input, classesOfInterestNames, resizeToInput); } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts index 61a0bab091..1cda359db5 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts @@ -1,6 +1,45 @@ -import { LabelEnum } from '../../types/common'; +import { LabelEnum, ResourceSource } from '../../types/common'; +import { ResourceFetcher } from '../../utils/ResourceFetcher'; +import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; +import { RnExecutorchError } from '../../errors/errorUtils'; import { VisionModule } from './VisionModule'; +/** + * Fetches a model binary and returns its local path, throwing if the download + * was interrupted (paused or cancelled). + * + * @internal + */ +export async function fetchModelPath( + source: ResourceSource, + onDownloadProgress: (progress: number) => void +): Promise { + const paths = await ResourceFetcher.fetch(onDownloadProgress, source); + if (!paths?.[0]) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. Please retry.' + ); + } + return paths[0]; +} + +/** + * Given a model configs record (mapping model names to `{ labelMap }`) and a + * type `T` (either a model name key or a raw {@link LabelEnum}), resolves to + * the label map for that model or `T` itself. + * + * @internal + */ +export type ResolveLabels< + T, + Configs extends Record, +> = T extends keyof Configs + ? Configs[T]['labelMap'] + : T extends LabelEnum + ? T + : never; + /** * Base class for computer vision modules that carry a type-safe label map * and support the full VisionModule API (string/PixelData forward + runOnFrame). @@ -10,8 +49,8 @@ import { VisionModule } from './VisionModule'; * @internal */ export abstract class VisionLabeledModule< - TOutput, - LabelMap extends LabelEnum, + TOutput = unknown, + LabelMap extends LabelEnum = LabelEnum, > extends VisionModule { protected readonly labelMap: LabelMap; From cd3c946e9aba6b1c5d3fbd9a3a57689556720ff8 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Mar 2026 14:55:54 +0100 Subject: [PATCH 43/71] feat: useModuleFactory handles runOnFrame, update stale comments --- .../computer_vision/useClassification.ts | 2 +- .../computer_vision/useImageEmbeddings.ts | 2 +- .../src/hooks/computer_vision/useOCR.ts | 2 +- .../computer_vision/useObjectDetection.ts | 5 +--- .../useSemanticSegmentation.ts | 4 +--- .../src/hooks/useModuleFactory.ts | 11 ++++++++- .../computer_vision/ImageEmbeddingsModule.ts | 3 +-- .../modules/computer_vision/VisionModule.ts | 24 +++++++++---------- 8 files changed, 28 insertions(+), 25 deletions(-) diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts index c014d6b0ed..a6ef5c6a14 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts @@ -9,7 +9,7 @@ import { useModuleFactory } from '../useModuleFactory'; * React hook for managing a Classification model instance. * * @category Hooks - * @param ClassificationProps - Configuration object containing `model` source and optional `preventLoad` flag. + * @param props - Configuration object containing `model` source and optional `preventLoad` flag. * @returns Ready to use Classification model. */ export const useClassification = ({ diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts index b4e79c9263..07376bddee 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts @@ -9,7 +9,7 @@ import { useModuleFactory } from '../useModuleFactory'; * React hook for managing an Image Embeddings model instance. * * @category Hooks - * @param ImageEmbeddingsProps - Configuration object containing `model` source and optional `preventLoad` flag. + * @param props - Configuration object containing `model` source and optional `preventLoad` flag. * @returns Ready to use Image Embeddings model. */ export const useImageEmbeddings = ({ diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts index 208824b8b8..31061d2b64 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts @@ -8,7 +8,7 @@ import { OCRDetection, OCRProps, OCRType } from '../../types/ocr'; * React hook for managing an OCR instance. * * @category Hooks - * @param OCRProps - Configuration object containing `model` sources and optional `preventLoad` flag. + * @param props - Configuration object containing `model` sources and optional `preventLoad` flag. * @returns Ready to use OCR model. */ export const useOCR = ({ model, preventLoad = false }: OCRProps): OCRType => { diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts index 81c81ce22f..5dfc552b63 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts @@ -7,7 +7,6 @@ import { ObjectDetectionProps, ObjectDetectionType, } from '../../types/objectDetection'; -import { useMemo } from 'react'; import { PixelData } from '../../types/common'; import { useModuleFactory } from '../useModuleFactory'; @@ -31,7 +30,7 @@ export const useObjectDetection = ({ isGenerating, downloadProgress, runForward, - instance, + runOnFrame, } = useModuleFactory({ factory: (config, onProgress) => ObjectDetectionModule.fromModelName(config, onProgress), @@ -43,8 +42,6 @@ export const useObjectDetection = ({ const forward = (input: string | PixelData, detectionThreshold?: number) => runForward((inst) => inst.forward(input, detectionThreshold)); - const runOnFrame = useMemo(() => instance?.runOnFrame ?? null, [instance]); - return { error, isReady, diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts index 622aa1a541..ae6ebed938 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts @@ -41,7 +41,7 @@ export const useSemanticSegmentation = < isGenerating, downloadProgress, runForward, - instance, + runOnFrame, } = useModuleFactory({ factory: (config, onProgress) => SemanticSegmentationModule.fromModelName(config, onProgress), @@ -59,8 +59,6 @@ export const useSemanticSegmentation = < inst.forward(imageSource, classesOfInterest, resizeToInput) ); - const runOnFrame = instance?.runOnFrame ?? null; - return { error, isReady, diff --git a/packages/react-native-executorch/src/hooks/useModuleFactory.ts b/packages/react-native-executorch/src/hooks/useModuleFactory.ts index 3d7f474052..e17de50f4f 100644 --- a/packages/react-native-executorch/src/hooks/useModuleFactory.ts +++ b/packages/react-native-executorch/src/hooks/useModuleFactory.ts @@ -1,4 +1,4 @@ -import { useState, useEffect } from 'react'; +import { useState, useEffect, useMemo } from 'react'; import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; @@ -92,6 +92,14 @@ export function useModuleFactory({ } }; + const runOnFrame = useMemo( + () => + instance && 'runOnFrame' in instance + ? (instance.runOnFrame as ((...args: any[]) => any) | null) + : null, + [instance] + ); + return { error, isReady, @@ -99,5 +107,6 @@ export function useModuleFactory({ downloadProgress, runForward, instance, + runOnFrame, }; } diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts index 43fb79c645..c4cd57b889 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts @@ -74,7 +74,6 @@ export class ImageEmbeddingsModule extends VisionModule { } async forward(input: string | PixelData): Promise { - const result = await super.forward(input); - return new Float32Array(result as unknown as ArrayBuffer); + return super.forward(input); } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts index a6e7983768..e486c03e0e 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts @@ -3,18 +3,6 @@ import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError } from '../../errors/errorUtils'; import { Frame, PixelData, ScalarType } from '../../types/common'; -/** - * Base class for computer vision models that support multiple input types. - * - * VisionModule extends BaseModule with: - * - Unified `forward()` API accepting string paths or raw pixel data - * - `runOnFrame` getter for real-time VisionCamera frame processing - * - Shared frame processor creation logic - * - * Subclasses should only implement model-specific loading logic. - * - * @category Typescript API - */ export function isPixelData(input: unknown): input is PixelData { return ( typeof input === 'object' && @@ -29,6 +17,18 @@ export function isPixelData(input: unknown): input is PixelData { ); } +/** + * Base class for computer vision models that support multiple input types. + * + * VisionModule extends BaseModule with: + * - Unified `forward()` API accepting string paths or raw pixel data + * - `runOnFrame` getter for real-time VisionCamera frame processing + * - Shared frame processor creation logic + * + * Subclasses implement model-specific loading logic and may override `forward` for typed signatures. + * + * @category Typescript API + */ export abstract class VisionModule extends BaseModule { /** * Synchronous worklet function for real-time VisionCamera frame processing. From 6bc907c2b02890f8ab3324448845fcd4f9b4f043 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Mar 2026 15:51:14 +0100 Subject: [PATCH 44/71] feat: style transfer returns raw pixels or uri --- .../app/style_transfer/index.tsx | 73 ++++--------------- .../host_objects/JsiConversions.h | 9 +++ .../models/style_transfer/StyleTransfer.cpp | 69 +++++++++--------- .../models/style_transfer/StyleTransfer.h | 17 ++--- .../models/style_transfer/Types.h | 4 + .../hooks/computer_vision/useStyleTransfer.ts | 13 +++- .../computer_vision/StyleTransferModule.ts | 2 +- .../src/types/styleTransfer.ts | 8 +- 8 files changed, 83 insertions(+), 112 deletions(-) diff --git a/apps/computer-vision/app/style_transfer/index.tsx b/apps/computer-vision/app/style_transfer/index.tsx index db238d671b..46ae3e814a 100644 --- a/apps/computer-vision/app/style_transfer/index.tsx +++ b/apps/computer-vision/app/style_transfer/index.tsx @@ -5,14 +5,6 @@ import { useStyleTransfer, STYLE_TRANSFER_CANDY_QUANTIZED, } from 'react-native-executorch'; -import { - Canvas, - Image as SkiaImage, - Skia, - AlphaType, - ColorType, - SkImage, -} from '@shopify/react-native-skia'; import { View, StyleSheet, Image } from 'react-native'; import React, { useContext, useEffect, useState } from 'react'; import { GeneratingContext } from '../../context'; @@ -26,36 +18,22 @@ export default function StyleTransferScreen() { }, [model.isGenerating, setGlobalGenerating]); const [imageUri, setImageUri] = useState(''); - const [styledImage, setStyledImage] = useState(null); - const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); + const [styledUri, setStyledUri] = useState(''); const handleCameraPress = async (isCamera: boolean) => { const image = await getImage(isCamera); const uri = image?.uri; if (typeof uri === 'string') { setImageUri(uri); - setStyledImage(null); + setStyledUri(''); } }; const runForward = async () => { if (imageUri) { try { - const output = await model.forward(imageUri); - const height = output.sizes[0]; - const width = output.sizes[1]; - const skData = Skia.Data.fromBytes(output.dataPtr); - const img = Skia.Image.MakeImage( - { - width, - height, - alphaType: AlphaType.Opaque, - colorType: ColorType.RGBA_8888, - }, - skData, - width * 4 - ); - setStyledImage(img); + const uri = await model.forward(imageUri, 'url'); + setStyledUri(uri); } catch (e) { console.error(e); } @@ -74,38 +52,17 @@ export default function StyleTransferScreen() { return ( - {styledImage ? ( - - setCanvasSize({ - width: e.nativeEvent.layout.width, - height: e.nativeEvent.layout.height, - }) - } - > - - - - - ) : ( - - )} + } + /> #include #include +#include #include #include @@ -609,4 +610,12 @@ inline jsi::Value getJsiValue( return dict; } +inline jsi::Value +getJsiValue(const models::style_transfer::StyleTransferResult &result, + jsi::Runtime &runtime) { + return std::visit( + [&runtime](const auto &value) { return getJsiValue(value, runtime); }, + result); +} + } // namespace rnexecutorch::jsi_conversion diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index 96b0cdcac5..e51f952b91 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -32,32 +32,9 @@ StyleTransfer::StyleTransfer(const std::string &modelSource, } } -PixelDataResult StyleTransfer::postprocess(const Tensor &tensor, - cv::Size outputSize) { - // Convert tensor output (at model input size) to CV_8UC3 BGR mat - cv::Mat mat = image_processing::getMatrixFromTensor(modelInputSize(), tensor); - - // Resize only if requested output differs from model output size - if (mat.size() != outputSize) { - cv::resize(mat, mat, outputSize); - } - - // Convert BGR -> RGBA so JS can pass the buffer directly to Skia - cv::Mat rgba; - cv::cvtColor(mat, rgba, cv::COLOR_BGR2RGBA); - - std::size_t dataSize = - static_cast(outputSize.width) * outputSize.height * 4; - auto pixelBuffer = std::make_shared(rgba.data, dataSize); - log(LOG_LEVEL::Debug, - "[StyleTransfer] postprocess: RGBA buffer size:", dataSize, - "w:", outputSize.width, "h:", outputSize.height); - - return PixelDataResult{pixelBuffer, outputSize.width, outputSize.height}; -} - -PixelDataResult StyleTransfer::runInference(cv::Mat image, - cv::Size originalSize) { +// Runs inference and returns the styled BGR cv::Mat resized to outputSize. +// Acquires inference_mutex_ for the duration. +cv::Mat StyleTransfer::runInference(cv::Mat image, cv::Size outputSize) { std::scoped_lock lock(inference_mutex_); cv::Mat preprocessed = preprocess(image); @@ -72,17 +49,37 @@ PixelDataResult StyleTransfer::runInference(cv::Mat image, "Ensure the model input is correct."); } - return postprocess(forwardResult->at(0).toTensor(), originalSize); + cv::Mat mat = image_processing::getMatrixFromTensor( + modelInputSize(), forwardResult->at(0).toTensor()); + if (mat.size() != outputSize) { + cv::resize(mat, mat, outputSize); + } + return mat; } -PixelDataResult StyleTransfer::generateFromString(std::string imageSource) { +PixelDataResult toPixelDataResult(const cv::Mat &bgrMat) { + cv::Size size = bgrMat.size(); + // Convert BGR -> RGBA so JS can pass the buffer directly to Skia + cv::Mat rgba; + cv::cvtColor(bgrMat, rgba, cv::COLOR_BGR2RGBA); + std::size_t dataSize = static_cast(size.width) * size.height * 4; + auto pixelBuffer = std::make_shared(rgba.data, dataSize); + return PixelDataResult{pixelBuffer, size.width, size.height}; +} + +StyleTransferResult StyleTransfer::generateFromString(std::string imageSource, + bool saveToFile) { cv::Mat imageBGR = image_processing::readImage(imageSource); cv::Size originalSize = imageBGR.size(); cv::Mat imageRGB; cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); - return runInference(imageRGB, originalSize); + cv::Mat result = runInference(imageRGB, originalSize); + if (saveToFile) { + return image_processing::saveToTempFile(result); + } + return toPixelDataResult(result); } PixelDataResult StyleTransfer::generateFromFrame(jsi::Runtime &runtime, @@ -90,16 +87,20 @@ PixelDataResult StyleTransfer::generateFromFrame(jsi::Runtime &runtime, // extractFromFrame rotates landscape frames 90° CW automatically. cv::Mat frame = extractFromFrame(runtime, frameData); - // For real-time frame processing, output at modelImageSize to avoid + // For real-time frame processing, output at modelInputSize to avoid // allocating large buffers (e.g. 1280x720x3 ~2.7MB) on every frame. - return runInference(frame, modelInputSize()); + return toPixelDataResult(runInference(frame, modelInputSize())); } -PixelDataResult StyleTransfer::generateFromPixels(JSTensorViewIn pixelData) { +StyleTransferResult StyleTransfer::generateFromPixels(JSTensorViewIn pixelData, + bool saveToFile) { cv::Mat image = extractFromPixels(pixelData); - cv::Size originalSize = image.size(); - return runInference(image, originalSize); + cv::Mat result = runInference(image, image.size()); + if (saveToFile) { + return image_processing::saveToTempFile(result); + } + return toPixelDataResult(result); } } // namespace rnexecutorch::models::style_transfer diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h index d018e66e0a..c15095bf5b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -16,30 +16,23 @@ namespace rnexecutorch { namespace models::style_transfer { using namespace facebook; -using executorch::aten::Tensor; -using executorch::extension::TensorPtr; class StyleTransfer : public VisionModel { public: StyleTransfer(const std::string &modelSource, std::shared_ptr callInvoker); - [[nodiscard("Registered non-void function")]] PixelDataResult - generateFromString(std::string imageSource); + [[nodiscard("Registered non-void function")]] StyleTransferResult + generateFromString(std::string imageSource, bool saveToFile); [[nodiscard("Registered non-void function")]] PixelDataResult generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); - [[nodiscard("Registered non-void function")]] PixelDataResult - generateFromPixels(JSTensorViewIn pixelData); + [[nodiscard("Registered non-void function")]] StyleTransferResult + generateFromPixels(JSTensorViewIn pixelData, bool saveToFile); private: - // outputSize: size to resize the styled output to before returning. - // Pass modelImageSize for real-time frame processing (avoids large allocs). - // Pass the source image size for generateFromString/generateFromPixels. - PixelDataResult runInference(cv::Mat image, cv::Size outputSize); - - PixelDataResult postprocess(const Tensor &tensor, cv::Size outputSize); + cv::Mat runInference(cv::Mat image, cv::Size outputSize); }; } // namespace models::style_transfer diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h index f677183a64..57e69eb730 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h @@ -2,6 +2,8 @@ #include #include +#include +#include namespace rnexecutorch::models::style_transfer { @@ -11,4 +13,6 @@ struct PixelDataResult { int height; }; +using StyleTransferResult = std::variant; + } // namespace rnexecutorch::models::style_transfer diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts index bfa42eee71..b5903d10dd 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts @@ -1,4 +1,5 @@ import { StyleTransferModule } from '../../modules/computer_vision/StyleTransferModule'; +import { PixelData } from '../../types/common'; import { StyleTransferProps, StyleTransferType, @@ -9,7 +10,7 @@ import { useModuleFactory } from '../useModuleFactory'; * React hook for managing a Style Transfer model instance. * * @category Hooks - * @param StyleTransferProps - Configuration object containing `model` source and optional `preventLoad` flag. + * @param props - Configuration object containing `model` source and optional `preventLoad` flag. * @returns Ready to use Style Transfer model. */ export const useStyleTransfer = ({ @@ -25,8 +26,14 @@ export const useStyleTransfer = ({ preventLoad, }); - const forward = (imageSource: string) => + const forward = (imageSource: string | PixelData) => runForward((inst) => inst.forward(imageSource)); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + } as StyleTransferType; }; diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index 4f5e82c4ec..b06a2f9b09 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -1,8 +1,8 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { StyleTransferModelName } from '../../types/styleTransfer'; import { ResourceSource, PixelData } from '../../types/common'; -import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; +import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { Logger } from '../../common/Logger'; import { VisionModule } from './VisionModule'; diff --git a/packages/react-native-executorch/src/types/styleTransfer.ts b/packages/react-native-executorch/src/types/styleTransfer.ts index 6d48bfbab3..f14412e4a8 100644 --- a/packages/react-native-executorch/src/types/styleTransfer.ts +++ b/packages/react-native-executorch/src/types/styleTransfer.ts @@ -67,10 +67,14 @@ export interface StyleTransferType { * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. * * @param input - Image source (string or PixelData object) - * @returns A Promise that resolves to `PixelData` containing the stylized image as raw RGB pixel data. + * @param output - Output format: `'pixelData'` (default) returns raw RGBA pixel data; `'url'` saves the result to a temp file and returns its `file://` path. + * @returns A Promise resolving to `PixelData` when `output` is `'pixelData'` (default), or a `file://` URL string when `output` is `'url'`. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ - forward: (input: string | PixelData) => Promise; + forward( + input: string | PixelData, + output?: O + ): Promise; /** * Synchronous worklet function for real-time VisionCamera frame processing. From 61c493c737a37ec2317aa9fda13cf28681bf63ef Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Mar 2026 16:21:50 +0100 Subject: [PATCH 45/71] refactor: things catched in review --- .../common/rnexecutorch/host_objects/ModelHostObject.h | 2 +- .../common/rnexecutorch/models/VisionModel.h | 2 +- .../common/rnexecutorch/models/ocr/OCR.cpp | 1 + .../rnexecutorch/models/vertical_ocr/VerticalOCR.cpp | 1 + .../src/hooks/useModuleFactory.ts | 4 +++- .../src/modules/computer_vision/StyleTransferModule.ts | 9 +++++++-- 6 files changed, 14 insertions(+), 5 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 7ef7953c85..e4361273d5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -480,4 +480,4 @@ template class ModelHostObject : public JsiHostObject { std::shared_ptr callInvoker; }; -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h index 766e6ff968..772ed40e61 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -33,7 +33,7 @@ namespace models { * std::unordered_map * generateFromFrame(jsi::Runtime& runtime, const jsi::Value& frameValue) { * auto frameObject = frameValue.asObject(runtime); - * cv::Mat frame = utils::extractFrame(runtime, frameObject); + * cv::Mat frame = extractFromFrame(runtime, frameObject); * return runInference(frame); * } * }; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp index 8de712edd4..3c64ba115f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp @@ -79,6 +79,7 @@ std::size_t OCR::getMemoryLowerBound() const noexcept { } void OCR::unload() noexcept { + std::scoped_lock lock(inference_mutex_); detector.unload(); recognitionHandler.unload(); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp index 8f6cbe5072..fef78d7953 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp @@ -209,6 +209,7 @@ types::OCRDetection VerticalOCR::_processSingleTextBox( } void VerticalOCR::unload() noexcept { + std::scoped_lock lock(inference_mutex_); detector.unload(); recognizer.unload(); } diff --git a/packages/react-native-executorch/src/hooks/useModuleFactory.ts b/packages/react-native-executorch/src/hooks/useModuleFactory.ts index e17de50f4f..bb3140518d 100644 --- a/packages/react-native-executorch/src/hooks/useModuleFactory.ts +++ b/packages/react-native-executorch/src/hooks/useModuleFactory.ts @@ -4,6 +4,8 @@ import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; type Deletable = { delete: () => void }; +type RunOnFrame = M extends { runOnFrame: infer R } ? R : never; + /** * Shared hook for modules that are instantiated via an async static factory * (i.e. `SomeModule.fromModelName(config, onProgress)`). @@ -95,7 +97,7 @@ export function useModuleFactory({ const runOnFrame = useMemo( () => instance && 'runOnFrame' in instance - ? (instance.runOnFrame as ((...args: any[]) => any) | null) + ? (instance.runOnFrame as RunOnFrame | null) : null, [instance] ); diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index b06a2f9b09..24fb2ddfd6 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -71,7 +71,12 @@ export class StyleTransferModule extends VisionModule { ); } - async forward(input: string | PixelData): Promise { - return super.forward(input); + async forward( + input: string | PixelData, + output?: O + ): Promise { + return super.forward(input, output === 'url') as Promise< + O extends 'url' ? string : PixelData + >; } } From a7cb75807a3e8ed6fb5da179b62b57e5e6f5eadf Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Mar 2026 16:37:39 +0100 Subject: [PATCH 46/71] fix: small adjustments to style transfer --- .eslintrc.js | 2 +- apps/computer-vision/app/vision_camera/index.tsx | 4 +++- .../src/modules/computer_vision/StyleTransferModule.ts | 2 +- .../src/modules/computer_vision/VisionModule.ts | 1 - 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.eslintrc.js b/.eslintrc.js index a9613d48ed..26f3b92475 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -32,7 +32,7 @@ module.exports = { customWordListFile: path.resolve(__dirname, '.cspell-wordlist.txt'), }, ], - 'camelcase': 'error', + 'camelcase': ['error', { properties: 'never' }], }, plugins: ['prettier', 'markdown'], overrides: [ diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index 7d005c5c59..2277565fbb 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -15,6 +15,7 @@ import { View, } from 'react-native'; import { useSafeAreaInsets } from 'react-native-safe-area-context'; +import { useIsFocused } from '@react-navigation/native'; import { Camera, Frame, @@ -235,6 +236,7 @@ export default function VisionCameraScreen() { const [fps, setFps] = useState(0); const [frameMs, setFrameMs] = useState(0); const lastFrameTimeRef = useRef(Date.now()); + const isFocused = useIsFocused(); const cameraPermission = useCameraPermission(); const devices = useCameraDevices(); const device = @@ -461,7 +463,7 @@ export default function VisionCameraScreen() { style={StyleSheet.absoluteFill} device={device} outputs={[frameOutput]} - isActive={true} + isActive={isFocused} format={format} orientationSource="interface" /> diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index 24fb2ddfd6..8dbc2ec8e6 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -11,7 +11,7 @@ import { VisionModule } from './VisionModule'; * * @category Typescript API */ -export class StyleTransferModule extends VisionModule { +export class StyleTransferModule extends VisionModule { private constructor(nativeModule: unknown) { super(); this.nativeModule = nativeModule; diff --git a/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts index e486c03e0e..31d3baba7b 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts @@ -127,7 +127,6 @@ export abstract class VisionModule extends BaseModule { RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded. Please load the model before calling forward().' ); - // Type detection and routing if (typeof input === 'string') { return await this.nativeModule.generateFromString(input, ...args); From c65d610f7062deccb43d3a2489996adf7909eae0 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Mar 2026 17:57:14 +0100 Subject: [PATCH 47/71] tests: vision models --- .../common/rnexecutorch/tests/CMakeLists.txt | 17 ++ .../integration/SemanticSegmentationTest.cpp | 132 ++++++++++- .../tests/integration/StyleTransferTest.cpp | 221 ++++++++++++++++-- .../tests/integration/VisionModelTest.cpp | 121 ++++++++++ .../tests/unit/FrameProcessorTest.cpp | 101 ++++++++ 5 files changed, 575 insertions(+), 17 deletions(-) create mode 100644 packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp create mode 100644 packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index 6705b687cb..bd359dcb8f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -157,8 +157,25 @@ add_rn_test(ImageProcessingTest unit/ImageProcessingTest.cpp LIBS opencv_deps ) +add_rn_test(FrameProcessorTests unit/FrameProcessorTest.cpp + SOURCES + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp + ${IMAGE_UTILS_SOURCES} + LIBS opencv_deps android +) + add_rn_test(BaseModelTests integration/BaseModelTest.cpp) +add_rn_test(VisionModelTests integration/VisionModelTest.cpp + SOURCES + ${RNEXECUTORCH_DIR}/models/VisionModel.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp + ${IMAGE_UTILS_SOURCES} + LIBS opencv_deps android +) + add_rn_test(ClassificationTests integration/ClassificationTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/classification/Classification.cpp diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp index 76b213ca8f..96069769cb 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp @@ -1,12 +1,13 @@ +#include +#include #include #include +#include #include #include #include #include -#include - using namespace rnexecutorch; using namespace rnexecutorch::models::semantic_segmentation; using executorch::extension::make_tensor_ptr; @@ -15,6 +16,15 @@ using executorch::runtime::EValue; constexpr auto kValidSemanticSegmentationModelPath = "deeplabV3_xnnpack_fp32.pte"; +constexpr auto kValidTestImagePath = + "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; + +static JSTensorViewIn makeRgbView(std::vector &buf, int32_t h, + int32_t w) { + buf.assign(static_cast(h * w * 3), 128); + return JSTensorViewIn{ + buf.data(), {h, w, 3}, executorch::aten::ScalarType::Byte}; +} // Test fixture for tests that need dummy input data class SemanticSegmentationForwardTest : public ::testing::Test { @@ -94,6 +104,121 @@ TEST_F(SemanticSegmentationForwardTest, ForwardAfterUnloadThrows) { EXPECT_THROW((void)model->forward(EValue(inputTensor)), RnExecutorchError); } +// ============================================================================ +// generateFromString tests +// ============================================================================ +TEST(SemanticSegmentationGenerateTests, InvalidImagePathThrows) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + EXPECT_THROW( + (void)model.generateFromString("nonexistent_image.jpg", {}, true), + RnExecutorchError); +} + +TEST(SemanticSegmentationGenerateTests, EmptyImagePathThrows) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + EXPECT_THROW((void)model.generateFromString("", {}, true), RnExecutorchError); +} + +TEST(SemanticSegmentationGenerateTests, MalformedURIThrows) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + EXPECT_THROW( + (void)model.generateFromString("not_a_valid_uri://bad", {}, true), + RnExecutorchError); +} + +TEST(SemanticSegmentationGenerateTests, ValidImageNoFilterReturnsResult) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, {}, true); + EXPECT_NE(result.argmax, nullptr); + EXPECT_NE(result.classBuffers, nullptr); +} + +TEST(SemanticSegmentationGenerateTests, ValidImageReturnsAllClasses) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, {}, true); + ASSERT_NE(result.classBuffers, nullptr); + EXPECT_EQ(result.classBuffers->size(), 21u); +} + +TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::set> filter = {"PERSON", "CAT"}; + auto result = model.generateFromString(kValidTestImagePath, filter, true); + ASSERT_NE(result.classBuffers, nullptr); + // Only the requested classes should appear in classBuffers + for (const auto &[label, _] : *result.classBuffers) { + EXPECT_TRUE(filter.count(label) > 0); + } +} + +TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, {}, false); + EXPECT_NE(result.argmax, nullptr); +} + +// ============================================================================ +// generateFromPixels tests +// ============================================================================ +TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, {}, true); + EXPECT_NE(result.argmax, nullptr); + EXPECT_NE(result.classBuffers, nullptr); +} + +TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, {}, true); + ASSERT_NE(result.classBuffers, nullptr); + EXPECT_EQ(result.classBuffers->size(), 21u); +} + +TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + std::set> filter = {"PERSON"}; + auto result = model.generateFromPixels(view, filter, true); + ASSERT_NE(result.classBuffers, nullptr); + for (const auto &[label, _] : *result.classBuffers) { + EXPECT_EQ(label, "PERSON"); + } +} + +TEST(SemanticSegmentationPixelTests, WrongSizesThrows) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::vector buf(16, 0); + JSTensorViewIn view{buf.data(), {4, 4}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(view, {}, true), + RnExecutorchError); +} + +TEST(SemanticSegmentationPixelTests, WrongChannelsThrows) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::vector buf(64, 0); + JSTensorViewIn view{ + buf.data(), {4, 4, 4}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(view, {}, true), + RnExecutorchError); +} + +TEST(SemanticSegmentationPixelTests, WrongScalarTypeThrows) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::vector buf(48, 0); + JSTensorViewIn view{ + buf.data(), {4, 4, 3}, executorch::aten::ScalarType::Float}; + EXPECT_THROW((void)model.generateFromPixels(view, {}, true), + RnExecutorchError); +} + +// ============================================================================ +// Inherited BaseModel tests +// ============================================================================ TEST(SemanticSegmentationInheritedTests, GetInputShapeWorks) { SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); auto shape = model.getInputShape("forward", 0); @@ -125,6 +250,9 @@ TEST(SemanticSegmentationInheritedTests, InputShapeIsSquare) { EXPECT_EQ(shape[2], shape[3]); // Height == Width for DeepLabV3 } +// ============================================================================ +// Constants tests +// ============================================================================ TEST(SemanticSegmentationConstantsTests, ClassLabelsHas21Entries) { EXPECT_EQ(constants::kDeeplabV3Resnet50Labels.size(), 21u); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp index c92299cb15..758ef63368 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp @@ -1,7 +1,12 @@ #include "BaseModelTests.h" +#include +#include #include #include +#include #include +#include +#include using namespace rnexecutorch; using namespace rnexecutorch::models::style_transfer; @@ -12,6 +17,13 @@ constexpr auto kValidStyleTransferModelPath = constexpr auto kValidTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; +static JSTensorViewIn makeRgbView(std::vector &buf, int32_t h, + int32_t w) { + buf.assign(static_cast(h * w * 3), 128); + return JSTensorViewIn{ + buf.data(), {h, w, 3}, executorch::aten::ScalarType::Byte}; +} + // ============================================================================ // Common tests via typed test suite // ============================================================================ @@ -28,7 +40,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generateFromString(kValidTestImagePath); + (void)model.generateFromString(kValidTestImagePath, false); } }; } // namespace model_tests @@ -38,42 +50,221 @@ INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, CommonModelTest, StyleTransferTypes); // ============================================================================ -// Model-specific tests +// generateFromString tests // ============================================================================ TEST(StyleTransferGenerateTests, InvalidImagePathThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", false), RnExecutorchError); } TEST(StyleTransferGenerateTests, EmptyImagePathThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString("", false), RnExecutorchError); } TEST(StyleTransferGenerateTests, MalformedURIThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", false), RnExecutorchError); } -TEST(StyleTransferGenerateTests, ValidImageReturnsNonNull) { +TEST(StyleTransferGenerateTests, ValidImageReturnsFilePath) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - auto result = model.generateFromString(kValidTestImagePath); - EXPECT_NE(result.dataPtr, nullptr); - EXPECT_GT(result.width, 0); - EXPECT_GT(result.height, 0); + auto result = model.generateFromString(kValidTestImagePath, false); + ASSERT_TRUE(std::holds_alternative(result)); + auto &pr = std::get(result); + EXPECT_NE(pr.dataPtr, nullptr); + EXPECT_GT(pr.width, 0); + EXPECT_GT(pr.height, 0); } TEST(StyleTransferGenerateTests, MultipleGeneratesWork) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath)); - auto result1 = model.generateFromString(kValidTestImagePath); - auto result2 = model.generateFromString(kValidTestImagePath); - EXPECT_NE(result1.dataPtr, nullptr); - EXPECT_NE(result2.dataPtr, nullptr); + EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, false)); + auto result1 = model.generateFromString(kValidTestImagePath, false); + auto result2 = model.generateFromString(kValidTestImagePath, false); + ASSERT_TRUE(std::holds_alternative(result1)); + ASSERT_TRUE(std::holds_alternative(result2)); + EXPECT_NE(std::get(result1).dataPtr, nullptr); + EXPECT_NE(std::get(result2).dataPtr, nullptr); +} + +// ============================================================================ +// generateFromString saveToFile tests +// ============================================================================ +TEST(StyleTransferSaveToFileTests, SaveToFileFalseReturnsPixelDataVariant) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, false); + EXPECT_TRUE(std::holds_alternative(result)); +} + +TEST(StyleTransferSaveToFileTests, SaveToFileFalsePixelDataIsNonNull) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, false); + ASSERT_TRUE(std::holds_alternative(result)); + EXPECT_NE(std::get(result).dataPtr, nullptr); +} + +TEST(StyleTransferSaveToFileTests, SaveToFileFalseHasPositiveDimensions) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, false); + ASSERT_TRUE(std::holds_alternative(result)); + auto &pr = std::get(result); + EXPECT_GT(pr.width, 0); + EXPECT_GT(pr.height, 0); +} + +TEST(StyleTransferSaveToFileTests, SaveToFileTrueReturnsStringVariant) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, true); + EXPECT_TRUE(std::holds_alternative(result)); +} + +TEST(StyleTransferSaveToFileTests, SaveToFileTrueStringIsNonEmpty) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, true); + ASSERT_TRUE(std::holds_alternative(result)); + EXPECT_FALSE(std::get(result).empty()); +} + +TEST(StyleTransferSaveToFileTests, SaveToFileTrueStringHasFileScheme) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, true); + ASSERT_TRUE(std::holds_alternative(result)); + EXPECT_TRUE(std::get(result).starts_with("file://")); +} + +// ============================================================================ +// generateFromPixels tests +// ============================================================================ +TEST(StyleTransferPixelTests, ValidPixelsSaveToFileFalseReturnsPixelData) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, false); + ASSERT_TRUE(std::holds_alternative(result)); + EXPECT_NE(std::get(result).dataPtr, nullptr); +} + +TEST(StyleTransferPixelTests, ValidPixelsSaveToFileFalseHasPositiveDimensions) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, false); + ASSERT_TRUE(std::holds_alternative(result)); + auto &pr = std::get(result); + EXPECT_GT(pr.width, 0); + EXPECT_GT(pr.height, 0); +} + +TEST(StyleTransferPixelTests, ValidPixelsSaveToFileTrueReturnsString) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, true); + EXPECT_TRUE(std::holds_alternative(result)); +} + +TEST(StyleTransferPixelTests, ValidPixelsSaveToFileTrueHasFileScheme) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, true); + ASSERT_TRUE(std::holds_alternative(result)); + EXPECT_TRUE(std::get(result).starts_with("file://")); +} + +TEST(StyleTransferPixelTests, OutputDimensionsMatchInputSize) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, false); + ASSERT_TRUE(std::holds_alternative(result)); + auto &pr = std::get(result); + EXPECT_EQ(pr.width, 64); + EXPECT_EQ(pr.height, 64); +} + +// ============================================================================ +// generateFromPixels invalid input tests +// ============================================================================ +TEST(StyleTransferPixelInvalidTests, WrongSizesLengthThrows) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf(16, 0); + JSTensorViewIn view{buf.data(), {4, 4}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(view, false), RnExecutorchError); +} + +TEST(StyleTransferPixelInvalidTests, FourChannelsThrows) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf(64, 0); + JSTensorViewIn view{ + buf.data(), {4, 4, 4}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(view, false), RnExecutorchError); +} + +TEST(StyleTransferPixelInvalidTests, OneChannelThrows) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf(16, 0); + JSTensorViewIn view{ + buf.data(), {4, 4, 1}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(view, false), RnExecutorchError); +} + +TEST(StyleTransferPixelInvalidTests, WrongScalarTypeThrows) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf(48, 0); + JSTensorViewIn view{ + buf.data(), {4, 4, 3}, executorch::aten::ScalarType::Float}; + EXPECT_THROW((void)model.generateFromPixels(view, false), RnExecutorchError); } +// ============================================================================ +// Thread safety tests +// ============================================================================ +TEST(StyleTransferThreadSafetyTests, TwoConcurrentGeneratesDoNotCrash) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::atomic successCount{0}; + std::atomic exceptionCount{0}; + + auto task = [&]() { + try { + (void)model.generateFromString(kValidTestImagePath, false); + successCount++; + } catch (const RnExecutorchError &) { + exceptionCount++; + } + }; + + std::thread a(task); + std::thread b(task); + a.join(); + b.join(); + + EXPECT_EQ(successCount + exceptionCount, 2); +} + +TEST(StyleTransferThreadSafetyTests, + GenerateAndUnloadConcurrentlyDoesNotCrash) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + + std::thread a([&]() { + try { + (void)model.generateFromString(kValidTestImagePath, false); + } catch (const RnExecutorchError &) { + } + }); + std::thread b([&]() { model.unload(); }); + + a.join(); + b.join(); + // If we reach here without crashing, the mutex serialized correctly. +} + +// ============================================================================ +// Inherited BaseModel tests +// ============================================================================ TEST(StyleTransferInheritedTests, GetInputShapeWorks) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); auto shape = model.getInputShape("forward", 0); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp new file mode 100644 index 0000000000..6736454d6f --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp @@ -0,0 +1,121 @@ +#include +#include +#include +#include +#include +#include +#include + +using namespace rnexecutorch; +using namespace rnexecutorch::models; +using executorch::aten::ScalarType; + +// ============================================================================ +// TestableVisionModel — exposes protected methods for testing +// ============================================================================ +class TestableVisionModel : public VisionModel { +public: + explicit TestableVisionModel(const std::string &path) + : VisionModel(path, nullptr) {} + + cv::Mat preprocessPublic(const cv::Mat &img) const { return preprocess(img); } + + cv::Mat extractFromPixelsPublic(const JSTensorViewIn &v) const { + return extractFromPixels(v); + } + + void setInputShape(std::vector shape) { + modelInputShape_ = std::move(shape); + } +}; + +// Reuse the style_transfer .pte as a vehicle — we never call forward(). +constexpr auto kModelPath = "style_transfer_candy_xnnpack_fp32.pte"; + +// ============================================================================ +// preprocess() tests +// ============================================================================ +class VisionModelPreprocessTest : public ::testing::Test { +protected: + void SetUp() override { + model = std::make_unique(kModelPath); + } + std::unique_ptr model; +}; + +TEST_F(VisionModelPreprocessTest, CorrectSizeImageReturnedAsIs) { + model->setInputShape({1, 3, 64, 64}); + cv::Mat img(64, 64, CV_8UC3, cv::Scalar(100, 150, 200)); + auto result = model->preprocessPublic(img); + EXPECT_EQ(result.size(), cv::Size(64, 64)); +} + +TEST_F(VisionModelPreprocessTest, CorrectSizeDataUnchanged) { + model->setInputShape({1, 3, 8, 8}); + cv::Mat img(8, 8, CV_8UC3, cv::Scalar(255, 0, 0)); + auto result = model->preprocessPublic(img); + auto pixel = result.at(0, 0); + EXPECT_EQ(pixel[0], 255); + EXPECT_EQ(pixel[1], 0); + EXPECT_EQ(pixel[2], 0); +} + +TEST_F(VisionModelPreprocessTest, LargerImageIsResizedDown) { + model->setInputShape({1, 3, 32, 32}); + cv::Mat img(128, 128, CV_8UC3, cv::Scalar(0)); + auto result = model->preprocessPublic(img); + EXPECT_EQ(result.size(), cv::Size(32, 32)); +} + +TEST_F(VisionModelPreprocessTest, SmallerImageIsResizedUp) { + model->setInputShape({1, 3, 128, 128}); + cv::Mat img(32, 32, CV_8UC3, cv::Scalar(0)); + auto result = model->preprocessPublic(img); + EXPECT_EQ(result.size(), cv::Size(128, 128)); +} + +TEST_F(VisionModelPreprocessTest, NonSquareTargetSize) { + model->setInputShape({1, 3, 48, 96}); + cv::Mat img(200, 100, CV_8UC3, cv::Scalar(0)); + auto result = model->preprocessPublic(img); + EXPECT_EQ(result.rows, 48); + EXPECT_EQ(result.cols, 96); +} + +// ============================================================================ +// extractFromPixels() tests +// ============================================================================ +class VisionModelExtractFromPixelsTest : public ::testing::Test { +protected: + void SetUp() override { + model = std::make_unique(kModelPath); + } + std::unique_ptr model; +}; + +TEST_F(VisionModelExtractFromPixelsTest, ValidInputReturnsCorrectDimensions) { + std::vector buf(64 * 64 * 3, 128); + JSTensorViewIn view{buf.data(), {64, 64, 3}, ScalarType::Byte}; + auto mat = model->extractFromPixelsPublic(view); + EXPECT_EQ(mat.rows, 64); + EXPECT_EQ(mat.cols, 64); + EXPECT_EQ(mat.channels(), 3); +} + +TEST_F(VisionModelExtractFromPixelsTest, TwoDimensionalSizesThrows) { + std::vector buf(16, 0); + JSTensorViewIn view{buf.data(), {4, 4}, ScalarType::Byte}; + EXPECT_THROW(model->extractFromPixelsPublic(view), RnExecutorchError); +} + +TEST_F(VisionModelExtractFromPixelsTest, WrongChannelsThrows) { + std::vector buf(64, 0); + JSTensorViewIn view{buf.data(), {4, 4, 4}, ScalarType::Byte}; + EXPECT_THROW(model->extractFromPixelsPublic(view), RnExecutorchError); +} + +TEST_F(VisionModelExtractFromPixelsTest, WrongScalarTypeThrows) { + std::vector buf(48, 0); + JSTensorViewIn view{buf.data(), {4, 4, 3}, ScalarType::Float}; + EXPECT_THROW(model->extractFromPixelsPublic(view), RnExecutorchError); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp new file mode 100644 index 0000000000..6465db2a75 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp @@ -0,0 +1,101 @@ +#include +#include +#include +#include +#include +#include + +using namespace rnexecutorch; +using namespace rnexecutorch::utils; +using executorch::aten::ScalarType; + +static JSTensorViewIn makeValidView(std::vector &buf, int32_t h, + int32_t w) { + buf.assign(static_cast(h * w * 3), 128); + return JSTensorViewIn{buf.data(), {h, w, 3}, ScalarType::Byte}; +} + +// ============================================================================ +// Valid input +// ============================================================================ +TEST(PixelsToMatValidInput, ProducesCorrectRows) { + std::vector buf; + auto view = makeValidView(buf, 48, 64); + EXPECT_EQ(pixelsToMat(view).rows, 48); +} + +TEST(PixelsToMatValidInput, ProducesCorrectCols) { + std::vector buf; + auto view = makeValidView(buf, 48, 64); + EXPECT_EQ(pixelsToMat(view).cols, 64); +} + +TEST(PixelsToMatValidInput, ProducesThreeChannelMat) { + std::vector buf; + auto view = makeValidView(buf, 4, 4); + EXPECT_EQ(pixelsToMat(view).channels(), 3); +} + +TEST(PixelsToMatValidInput, MatTypeIsCV_8UC3) { + std::vector buf; + auto view = makeValidView(buf, 4, 4); + EXPECT_EQ(pixelsToMat(view).type(), CV_8UC3); +} + +TEST(PixelsToMatValidInput, MatWrapsOriginalData) { + std::vector buf; + auto view = makeValidView(buf, 4, 4); + auto mat = pixelsToMat(view); + EXPECT_EQ(mat.data, buf.data()); +} + +// ============================================================================ +// Invalid sizes dimensionality +// ============================================================================ +TEST(PixelsToMatInvalidSizes, TwoDimensionalThrows) { + std::vector buf(16, 0); + JSTensorViewIn view{buf.data(), {4, 4}, ScalarType::Byte}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +TEST(PixelsToMatInvalidSizes, FourDimensionalThrows) { + std::vector buf(48, 0); + JSTensorViewIn view{buf.data(), {1, 4, 4, 3}, ScalarType::Byte}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +TEST(PixelsToMatInvalidSizes, EmptySizesThrows) { + std::vector buf(4, 0); + JSTensorViewIn view{buf.data(), {}, ScalarType::Byte}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +// ============================================================================ +// Invalid channel count +// ============================================================================ +TEST(PixelsToMatInvalidChannels, OneChannelThrows) { + std::vector buf(16, 0); + JSTensorViewIn view{buf.data(), {4, 4, 1}, ScalarType::Byte}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +TEST(PixelsToMatInvalidChannels, FourChannelsThrows) { + std::vector buf(64, 0); + JSTensorViewIn view{buf.data(), {4, 4, 4}, ScalarType::Byte}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +// ============================================================================ +// Invalid scalar type +// ============================================================================ +TEST(PixelsToMatInvalidScalarType, FloatScalarTypeThrows) { + std::vector buf(48, 0); + JSTensorViewIn view{buf.data(), {4, 4, 3}, ScalarType::Float}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +TEST(PixelsToMatInvalidScalarType, IntScalarTypeThrows) { + std::vector buf(48, 0); + JSTensorViewIn view{buf.data(), {4, 4, 3}, ScalarType::Int}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} From c18ebad89ae2ddfba4134b94194afbfa03067038 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Mar 2026 18:10:46 +0100 Subject: [PATCH 48/71] docs: update documentation --- .../02-computer-vision/useClassification.md | 8 +- .../02-computer-vision/useImageEmbeddings.md | 8 +- .../03-hooks/02-computer-vision/useOCR.md | 8 +- .../02-computer-vision/useObjectDetection.md | 8 +- .../useSemanticSegmentation.md | 8 +- .../02-computer-vision/useStyleTransfer.md | 55 ++++++-- .../visioncamera-integration.md | 125 ++++++++++++++++++ .../ClassificationModule.md | 4 +- .../ImageEmbeddingsModule.md | 4 +- .../02-computer-vision/OCRModule.md | 4 +- .../ObjectDetectionModule.md | 9 +- .../SemanticSegmentationModule.md | 4 +- .../02-computer-vision/StyleTransferModule.md | 9 +- 13 files changed, 230 insertions(+), 24 deletions(-) create mode 100644 docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md diff --git a/docs/docs/03-hooks/02-computer-vision/useClassification.md b/docs/docs/03-hooks/02-computer-vision/useClassification.md index e9c2eebfab..d627d95364 100644 --- a/docs/docs/03-hooks/02-computer-vision/useClassification.md +++ b/docs/docs/03-hooks/02-computer-vision/useClassification.md @@ -52,12 +52,18 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns a promise, which can resolve either to an error or an object containing categories with their probabilities. +To run the model, use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to an object containing categories with their probabilities. :::info Images from external sources are stored in your application's temporary directory. ::: +## VisionCamera integration + +For real-time classification on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns `{ [category: string]: number }`. + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). + ## Example ```typescript diff --git a/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md b/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md index caef87cdf2..8ce07e0ba4 100644 --- a/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md +++ b/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md @@ -63,7 +63,13 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/interfaces/ImageEmbeddingsType.md#forward) method. It accepts one argument which is a URI/URL to an image you want to encode or base64 (whole URI or only raw base64). The function returns a promise, which can resolve either to an error or an array of numbers representing the embedding. +To run the model, use the [`forward`](../../06-api-reference/interfaces/ImageEmbeddingsType.md#forward) method. It accepts one argument — the image to embed. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to a `Float32Array` representing the embedding. + +## VisionCamera integration + +For real-time embedding on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns `Float32Array`. + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Example diff --git a/docs/docs/03-hooks/02-computer-vision/useOCR.md b/docs/docs/03-hooks/02-computer-vision/useOCR.md index 76e7ad6956..f622414651 100644 --- a/docs/docs/03-hooks/02-computer-vision/useOCR.md +++ b/docs/docs/03-hooks/02-computer-vision/useOCR.md @@ -50,7 +50,13 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the text recognized within the box, and the confidence score. For more information, please refer to the reference or type definitions. +To run the model, use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score. + +## VisionCamera integration + +For real-time text recognition on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns `OCRDetection[]`. + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Detection object diff --git a/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md b/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md index 69ac3c79ac..a910645d32 100644 --- a/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md +++ b/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md @@ -66,7 +66,7 @@ You need more details? Check the following resources: To run the model, use the [`forward`](../../06-api-reference/interfaces/ObjectDetectionType.md#forward) method. It accepts two arguments: -- `imageSource` (required) - The image to process. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- `input` (required) - The image to process. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). - `detectionThreshold` (optional) - A number between 0 and 1 representing the minimum confidence score for a detection to be included in the results. Defaults to `0.7`. `forward` returns a promise resolving to an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects, each containing: @@ -107,6 +107,12 @@ function App() { } ``` +## VisionCamera integration + +For real-time object detection on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns `Detection[]`. + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). + ## Supported models | Model | Number of classes | Class list | diff --git a/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md b/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md index 3a6fa46553..1ca0987361 100644 --- a/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md +++ b/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md @@ -66,7 +66,7 @@ You need more details? Check the following resources: To run the model, use the [`forward`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) method. It accepts three arguments: -- [`imageSource`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- [`input`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). - [`classesOfInterest`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]` (no class masks). The `ARGMAX` map is always returned regardless of this parameter. - [`resizeToInput`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions (e.g. 224x224 for `DEEPLAB_V3_RESNET50`). @@ -115,6 +115,12 @@ function App() { } ``` +## VisionCamera integration + +For real-time segmentation on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns the same segmentation result object as `forward`. + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). + ## Supported models | Model | Number of classes | Class list | Quantized | diff --git a/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md b/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md index 471bde35e9..636589d7c7 100644 --- a/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md +++ b/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md @@ -23,10 +23,13 @@ import { const model = useStyleTransfer({ model: STYLE_TRANSFER_CANDY }); -const imageUri = 'file::///Users/.../cute_cat.png'; +const imageUri = 'file:///Users/.../cute_cat.png'; try { - const generatedImageUrl = await model.forward(imageUri); + // Returns a file URI string + const uri = await model.forward(imageUri, 'url'); + // Or returns raw PixelData (default) + const pixels = await model.forward(imageUri); } catch (error) { console.error(error); } @@ -51,30 +54,56 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use [`forward`](../../06-api-reference/interfaces/StyleTransferType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns a promise which can resolve either to an error or a URL to generated image. +To run the model, use the [`forward`](../../06-api-reference/interfaces/StyleTransferType.md#forward) method. It accepts two arguments: + +- `input` (required) — The image to stylize. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). +- `output` (optional) — Controls the return format: + - `'pixelData'` (default) — Returns a `PixelData` object with raw RGB pixels. No file is written. + - `'url'` — Saves the result to a temp file and returns its URI as a `string`. :::info -Images from external sources and the generated image are stored in your application's temporary directory. +When `output` is `'url'`, the generated image is stored in your application's temporary directory. ::: ## Example ```typescript +import { + useStyleTransfer, + STYLE_TRANSFER_CANDY, +} from 'react-native-executorch'; + function App() { const model = useStyleTransfer({ model: STYLE_TRANSFER_CANDY }); - // ... - const imageUri = 'file::///Users/.../cute_cat.png'; - - try { - const generatedImageUrl = await model.forward(imageUri); - } catch (error) { - console.error(error); - } - // ... + // Returns a file URI — easy to pass to + const runWithUrl = async (imageUri: string) => { + try { + const uri = await model.forward(imageUri, 'url'); + console.log('Styled image saved at:', uri); + } catch (error) { + console.error(error); + } + }; + + // Returns raw PixelData — useful for further processing or frame pipelines + const runWithPixelData = async (imageUri: string) => { + try { + const pixels = await model.forward(imageUri); + // pixels.dataPtr is a Uint8Array of RGB bytes + } catch (error) { + console.error(error); + } + }; } ``` +## VisionCamera integration + +For real-time style transfer on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and always returns `PixelData`. + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). + ## Supported models - [Candy](https://github.com/pytorch/examples/tree/main/fast_neural_style) diff --git a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md new file mode 100644 index 0000000000..51cca28450 --- /dev/null +++ b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md @@ -0,0 +1,125 @@ +--- +title: VisionCamera Integration +--- + +React Native ExecuTorch vision models support real-time frame processing via [VisionCamera](https://react-native-vision-camera.com/) using the `runOnFrame` worklet. This page explains how `runOnFrame` works and how to use it with any supported model. + +## Which models support runOnFrame? + +The following hooks expose `runOnFrame`: + +- [`useClassification`](./useClassification.md) +- [`useImageEmbeddings`](./useImageEmbeddings.md) +- [`useOCR`](./useOCR.md) +- [`useObjectDetection`](./useObjectDetection.md) +- [`useSemanticSegmentation`](./useSemanticSegmentation.md) +- [`useStyleTransfer`](./useStyleTransfer.md) + +## runOnFrame vs forward + +| | `runOnFrame` | `forward` | +| -------- | -------------------- | -------------------------- | +| Thread | JS (worklet) | Background thread | +| Input | VisionCamera `Frame` | `string` URI / `PixelData` | +| Output | Model result (sync) | `Promise` | +| Use case | Real-time camera | Single image | + +Use `runOnFrame` when you need to process every camera frame. Use `forward` for one-off image inference. + +:::warning +`runOnFrame` runs synchronously on the JS thread. Keep processing time low to avoid dropping camera frames. +::: + +## Setup + +### 1. Store runOnFrame in state + +`runOnFrame` is a worklet function. Passing it directly to `useState` would cause React to invoke it as a state-updater function. Use the functional form of `setState` instead: + +```tsx +import React, { useState, useEffect } from 'react'; +import { Camera, useFrameProcessor } from 'react-native-vision-camera'; +import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch'; + +export default function App() { + const model = useClassification({ model: EFFICIENTNET_V2_S }); + + const [runOnFrame, setRunOnFrame] = useState(null); + + useEffect(() => { + if (model.isReady) { + setRunOnFrame(() => model.runOnFrame); + } + }, [model.isReady, model.runOnFrame]); + + const frameProcessor = useFrameProcessor( + (frame) => { + 'worklet'; + if (!runOnFrame) return; + + runOnFrame(frame); + // use the returned result ... + }, + [runOnFrame] + ); + + return ; +} +``` + +## Full example (Classification) + +```tsx +import React, { useState, useEffect } from 'react'; +import { Text, StyleSheet } from 'react-native'; +import { + Camera, + useCameraDevice, + useFrameProcessor, +} from 'react-native-vision-camera'; +import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch'; + +export default function App() { + const device = useCameraDevice('back'); + const model = useClassification({ model: EFFICIENTNET_V2_S }); + + const [runOnFrame, setRunOnFrame] = useState(null); + const [topLabel, setTopLabel] = useState(''); + + useEffect(() => { + if (model.isReady) { + setRunOnFrame(() => model.runOnFrame); + } + }, [model.isReady, model.runOnFrame]); + + const frameProcessor = useFrameProcessor( + (frame) => { + 'worklet'; + if (!runOnFrame) return; + + const scores = runOnFrame(frame); + const top = Object.entries(scores).sort(([, a], [, b]) => b - a)[0]; + if (top) setTopLabel(top[0]); + }, + [runOnFrame] + ); + + if (!device) return null; + + return ( + <> + + {topLabel} + + ); +} + +const styles = StyleSheet.create({ + camera: { flex: 1 }, +}); +``` diff --git a/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md b/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md index 4234fa865d..a9bf6ed77d 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md @@ -47,7 +47,9 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method on the module object. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an object containing categories with their probabilities. +To run the model, use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an object containing categories with their probabilities. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. ## Managing memory diff --git a/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md b/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md index 7388416334..47eceef00a 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md @@ -48,4 +48,6 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -[`forward`](../../06-api-reference/classes/ImageEmbeddingsModule.md#forward) accepts one argument: image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns a promise, which can resolve either to an error or an array of numbers representing the embedding. +[`forward`](../../06-api-reference/classes/ImageEmbeddingsModule.md#forward) accepts one argument — the image to embed. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to a `Float32Array` representing the embedding. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. diff --git a/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md b/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md index 1524859ed2..1391982173 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md @@ -41,4 +41,6 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/OCRModule.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score. +To run the model, use the [`forward`](../../06-api-reference/classes/OCRModule.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. diff --git a/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md b/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md index d942eded65..b56cb47713 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md @@ -40,7 +40,14 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/ObjectDetectionModule.md#forward) method on the module object. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score. +To run the model, use the [`forward`](../../06-api-reference/classes/ObjectDetectionModule.md#forward) method. It accepts two arguments: + +- `input` (required) - The image to process. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). +- `detectionThreshold` (optional) - A number between 0 and 1. Defaults to `0.7`. + +The method returns a promise resolving to an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects, each containing the bounding box, label, and confidence score. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. ## Using a custom model diff --git a/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md b/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md index bf88690bdb..4bf2129ac1 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md @@ -86,7 +86,7 @@ For more information on loading resources, take a look at [loading models](../.. To run the model, use the [`forward`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) method. It accepts three arguments: -- [`imageSource`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- [`input`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). - [`classesOfInterest`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]`. The `ARGMAX` map is always returned regardless. - [`resizeToInput`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions. @@ -113,6 +113,8 @@ result.CAT; // Float32Array result.DOG; // Float32Array ``` +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. + ## Managing memory The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/SemanticSegmentationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) after [`delete`](../../06-api-reference/classes/SemanticSegmentationModule.md#delete) unless you create a new instance. diff --git a/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md b/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md index 0e70a24796..7475398971 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md @@ -47,7 +47,14 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/StyleTransferModule.md#forward) method on the module object. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or a URL to generated image. +To run the model, use the [`forward`](../../06-api-reference/classes/StyleTransferModule.md#forward) method. It accepts two arguments: + +- `input` (required) — The image to stylize. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). +- `output` (optional) — Controls the return format: + - `'pixelData'` (default) — Returns a `PixelData` object with raw RGB pixels. No file is written. + - `'url'` — Saves the result to a temp file and returns its URI as a `string`. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. ## Managing memory From 5c50292de39a4bef06d715e05a428cde753a8a7d Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 12 Mar 2026 18:12:02 +0100 Subject: [PATCH 49/71] docs: update docs link --- .../03-hooks/02-computer-vision/visioncamera-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md index 51cca28450..c8462f309e 100644 --- a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md +++ b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md @@ -2,7 +2,7 @@ title: VisionCamera Integration --- -React Native ExecuTorch vision models support real-time frame processing via [VisionCamera](https://react-native-vision-camera.com/) using the `runOnFrame` worklet. This page explains how `runOnFrame` works and how to use it with any supported model. +React Native ExecuTorch vision models support real-time frame processing via [VisionCamera](https://react-native-vision-camera-v5-docs.vercel.app) using the `runOnFrame` worklet. This page explains how `runOnFrame` works and how to use it with any supported model. ## Which models support runOnFrame? From 35411bf77ae96b70ea345284eefcb40c954a6776 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 10:55:04 +0100 Subject: [PATCH 50/71] chore: tests, docs, comments etc. --- .../02-computer-vision/useVerticalOCR.md | 8 +++- .../visioncamera-integration.md | 39 +++++-------------- .../02-computer-vision/VerticalOCRModule.md | 4 +- .../models/style_transfer/StyleTransfer.cpp | 5 --- .../tests/integration/ClassificationTest.cpp | 14 +++++++ .../tests/integration/ImageEmbeddingsTest.cpp | 15 +++++++ .../tests/integration/OCRTest.cpp | 15 +++++++ .../tests/integration/ObjectDetectionTest.cpp | 34 ---------------- .../integration/SemanticSegmentationTest.cpp | 26 ------------- .../tests/integration/StyleTransferTest.cpp | 34 ---------------- .../tests/integration/VerticalOCRTest.cpp | 15 +++++++ .../computer_vision/useClassification.ts | 2 +- .../computer_vision/useImageEmbeddings.ts | 2 +- .../src/hooks/computer_vision/useOCR.ts | 2 +- .../hooks/computer_vision/useStyleTransfer.ts | 2 +- 15 files changed, 83 insertions(+), 134 deletions(-) diff --git a/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md b/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md index b9d29fc423..53d0e8b7ff 100644 --- a/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md +++ b/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md @@ -58,7 +58,13 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the text recognized within the box, and the confidence score. For more information, please refer to the reference or type definitions. +To run the model, use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score. + +## VisionCamera integration + +For real-time text recognition on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns `OCRDetection[]`. + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Detection object diff --git a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md index c8462f309e..8b5931b2e0 100644 --- a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md +++ b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md @@ -11,6 +11,7 @@ The following hooks expose `runOnFrame`: - [`useClassification`](./useClassification.md) - [`useImageEmbeddings`](./useImageEmbeddings.md) - [`useOCR`](./useOCR.md) +- [`useVerticalOCR`](./useVerticalOCR.md) - [`useObjectDetection`](./useObjectDetection.md) - [`useSemanticSegmentation`](./useSemanticSegmentation.md) - [`useStyleTransfer`](./useStyleTransfer.md) @@ -32,35 +33,23 @@ Use `runOnFrame` when you need to process every camera frame. Use `forward` for ## Setup -### 1. Store runOnFrame in state - -`runOnFrame` is a worklet function. Passing it directly to `useState` would cause React to invoke it as a state-updater function. Use the functional form of `setState` instead: +`runOnFrame` is a stable worklet function exposed directly from the hook. Pass it to `useFrameProcessor` and guard with `model.isReady` — no need to store it in state: ```tsx -import React, { useState, useEffect } from 'react'; -import { Camera, useFrameProcessor } from 'react-native-vision-camera'; +import { useFrameProcessor } from 'react-native-vision-camera'; import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch'; export default function App() { const model = useClassification({ model: EFFICIENTNET_V2_S }); - const [runOnFrame, setRunOnFrame] = useState(null); - - useEffect(() => { - if (model.isReady) { - setRunOnFrame(() => model.runOnFrame); - } - }, [model.isReady, model.runOnFrame]); - const frameProcessor = useFrameProcessor( (frame) => { 'worklet'; - if (!runOnFrame) return; + if (!model.isReady) return; - runOnFrame(frame); - // use the returned result ... + model.runOnFrame(frame); // use the returned result }, - [runOnFrame] + [model.isReady, model.runOnFrame] ); return ; @@ -70,7 +59,7 @@ export default function App() { ## Full example (Classification) ```tsx -import React, { useState, useEffect } from 'react'; +import React, { useState } from 'react'; import { Text, StyleSheet } from 'react-native'; import { Camera, @@ -82,26 +71,18 @@ import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch'; export default function App() { const device = useCameraDevice('back'); const model = useClassification({ model: EFFICIENTNET_V2_S }); - - const [runOnFrame, setRunOnFrame] = useState(null); const [topLabel, setTopLabel] = useState(''); - useEffect(() => { - if (model.isReady) { - setRunOnFrame(() => model.runOnFrame); - } - }, [model.isReady, model.runOnFrame]); - const frameProcessor = useFrameProcessor( (frame) => { 'worklet'; - if (!runOnFrame) return; + if (!model.isReady) return; - const scores = runOnFrame(frame); + const scores = model.runOnFrame(frame); const top = Object.entries(scores).sort(([, a], [, b]) => b - a)[0]; if (top) setTopLabel(top[0]); }, - [runOnFrame] + [model.isReady, model.runOnFrame] ); if (!device) return null; diff --git a/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md b/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md index eb47efa857..cc1cdba51d 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md @@ -43,4 +43,6 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/VerticalOCRModule.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score. +To run the model, use the [`forward`](../../06-api-reference/classes/VerticalOCRModule.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index e51f952b91..82c078eb02 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -32,8 +32,6 @@ StyleTransfer::StyleTransfer(const std::string &modelSource, } } -// Runs inference and returns the styled BGR cv::Mat resized to outputSize. -// Acquires inference_mutex_ for the duration. cv::Mat StyleTransfer::runInference(cv::Mat image, cv::Size outputSize) { std::scoped_lock lock(inference_mutex_); @@ -84,11 +82,8 @@ StyleTransferResult StyleTransfer::generateFromString(std::string imageSource, PixelDataResult StyleTransfer::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) { - // extractFromFrame rotates landscape frames 90° CW automatically. cv::Mat frame = extractFromFrame(runtime, frameData); - // For real-time frame processing, output at modelInputSize to avoid - // allocating large buffers (e.g. 1280x720x3 ~2.7MB) on every frame. return toPixelDataResult(runInference(frame, modelInputSize())); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp index b64f167c90..5725778def 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp @@ -1,6 +1,8 @@ #include "BaseModelTests.h" +#include #include #include +#include #include #include @@ -115,3 +117,15 @@ TEST(ClassificationInheritedTests, GetMethodMetaWorks) { auto result = model.getMethodMeta("forward"); EXPECT_TRUE(result.ok()); } + +// ============================================================================ +// generateFromPixels smoke test +// ============================================================================ +TEST(ClassificationPixelTests, ValidPixelsReturnsResults) { + Classification model(kValidClassificationModelPath, nullptr); + std::vector buf(64 * 64 * 3, 128); + JSTensorViewIn view{ + buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + auto results = model.generateFromPixels(view); + EXPECT_FALSE(results.empty()); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp index ba76939a8e..87d37908b1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp @@ -1,7 +1,9 @@ #include "BaseModelTests.h" #include +#include #include #include +#include #include using namespace rnexecutorch; @@ -122,3 +124,16 @@ TEST(ImageEmbeddingsInheritedTests, GetMethodMetaWorks) { auto result = model.getMethodMeta("forward"); EXPECT_TRUE(result.ok()); } + +// ============================================================================ +// generateFromPixels smoke test +// ============================================================================ +TEST(ImageEmbeddingsPixelTests, ValidPixelsReturnsEmbedding) { + ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); + std::vector buf(64 * 64 * 3, 128); + JSTensorViewIn view{ + buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + auto result = model.generateFromPixels(view); + EXPECT_NE(result, nullptr); + EXPECT_GT(result->size(), 0u); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp index 6f6f708be2..072c761164 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp @@ -1,6 +1,8 @@ #include "BaseModelTests.h" +#include #include #include +#include #include #include @@ -126,3 +128,16 @@ TEST(OCRGenerateTests, DetectionsHaveNonEmptyText) { EXPECT_FALSE(detection.text.empty()); } } + +// ============================================================================ +// generateFromPixels smoke test +// ============================================================================ +TEST(OCRPixelTests, ValidPixelsReturnsResults) { + OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, + createMockCallInvoker()); + std::vector buf(64 * 64 * 3, 128); + JSTensorViewIn view{ + buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + auto results = model.generateFromPixels(view); + EXPECT_GE(results.size(), 0u); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp index 8964f20133..c983f2fc77 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp @@ -163,40 +163,6 @@ TEST(ObjectDetectionPixelTests, ValidPixelDataReturnsResults) { EXPECT_GE(results.size(), 0u); } -TEST(ObjectDetectionPixelTests, WrongSizesLengthThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - std::vector pixelData(16, 0); - JSTensorViewIn tensorView{ - pixelData.data(), {4, 4}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(tensorView, 0.5), - RnExecutorchError); -} - -TEST(ObjectDetectionPixelTests, WrongChannelCountThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - constexpr int32_t width = 4, height = 4, channels = 4; - std::vector pixelData(width * height * channels, 0); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(tensorView, 0.5), - RnExecutorchError); -} - -TEST(ObjectDetectionPixelTests, WrongScalarTypeThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - constexpr int32_t width = 4, height = 4, channels = 3; - std::vector pixelData(width * height * channels, 0); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Float}; - EXPECT_THROW((void)model.generateFromPixels(tensorView, 0.5), - RnExecutorchError); -} - TEST(ObjectDetectionPixelTests, NegativeThresholdThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp index 96069769cb..957421f091 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp @@ -190,32 +190,6 @@ TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) { } } -TEST(SemanticSegmentationPixelTests, WrongSizesThrows) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); - std::vector buf(16, 0); - JSTensorViewIn view{buf.data(), {4, 4}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(view, {}, true), - RnExecutorchError); -} - -TEST(SemanticSegmentationPixelTests, WrongChannelsThrows) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); - std::vector buf(64, 0); - JSTensorViewIn view{ - buf.data(), {4, 4, 4}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(view, {}, true), - RnExecutorchError); -} - -TEST(SemanticSegmentationPixelTests, WrongScalarTypeThrows) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); - std::vector buf(48, 0); - JSTensorViewIn view{ - buf.data(), {4, 4, 3}, executorch::aten::ScalarType::Float}; - EXPECT_THROW((void)model.generateFromPixels(view, {}, true), - RnExecutorchError); -} - // ============================================================================ // Inherited BaseModel tests // ============================================================================ diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp index 758ef63368..c18131002a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp @@ -186,40 +186,6 @@ TEST(StyleTransferPixelTests, OutputDimensionsMatchInputSize) { EXPECT_EQ(pr.height, 64); } -// ============================================================================ -// generateFromPixels invalid input tests -// ============================================================================ -TEST(StyleTransferPixelInvalidTests, WrongSizesLengthThrows) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - std::vector buf(16, 0); - JSTensorViewIn view{buf.data(), {4, 4}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(view, false), RnExecutorchError); -} - -TEST(StyleTransferPixelInvalidTests, FourChannelsThrows) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - std::vector buf(64, 0); - JSTensorViewIn view{ - buf.data(), {4, 4, 4}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(view, false), RnExecutorchError); -} - -TEST(StyleTransferPixelInvalidTests, OneChannelThrows) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - std::vector buf(16, 0); - JSTensorViewIn view{ - buf.data(), {4, 4, 1}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(view, false), RnExecutorchError); -} - -TEST(StyleTransferPixelInvalidTests, WrongScalarTypeThrows) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - std::vector buf(48, 0); - JSTensorViewIn view{ - buf.data(), {4, 4, 3}, executorch::aten::ScalarType::Float}; - EXPECT_THROW((void)model.generateFromPixels(view, false), RnExecutorchError); -} - // ============================================================================ // Thread safety tests // ============================================================================ diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp index 56f18d862a..fd6d59441d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp @@ -1,6 +1,8 @@ #include "BaseModelTests.h" +#include #include #include +#include #include #include @@ -239,3 +241,16 @@ TEST(VerticalOCRStrategyTests, BothStrategiesReturnValidResults) { EXPECT_GE(independentResults.size(), 0u); EXPECT_GE(jointResults.size(), 0u); } + +// ============================================================================ +// generateFromPixels smoke test +// ============================================================================ +TEST(VerticalOCRPixelTests, ValidPixelsReturnsResults) { + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, + ENGLISH_SYMBOLS, false, createMockCallInvoker()); + std::vector buf(64 * 64 * 3, 128); + JSTensorViewIn view{ + buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + auto results = model.generateFromPixels(view); + EXPECT_GE(results.size(), 0u); +} diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts index a6ef5c6a14..c014d6b0ed 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts @@ -9,7 +9,7 @@ import { useModuleFactory } from '../useModuleFactory'; * React hook for managing a Classification model instance. * * @category Hooks - * @param props - Configuration object containing `model` source and optional `preventLoad` flag. + * @param ClassificationProps - Configuration object containing `model` source and optional `preventLoad` flag. * @returns Ready to use Classification model. */ export const useClassification = ({ diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts index 07376bddee..b4e79c9263 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts @@ -9,7 +9,7 @@ import { useModuleFactory } from '../useModuleFactory'; * React hook for managing an Image Embeddings model instance. * * @category Hooks - * @param props - Configuration object containing `model` source and optional `preventLoad` flag. + * @param ImageEmbeddingsProps - Configuration object containing `model` source and optional `preventLoad` flag. * @returns Ready to use Image Embeddings model. */ export const useImageEmbeddings = ({ diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts index 31061d2b64..208824b8b8 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts @@ -8,7 +8,7 @@ import { OCRDetection, OCRProps, OCRType } from '../../types/ocr'; * React hook for managing an OCR instance. * * @category Hooks - * @param props - Configuration object containing `model` sources and optional `preventLoad` flag. + * @param OCRProps - Configuration object containing `model` sources and optional `preventLoad` flag. * @returns Ready to use OCR model. */ export const useOCR = ({ model, preventLoad = false }: OCRProps): OCRType => { diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts index b5903d10dd..6ebd1ccd6c 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts @@ -10,7 +10,7 @@ import { useModuleFactory } from '../useModuleFactory'; * React hook for managing a Style Transfer model instance. * * @category Hooks - * @param props - Configuration object containing `model` source and optional `preventLoad` flag. + * @param StyleTransferProps - Configuration object containing `model` source and optional `preventLoad` flag. * @returns Ready to use Style Transfer model. */ export const useStyleTransfer = ({ From 35b722ba5d3121371cef50f01d97f1a9495f5e04 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 11:14:30 +0100 Subject: [PATCH 51/71] docs: update vision camera docs page --- .../visioncamera-integration.md | 189 ++++++++++++++---- 1 file changed, 145 insertions(+), 44 deletions(-) diff --git a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md index 8b5931b2e0..79f6c5aad8 100644 --- a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md +++ b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md @@ -2,7 +2,14 @@ title: VisionCamera Integration --- -React Native ExecuTorch vision models support real-time frame processing via [VisionCamera](https://react-native-vision-camera-v5-docs.vercel.app) using the `runOnFrame` worklet. This page explains how `runOnFrame` works and how to use it with any supported model. +React Native ExecuTorch vision models support real-time frame processing via [VisionCamera v5](https://react-native-vision-camera-v5-docs.vercel.app) using the `runOnFrame` worklet. This page explains how to set it up and what to watch out for. + +## Prerequisites + +Make sure you have the following packages installed: + +- [`react-native-vision-camera`](https://react-native-vision-camera-v5-docs.vercel.app) v5 +- [`react-native-worklets`](https://docs.swmansion.com/react-native-worklets/) ## Which models support runOnFrame? @@ -27,63 +34,79 @@ The following hooks expose `runOnFrame`: Use `runOnFrame` when you need to process every camera frame. Use `forward` for one-off image inference. -:::warning -`runOnFrame` runs synchronously on the JS thread. Keep processing time low to avoid dropping camera frames. -::: - -## Setup +## How it works -`runOnFrame` is a stable worklet function exposed directly from the hook. Pass it to `useFrameProcessor` and guard with `model.isReady` — no need to store it in state: - -```tsx -import { useFrameProcessor } from 'react-native-vision-camera'; -import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch'; - -export default function App() { - const model = useClassification({ model: EFFICIENTNET_V2_S }); +VisionCamera v5 delivers frames via [`useFrameOutput`](https://react-native-vision-camera-v5-docs.vercel.app/docs/frame-output). Inside the `onFrame` worklet you call `runOnFrame(frame)` synchronously, then use `scheduleOnRN` from `react-native-worklets` to post the result back to React state on the main thread. - const frameProcessor = useFrameProcessor( - (frame) => { - 'worklet'; - if (!model.isReady) return; +:::warning +You **must** set `pixelFormat: 'rgb'` in `useFrameOutput`. Our extraction pipeline expect RGB pixel data — any other format (e.g. the default `yuv`) will produce incorrect results. +::: - model.runOnFrame(frame); // use the returned result - }, - [model.isReady, model.runOnFrame] - ); +:::warning +`runOnFrame` is synchronous and runs on the JS worklet thread. For models with longer inference times, use `dropFramesWhileBusy: true` to skip frames and avoid blocking the camera pipeline. For more control, see VisionCamera's [async frame processing guide](https://react-native-vision-camera-v5-docs.vercel.app/docs/async-frame-processing). +::: - return ; -} -``` +:::note +Always call `frame.dispose()` after processing to release the frame buffer. Wrap your inference in a `try/finally` to ensure it's always called even if `runOnFrame` throws. +::: ## Full example (Classification) ```tsx -import React, { useState } from 'react'; +import { useState, useCallback } from 'react'; import { Text, StyleSheet } from 'react-native'; import { Camera, - useCameraDevice, - useFrameProcessor, + Frame, + useCameraDevices, + useCameraPermission, + useFrameOutput, } from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch'; export default function App() { - const device = useCameraDevice('back'); + const { hasPermission, requestPermission } = useCameraPermission(); + const devices = useCameraDevices(); + const device = devices.find((d) => d.position === 'back'); const model = useClassification({ model: EFFICIENTNET_V2_S }); - const [topLabel, setTopLabel] = useState(''); - - const frameProcessor = useFrameProcessor( - (frame) => { - 'worklet'; - if (!model.isReady) return; - - const scores = model.runOnFrame(frame); - const top = Object.entries(scores).sort(([, a], [, b]) => b - a)[0]; - if (top) setTopLabel(top[0]); - }, - [model.isReady, model.runOnFrame] - ); + const [topLabel, setTopLabel] = useState(''); + + // Extract runOnFrame so it can be captured by the useCallback dependency array + const runOnFrame = model.runOnFrame; + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + if (!runOnFrame) return; + try { + const scores = runOnFrame(frame); + if (scores) { + let best = ''; + let bestScore = -1; + for (const [label, score] of Object.entries(scores)) { + if ((score as number) > bestScore) { + bestScore = score as number; + best = label; + } + } + scheduleOnRN(setTopLabel, best); + } + } finally { + frame.dispose(); + } + }, + [runOnFrame] + ), + }); + + if (!hasPermission) { + requestPermission(); + return null; + } if (!device) return null; @@ -92,15 +115,93 @@ export default function App() { - {topLabel} + {topLabel} ); } const styles = StyleSheet.create({ camera: { flex: 1 }, + label: { + position: 'absolute', + bottom: 40, + alignSelf: 'center', + color: 'white', + fontSize: 20, + }, }); ``` + +## Using the Module API + +If you use the TypeScript Module API (e.g. `ClassificationModule`) directly instead of a hook, `runOnFrame` is a worklet function and **cannot** be passed directly to `useState` — React would invoke it as a state initializer. Use the functional updater form `() => module.runOnFrame`: + +```tsx +import { useState, useEffect, useCallback } from 'react'; +import { Camera, useFrameOutput } from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { + ClassificationModule, + EFFICIENTNET_V2_S, +} from 'react-native-executorch'; + +export default function App() { + const [module] = useState(() => new ClassificationModule()); + const [runOnFrame, setRunOnFrame] = useState( + null + ); + + useEffect(() => { + module.load(EFFICIENTNET_V2_S).then(() => { + // () => module.runOnFrame is required — passing module.runOnFrame directly + // would cause React to call it as a state initializer function + setRunOnFrame(() => module.runOnFrame); + }); + }, [module]); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame) => { + 'worklet'; + if (!runOnFrame) return; + try { + const result = runOnFrame(frame); + if (result) scheduleOnRN(setResult, result); + } finally { + frame.dispose(); + } + }, + [runOnFrame] + ), + }); + + return ; +} +``` + +## Common issues + +### Results look wrong or scrambled + +You forgot to set `pixelFormat: 'rgb'`. The default VisionCamera pixel format is `yuv` — our frame extraction works only with RGB data. + +### App freezes or camera drops frames + +Your model's inference time exceeds the frame interval. Enable `dropFramesWhileBusy: true` in `useFrameOutput`, or move inference off the worklet thread using VisionCamera's [async frame processing](https://react-native-vision-camera-v5-docs.vercel.app/docs/async-frame-processing). + +### Memory leak / crash after many frames + +You are not calling `frame.dispose()`. Always dispose the frame in a `finally` block. + +### `runOnFrame` is always null + +The model hasn't finished loading yet. Guard with `if (!runOnFrame) return` inside `onFrame`, or check `model.isReady` before enabling the camera. + +### TypeError: `module.runOnFrame` is not a function (Module API) + +You passed `module.runOnFrame` directly to `setState` instead of `() => module.runOnFrame`. React invoked it as a state initializer — see the [Module API section](#using-the-module-api) above. From 43b12951735771eee1b1f313450848b7de6ae581 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 11:47:52 +0100 Subject: [PATCH 52/71] refactor: unused include --- .../common/rnexecutorch/models/style_transfer/StyleTransfer.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index 82c078eb02..4aaa774cc4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -6,7 +6,6 @@ #include #include #include -#include namespace rnexecutorch::models::style_transfer { using namespace facebook; From 312b45ae5f807aedc09782335d85a6bf109b548a Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 11:51:13 +0100 Subject: [PATCH 53/71] refactor: extract vision camera color utils Co-Authored-By: Claude Sonnet 4.6 --- .../app/vision_camera/index.tsx | 43 +------------------ .../app/vision_camera/utils/colors.ts | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+), 42 deletions(-) create mode 100644 apps/computer-vision/app/vision_camera/utils/colors.ts diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index 2277565fbb..2a0bb8dd90 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -54,6 +54,7 @@ import Svg, { Path, Polygon } from 'react-native-svg'; import { GeneratingContext } from '../../context'; import Spinner from '../../components/Spinner'; import ColorPalette from '../../colors'; +import { CLASS_COLORS, labelColor, labelColorBg } from './utils/colors'; type TaskId = 'classification' | 'objectDetection' | 'segmentation'; type ModelId = @@ -100,48 +101,6 @@ const TASKS: Task[] = [ }, ]; -const CLASS_COLORS: number[][] = [ - [0, 0, 0, 0], - [51, 255, 87, 180], - [51, 87, 255, 180], - [255, 51, 246, 180], - [51, 255, 246, 180], - [243, 255, 51, 180], - [141, 51, 255, 180], - [255, 131, 51, 180], - [51, 255, 131, 180], - [131, 51, 255, 180], - [255, 255, 51, 180], - [51, 255, 255, 180], - [255, 51, 143, 180], - [127, 51, 255, 180], - [51, 255, 175, 180], - [255, 175, 51, 180], - [179, 255, 51, 180], - [255, 87, 51, 180], - [255, 51, 162, 180], - [51, 162, 255, 180], - [162, 51, 255, 180], -]; - -function hashLabel(label: string): number { - let hash = 5381; - for (let i = 0; i < label.length; i++) { - hash = (hash + hash * 32 + label.charCodeAt(i)) % 1000003; - } - return 1 + (Math.abs(hash) % (CLASS_COLORS.length - 1)); -} - -function labelColor(label: string): string { - const color = CLASS_COLORS[hashLabel(label)]!; - return `rgba(${color[0]},${color[1]},${color[2]},1)`; -} - -function labelColorBg(label: string): string { - const color = CLASS_COLORS[hashLabel(label)]!; - return `rgba(${color[0]},${color[1]},${color[2]},0.75)`; -} - const frameKillSwitch = createSynchronizable(false); export default function VisionCameraScreen() { diff --git a/apps/computer-vision/app/vision_camera/utils/colors.ts b/apps/computer-vision/app/vision_camera/utils/colors.ts new file mode 100644 index 0000000000..5d03ca65cc --- /dev/null +++ b/apps/computer-vision/app/vision_camera/utils/colors.ts @@ -0,0 +1,43 @@ +// apps/computer-vision/app/vision_camera/utils/colors.ts + +export const CLASS_COLORS: number[][] = [ + [0, 0, 0, 0], + [51, 255, 87, 180], + [51, 87, 255, 180], + [255, 51, 246, 180], + [51, 255, 246, 180], + [243, 255, 51, 180], + [141, 51, 255, 180], + [255, 131, 51, 180], + [51, 255, 131, 180], + [131, 51, 255, 180], + [255, 255, 51, 180], + [51, 255, 255, 180], + [255, 51, 143, 180], + [127, 51, 255, 180], + [51, 255, 175, 180], + [255, 175, 51, 180], + [179, 255, 51, 180], + [255, 87, 51, 180], + [255, 51, 162, 180], + [51, 162, 255, 180], + [162, 51, 255, 180], +]; + +export function hashLabel(label: string): number { + let hash = 5381; + for (let i = 0; i < label.length; i++) { + hash = (hash + hash * 32 + label.charCodeAt(i)) % 1000003; + } + return 1 + (Math.abs(hash) % (CLASS_COLORS.length - 1)); +} + +export function labelColor(label: string): string { + const color = CLASS_COLORS[hashLabel(label)]!; + return `rgba(${color[0]},${color[1]},${color[2]},1)`; +} + +export function labelColorBg(label: string): string { + const color = CLASS_COLORS[hashLabel(label)]!; + return `rgba(${color[0]},${color[1]},${color[2]},0.75)`; +} From 035dbab012699f9fddd5c8eecf42e772a7482635 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 11:52:21 +0100 Subject: [PATCH 54/71] refactor: add ClassificationTask component --- .../tasks/ClassificationTask.tsx | 121 ++++++++++++++++++ .../app/vision_camera/tasks/types.ts | 15 +++ 2 files changed, 136 insertions(+) create mode 100644 apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx create mode 100644 apps/computer-vision/app/vision_camera/tasks/types.ts diff --git a/apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx b/apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx new file mode 100644 index 0000000000..7111caf9c2 --- /dev/null +++ b/apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx @@ -0,0 +1,121 @@ +// apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx +import React, { useCallback, useEffect, useRef, useState } from 'react'; +import { StyleSheet, Text, View } from 'react-native'; +import { Frame, useFrameOutput } from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { EFFICIENTNET_V2_S, useClassification } from 'react-native-executorch'; +import { TaskProps } from './types'; + +type Props = Omit; + +export default function ClassificationTask({ + frameKillSwitch, + onFrameOutputChange, + onReadyChange, + onProgressChange, + onGeneratingChange, + onFpsChange, +}: Props) { + const model = useClassification({ model: EFFICIENTNET_V2_S }); + const [classResult, setClassResult] = useState({ label: '', score: 0 }); + const lastFrameTimeRef = useRef(Date.now()); + + useEffect(() => { + onReadyChange(model.isReady); + }, [model.isReady, onReadyChange]); + + useEffect(() => { + onProgressChange(model.downloadProgress); + }, [model.downloadProgress, onProgressChange]); + + useEffect(() => { + onGeneratingChange(model.isGenerating); + }, [model.isGenerating, onGeneratingChange]); + + const classRof = model.runOnFrame; + + const updateClass = useCallback( + (r: { label: string; score: number }) => { + setClassResult(r); + const now = Date.now(); + const diff = now - lastFrameTimeRef.current; + if (diff > 0) onFpsChange(Math.round(1000 / diff), diff); + lastFrameTimeRef.current = now; + }, + [onFpsChange] + ); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + if (frameKillSwitch.getDirty()) { + frame.dispose(); + return; + } + try { + if (!classRof) return; + const result = classRof(frame); + if (result) { + let bestLabel = ''; + let bestScore = -1; + const entries = Object.entries(result); + for (let i = 0; i < entries.length; i++) { + const [label, score] = entries[i]!; + if ((score as number) > bestScore) { + bestScore = score as number; + bestLabel = label; + } + } + scheduleOnRN(updateClass, { label: bestLabel, score: bestScore }); + } + } catch { + // ignore + } finally { + frame.dispose(); + } + }, + [classRof, frameKillSwitch, updateClass] + ), + }); + + useEffect(() => { + onFrameOutputChange(frameOutput); + }, [frameOutput, onFrameOutputChange]); + + return classResult.label ? ( + + {classResult.label} + {(classResult.score * 100).toFixed(1)}% + + ) : null; +} + +const styles = StyleSheet.create({ + overlay: { + ...StyleSheet.absoluteFillObject, + justifyContent: 'center', + alignItems: 'center', + }, + label: { + color: 'white', + fontSize: 28, + fontWeight: '700', + textAlign: 'center', + textShadowColor: 'rgba(0,0,0,0.8)', + textShadowOffset: { width: 0, height: 1 }, + textShadowRadius: 6, + paddingHorizontal: 24, + }, + score: { + color: 'rgba(255,255,255,0.75)', + fontSize: 18, + fontWeight: '500', + marginTop: 4, + textShadowColor: 'rgba(0,0,0,0.8)', + textShadowOffset: { width: 0, height: 1 }, + textShadowRadius: 6, + }, +}); diff --git a/apps/computer-vision/app/vision_camera/tasks/types.ts b/apps/computer-vision/app/vision_camera/tasks/types.ts new file mode 100644 index 0000000000..fabc85c576 --- /dev/null +++ b/apps/computer-vision/app/vision_camera/tasks/types.ts @@ -0,0 +1,15 @@ +// apps/computer-vision/app/vision_camera/tasks/types.ts +import { useFrameOutput } from 'react-native-vision-camera'; +import { createSynchronizable } from 'react-native-worklets'; + +export type TaskProps = { + activeModel: string; + canvasSize: { width: number; height: number }; + cameraPosition: 'front' | 'back'; + frameKillSwitch: ReturnType>; + onFrameOutputChange: (frameOutput: ReturnType) => void; + onReadyChange: (isReady: boolean) => void; + onProgressChange: (progress: number) => void; + onGeneratingChange: (isGenerating: boolean) => void; + onFpsChange: (fps: number, frameMs: number) => void; +}; From c68d90992bb1abf19f6e979624d541591511ae6b Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 11:54:34 +0100 Subject: [PATCH 55/71] refactor: add ObjectDetectionTask component --- .../tasks/ObjectDetectionTask.tsx | 175 ++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx diff --git a/apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx b/apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx new file mode 100644 index 0000000000..f9edf2f3d0 --- /dev/null +++ b/apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx @@ -0,0 +1,175 @@ +// apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx +import React, { useCallback, useEffect, useRef, useState } from 'react'; +import { StyleSheet, Text, View } from 'react-native'; +import { Frame, useFrameOutput } from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { + Detection, + RF_DETR_NANO, + SSDLITE_320_MOBILENET_V3_LARGE, + useObjectDetection, +} from 'react-native-executorch'; +import { labelColor, labelColorBg } from '../utils/colors'; +import { TaskProps } from './types'; + +type ObjModelId = 'objectDetection_ssdlite' | 'objectDetection_rfdetr'; + +type Props = TaskProps & { activeModel: ObjModelId }; + +export default function ObjectDetectionTask({ + activeModel, + canvasSize, + cameraPosition, + frameKillSwitch, + onFrameOutputChange, + onReadyChange, + onProgressChange, + onGeneratingChange, + onFpsChange, +}: Props) { + const ssdlite = useObjectDetection({ + model: SSDLITE_320_MOBILENET_V3_LARGE, + preventLoad: activeModel !== 'objectDetection_ssdlite', + }); + const rfdetr = useObjectDetection({ + model: RF_DETR_NANO, + preventLoad: activeModel !== 'objectDetection_rfdetr', + }); + + const active = activeModel === 'objectDetection_ssdlite' ? ssdlite : rfdetr; + + const [detections, setDetections] = useState([]); + const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); + const lastFrameTimeRef = useRef(Date.now()); + + useEffect(() => { + onReadyChange(active.isReady); + }, [active.isReady, onReadyChange]); + + useEffect(() => { + onProgressChange(active.downloadProgress); + }, [active.downloadProgress, onProgressChange]); + + useEffect(() => { + onGeneratingChange(active.isGenerating); + }, [active.isGenerating, onGeneratingChange]); + + const detRof = active.runOnFrame; + + const updateDetections = useCallback( + (p: { results: Detection[]; imageWidth: number; imageHeight: number }) => { + setDetections(p.results); + setImageSize({ width: p.imageWidth, height: p.imageHeight }); + const now = Date.now(); + const diff = now - lastFrameTimeRef.current; + if (diff > 0) onFpsChange(Math.round(1000 / diff), diff); + lastFrameTimeRef.current = now; + }, + [onFpsChange] + ); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + if (frameKillSwitch.getDirty()) { + frame.dispose(); + return; + } + try { + if (!detRof) return; + const iw = frame.width > frame.height ? frame.height : frame.width; + const ih = frame.width > frame.height ? frame.width : frame.height; + const result = detRof(frame, 0.5); + if (result) { + scheduleOnRN(updateDetections, { + results: result, + imageWidth: iw, + imageHeight: ih, + }); + } + } catch { + // ignore + } finally { + frame.dispose(); + } + }, + [detRof, frameKillSwitch, updateDetections] + ), + }); + + useEffect(() => { + onFrameOutputChange(frameOutput); + }, [frameOutput, onFrameOutputChange]); + + const scale = Math.max( + canvasSize.width / imageSize.width, + canvasSize.height / imageSize.height + ); + const offsetX = (canvasSize.width - imageSize.width * scale) / 2; + const offsetY = (canvasSize.height - imageSize.height * scale) / 2; + + return ( + + {detections.map((det, i) => { + const left = det.bbox.x1 * scale + offsetX; + const top = det.bbox.y1 * scale + offsetY; + const w = (det.bbox.x2 - det.bbox.x1) * scale; + const h = (det.bbox.y2 - det.bbox.y1) * scale; + return ( + + + + {det.label} {(det.score * 100).toFixed(1)} + + + + ); + })} + + ); +} + +const styles = StyleSheet.create({ + bbox: { + position: 'absolute', + borderWidth: 2, + borderColor: 'cyan', + borderRadius: 4, + }, + bboxLabel: { + position: 'absolute', + top: -22, + left: -2, + paddingHorizontal: 6, + paddingVertical: 2, + borderRadius: 4, + }, + bboxLabelText: { color: 'white', fontSize: 11, fontWeight: '600' }, +}); From 1267ec4148c55d3fb0c3ca66c7337b0bcbfbed06 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 11:55:51 +0100 Subject: [PATCH 56/71] refactor: add SegmentationTask component --- .../vision_camera/tasks/SegmentationTask.tsx | 213 ++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx diff --git a/apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx b/apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx new file mode 100644 index 0000000000..9f7c2a67d7 --- /dev/null +++ b/apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx @@ -0,0 +1,213 @@ +// apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx +import React, { useCallback, useEffect, useRef, useState } from 'react'; +import { StyleSheet, View } from 'react-native'; +import { Frame, useFrameOutput } from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { + DEEPLAB_V3_RESNET50_QUANTIZED, + DEEPLAB_V3_RESNET101_QUANTIZED, + DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED, + FCN_RESNET50_QUANTIZED, + FCN_RESNET101_QUANTIZED, + LRASPP_MOBILENET_V3_LARGE_QUANTIZED, + SELFIE_SEGMENTATION, + useSemanticSegmentation, +} from 'react-native-executorch'; +import { + AlphaType, + Canvas, + ColorType, + Image as SkiaImage, + Skia, + SkImage, +} from '@shopify/react-native-skia'; +import { CLASS_COLORS } from '../utils/colors'; +import { TaskProps } from './types'; + +type SegModelId = + | 'segmentation_deeplab_resnet50' + | 'segmentation_deeplab_resnet101' + | 'segmentation_deeplab_mobilenet' + | 'segmentation_lraspp' + | 'segmentation_fcn_resnet50' + | 'segmentation_fcn_resnet101' + | 'segmentation_selfie'; + +type Props = TaskProps & { activeModel: SegModelId }; + +export default function SegmentationTask({ + activeModel, + canvasSize, + cameraPosition, + frameKillSwitch, + onFrameOutputChange, + onReadyChange, + onProgressChange, + onGeneratingChange, + onFpsChange, +}: Props) { + const segDeeplabResnet50 = useSemanticSegmentation({ + model: DEEPLAB_V3_RESNET50_QUANTIZED, + preventLoad: activeModel !== 'segmentation_deeplab_resnet50', + }); + const segDeeplabResnet101 = useSemanticSegmentation({ + model: DEEPLAB_V3_RESNET101_QUANTIZED, + preventLoad: activeModel !== 'segmentation_deeplab_resnet101', + }); + const segDeeplabMobilenet = useSemanticSegmentation({ + model: DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED, + preventLoad: activeModel !== 'segmentation_deeplab_mobilenet', + }); + const segLraspp = useSemanticSegmentation({ + model: LRASPP_MOBILENET_V3_LARGE_QUANTIZED, + preventLoad: activeModel !== 'segmentation_lraspp', + }); + const segFcnResnet50 = useSemanticSegmentation({ + model: FCN_RESNET50_QUANTIZED, + preventLoad: activeModel !== 'segmentation_fcn_resnet50', + }); + const segFcnResnet101 = useSemanticSegmentation({ + model: FCN_RESNET101_QUANTIZED, + preventLoad: activeModel !== 'segmentation_fcn_resnet101', + }); + const segSelfie = useSemanticSegmentation({ + model: SELFIE_SEGMENTATION, + preventLoad: activeModel !== 'segmentation_selfie', + }); + + const active = { + segmentation_deeplab_resnet50: segDeeplabResnet50, + segmentation_deeplab_resnet101: segDeeplabResnet101, + segmentation_deeplab_mobilenet: segDeeplabMobilenet, + segmentation_lraspp: segLraspp, + segmentation_fcn_resnet50: segFcnResnet50, + segmentation_fcn_resnet101: segFcnResnet101, + segmentation_selfie: segSelfie, + }[activeModel]; + + const [maskImage, setMaskImage] = useState(null); + const lastFrameTimeRef = useRef(Date.now()); + + useEffect(() => { + onReadyChange(active.isReady); + }, [active.isReady, onReadyChange]); + + useEffect(() => { + onProgressChange(active.downloadProgress); + }, [active.downloadProgress, onProgressChange]); + + useEffect(() => { + onGeneratingChange(active.isGenerating); + }, [active.isGenerating, onGeneratingChange]); + + // Clear stale mask when the segmentation model variant changes + useEffect(() => { + setMaskImage((prev) => { + prev?.dispose(); + return null; + }); + }, [activeModel]); + + // Dispose native Skia image on unmount to prevent memory leaks + useEffect(() => { + return () => { + setMaskImage((prev) => { + prev?.dispose(); + return null; + }); + }; + }, []); + + const segRof = active.runOnFrame; + + const updateMask = useCallback( + (img: SkImage) => { + setMaskImage((prev) => { + prev?.dispose(); + return img; + }); + const now = Date.now(); + const diff = now - lastFrameTimeRef.current; + if (diff > 0) onFpsChange(Math.round(1000 / diff), diff); + lastFrameTimeRef.current = now; + }, + [onFpsChange] + ); + + // CLASS_COLORS captured directly in closure — worklets cannot import modules + const colors = CLASS_COLORS; + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + if (frameKillSwitch.getDirty()) { + frame.dispose(); + return; + } + try { + if (!segRof) return; + const result = segRof(frame, [], false); + if (result?.ARGMAX) { + const argmax: Int32Array = result.ARGMAX; + const side = Math.round(Math.sqrt(argmax.length)); + const pixels = new Uint8Array(side * side * 4); + for (let i = 0; i < argmax.length; i++) { + const color = colors[argmax[i]!] ?? [0, 0, 0, 0]; + pixels[i * 4] = color[0]!; + pixels[i * 4 + 1] = color[1]!; + pixels[i * 4 + 2] = color[2]!; + pixels[i * 4 + 3] = color[3]!; + } + const skData = Skia.Data.fromBytes(pixels); + const img = Skia.Image.MakeImage( + { + width: side, + height: side, + alphaType: AlphaType.Unpremul, + colorType: ColorType.RGBA_8888, + }, + skData, + side * 4 + ); + if (img) scheduleOnRN(updateMask, img); + } + } catch { + // ignore + } finally { + frame.dispose(); + } + }, + [colors, frameKillSwitch, segRof, updateMask] + ), + }); + + useEffect(() => { + onFrameOutputChange(frameOutput); + }, [frameOutput, onFrameOutputChange]); + + if (!maskImage) return null; + + return ( + + + + + + ); +} From eb8cccf673735e1a5a4095914dc87d5410054cff Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 11:58:24 +0100 Subject: [PATCH 57/71] refactor: simplify vision camera screen to shell + task components Co-Authored-By: Claude Sonnet 4.6 --- .../app/vision_camera/index.tsx | 461 +++--------------- 1 file changed, 67 insertions(+), 394 deletions(-) diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index 2a0bb8dd90..c134455042 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -1,9 +1,9 @@ +// apps/computer-vision/app/vision_camera/index.tsx import React, { useCallback, useContext, useEffect, useMemo, - useRef, useState, } from 'react'; import { @@ -18,43 +18,20 @@ import { useSafeAreaInsets } from 'react-native-safe-area-context'; import { useIsFocused } from '@react-navigation/native'; import { Camera, - Frame, getCameraFormat, Templates, useCameraDevices, useCameraPermission, useFrameOutput, } from 'react-native-vision-camera'; -import { createSynchronizable, scheduleOnRN } from 'react-native-worklets'; -import { - DEEPLAB_V3_RESNET50_QUANTIZED, - DEEPLAB_V3_RESNET101_QUANTIZED, - DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED, - LRASPP_MOBILENET_V3_LARGE_QUANTIZED, - FCN_RESNET50_QUANTIZED, - FCN_RESNET101_QUANTIZED, - SELFIE_SEGMENTATION, - Detection, - EFFICIENTNET_V2_S, - RF_DETR_NANO, - SSDLITE_320_MOBILENET_V3_LARGE, - useClassification, - useSemanticSegmentation, - useObjectDetection, -} from 'react-native-executorch'; -import { - AlphaType, - Canvas, - ColorType, - Image as SkiaImage, - Skia, - SkImage, -} from '@shopify/react-native-skia'; +import { createSynchronizable } from 'react-native-worklets'; import Svg, { Path, Polygon } from 'react-native-svg'; import { GeneratingContext } from '../../context'; import Spinner from '../../components/Spinner'; import ColorPalette from '../../colors'; -import { CLASS_COLORS, labelColor, labelColorBg } from './utils/colors'; +import ClassificationTask from './tasks/ClassificationTask'; +import ObjectDetectionTask from './tasks/ObjectDetectionTask'; +import SegmentationTask from './tasks/SegmentationTask'; type TaskId = 'classification' | 'objectDetection' | 'segmentation'; type ModelId = @@ -101,6 +78,8 @@ const TASKS: Task[] = [ }, ]; +// Module-level const so worklets in task components can always reference the same stable object. +// Never replaced — only mutated via setBlocking to avoid closure staleness. const frameKillSwitch = createSynchronizable(false); export default function VisionCameraScreen() { @@ -108,93 +87,18 @@ export default function VisionCameraScreen() { const [activeTask, setActiveTask] = useState('classification'); const [activeModel, setActiveModel] = useState('classification'); const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); - const [cameraPosition, setCameraPosition] = useState<'back' | 'front'>( + const [cameraPosition, setCameraPosition] = useState<'front' | 'back'>( 'back' ); - const { setGlobalGenerating } = useContext(GeneratingContext); - - const classification = useClassification({ - model: EFFICIENTNET_V2_S, - preventLoad: activeModel !== 'classification', - }); - const objectDetectionSsdlite = useObjectDetection({ - model: SSDLITE_320_MOBILENET_V3_LARGE, - preventLoad: activeModel !== 'objectDetection_ssdlite', - }); - const objectDetectionRfdetr = useObjectDetection({ - model: RF_DETR_NANO, - preventLoad: activeModel !== 'objectDetection_rfdetr', - }); - - const activeObjectDetection = - { - objectDetection_ssdlite: objectDetectionSsdlite, - objectDetection_rfdetr: objectDetectionRfdetr, - }[activeModel as 'objectDetection_ssdlite' | 'objectDetection_rfdetr'] ?? - null; - const segDeeplabResnet50 = useSemanticSegmentation({ - model: DEEPLAB_V3_RESNET50_QUANTIZED, - preventLoad: activeModel !== 'segmentation_deeplab_resnet50', - }); - const segDeeplabResnet101 = useSemanticSegmentation({ - model: DEEPLAB_V3_RESNET101_QUANTIZED, - preventLoad: activeModel !== 'segmentation_deeplab_resnet101', - }); - const segDeeplabMobilenet = useSemanticSegmentation({ - model: DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED, - preventLoad: activeModel !== 'segmentation_deeplab_mobilenet', - }); - const segLraspp = useSemanticSegmentation({ - model: LRASPP_MOBILENET_V3_LARGE_QUANTIZED, - preventLoad: activeModel !== 'segmentation_lraspp', - }); - const segFcnResnet50 = useSemanticSegmentation({ - model: FCN_RESNET50_QUANTIZED, - preventLoad: activeModel !== 'segmentation_fcn_resnet50', - }); - const segFcnResnet101 = useSemanticSegmentation({ - model: FCN_RESNET101_QUANTIZED, - preventLoad: activeModel !== 'segmentation_fcn_resnet101', - }); - const segSelfie = useSemanticSegmentation({ - model: SELFIE_SEGMENTATION, - preventLoad: activeModel !== 'segmentation_selfie', - }); - - const activeSegmentation = - { - segmentation_deeplab_resnet50: segDeeplabResnet50, - segmentation_deeplab_resnet101: segDeeplabResnet101, - segmentation_deeplab_mobilenet: segDeeplabMobilenet, - segmentation_lraspp: segLraspp, - segmentation_fcn_resnet50: segFcnResnet50, - segmentation_fcn_resnet101: segFcnResnet101, - segmentation_selfie: segSelfie, - }[ - activeModel as - | 'segmentation_deeplab_resnet50' - | 'segmentation_deeplab_resnet101' - | 'segmentation_deeplab_mobilenet' - | 'segmentation_lraspp' - | 'segmentation_fcn_resnet50' - | 'segmentation_fcn_resnet101' - | 'segmentation_selfie' - ] ?? null; - - const activeIsGenerating = - activeModel === 'classification' - ? classification.isGenerating - : activeModel.startsWith('objectDetection') - ? (activeObjectDetection?.isGenerating ?? false) - : (activeSegmentation?.isGenerating ?? false); - - useEffect(() => { - setGlobalGenerating(activeIsGenerating); - }, [activeIsGenerating, setGlobalGenerating]); - const [fps, setFps] = useState(0); const [frameMs, setFrameMs] = useState(0); - const lastFrameTimeRef = useRef(Date.now()); + const [isReady, setIsReady] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); + const [frameOutput, setFrameOutput] = useState | null>(null); + const { setGlobalGenerating } = useContext(GeneratingContext); + const isFocused = useIsFocused(); const cameraPermission = useCameraPermission(); const devices = useCameraDevices(); @@ -209,168 +113,25 @@ export default function VisionCameraScreen() { } }, [device]); - const [classResult, setClassResult] = useState({ label: '', score: 0 }); - const [detections, setDetections] = useState([]); - const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); - const [maskImage, setMaskImage] = useState(null); - - const updateClass = useCallback((r: { label: string; score: number }) => { - setClassResult(r); - const now = Date.now(); - const diff = now - lastFrameTimeRef.current; - if (diff > 0) { - setFps(Math.round(1000 / diff)); - setFrameMs(diff); - } - lastFrameTimeRef.current = now; - }, []); - - const updateFps = useCallback(() => { - const now = Date.now(); - const diff = now - lastFrameTimeRef.current; - if (diff > 0) { - setFps(Math.round(1000 / diff)); - setFrameMs(diff); - } - lastFrameTimeRef.current = now; - }, []); - - const updateDetections = useCallback( - (p: { results: Detection[]; imageWidth: number; imageHeight: number }) => { - setDetections(p.results); - setImageSize({ width: p.imageWidth, height: p.imageHeight }); - updateFps(); - }, - [updateFps] - ); - - const updateMask = useCallback( - (img: SkImage) => { - setMaskImage((prev) => { - prev?.dispose(); - return img; - }); - updateFps(); - }, - [updateFps] - ); - - const classRof = classification.runOnFrame; - const detRof = activeObjectDetection?.runOnFrame ?? null; - const segRof = activeSegmentation?.runOnFrame ?? null; - useEffect(() => { frameKillSwitch.setBlocking(true); - setMaskImage((prev) => { - prev?.dispose(); - return null; - }); const id = setTimeout(() => { frameKillSwitch.setBlocking(false); }, 300); return () => clearTimeout(id); }, [activeModel]); - const frameOutput = useFrameOutput({ - pixelFormat: 'rgb', - dropFramesWhileBusy: true, - onFrame: useCallback( - (frame: Frame) => { - 'worklet'; - - if (frameKillSwitch.getDirty()) { - frame.dispose(); - return; - } - - try { - if (activeModel === 'classification') { - if (!classRof) return; - const result = classRof(frame); - if (result) { - let bestLabel = ''; - let bestScore = -1; - const entries = Object.entries(result); - for (let i = 0; i < entries.length; i++) { - const [label, score] = entries[i]!; - if ((score as number) > bestScore) { - bestScore = score as number; - bestLabel = label; - } - } - scheduleOnRN(updateClass, { label: bestLabel, score: bestScore }); - } - } else if (activeModel.startsWith('objectDetection')) { - if (!detRof) return; - const iw = frame.width > frame.height ? frame.height : frame.width; - const ih = frame.width > frame.height ? frame.width : frame.height; - const result = detRof(frame, 0.5); - if (result) { - scheduleOnRN(updateDetections, { - results: result, - imageWidth: iw, - imageHeight: ih, - }); - } - } else if (activeModel.startsWith('segmentation')) { - if (!segRof) return; - const result = segRof(frame, [], false); - if (result?.ARGMAX) { - const argmax: Int32Array = result.ARGMAX; - const side = Math.round(Math.sqrt(argmax.length)); - const pixels = new Uint8Array(side * side * 4); - for (let i = 0; i < argmax.length; i++) { - const color = CLASS_COLORS[argmax[i]!] ?? [0, 0, 0, 0]; - pixels[i * 4] = color[0]!; - pixels[i * 4 + 1] = color[1]!; - pixels[i * 4 + 2] = color[2]!; - pixels[i * 4 + 3] = color[3]!; - } - const skData = Skia.Data.fromBytes(pixels); - const img = Skia.Image.MakeImage( - { - width: side, - height: side, - alphaType: AlphaType.Unpremul, - colorType: ColorType.RGBA_8888, - }, - skData, - side * 4 - ); - if (img) scheduleOnRN(updateMask, img); - } - } - } catch { - // ignore - } finally { - frame.dispose(); - } - }, - [ - activeModel, - classRof, - detRof, - segRof, - updateClass, - updateDetections, - updateMask, - ] - ), - }); - - const activeIsReady = - activeModel === 'classification' - ? classification.isReady - : activeModel.startsWith('objectDetection') - ? (activeObjectDetection?.isReady ?? false) - : (activeSegmentation?.isReady ?? false); + const handleFpsChange = useCallback((newFps: number, newMs: number) => { + setFps(newFps); + setFrameMs(newMs); + }, []); - const activeDownloadProgress = - activeModel === 'classification' - ? classification.downloadProgress - : activeModel.startsWith('objectDetection') - ? (activeObjectDetection?.downloadProgress ?? 0) - : (activeSegmentation?.downloadProgress ?? 0); + const handleGeneratingChange = useCallback( + (generating: boolean) => { + setGlobalGenerating(generating); + }, + [setGlobalGenerating] + ); if (!cameraPermission.hasPermission) { return ( @@ -394,26 +155,23 @@ export default function VisionCameraScreen() { ); } - function coverFit(imgW: number, imgH: number) { - const scale = Math.max(canvasSize.width / imgW, canvasSize.height / imgH); - return { - scale, - offsetX: (canvasSize.width - imgW * scale) / 2, - offsetY: (canvasSize.height - imgH * scale) / 2, - }; - } - - const { - scale: detScale, - offsetX: detOX, - offsetY: detOY, - } = coverFit(imageSize.width, imageSize.height); - const activeTaskInfo = TASKS.find((t) => t.id === activeTask)!; const activeVariantLabel = activeTaskInfo.variants.find((v) => v.id === activeModel)?.label ?? activeTaskInfo.variants[0]!.label; + const taskProps = { + activeModel, + canvasSize, + cameraPosition, + frameKillSwitch, + onFrameOutputChange: setFrameOutput, + onReadyChange: setIsReady, + onProgressChange: setDownloadProgress, + onGeneratingChange: handleGeneratingChange, + onFpsChange: handleFpsChange, + }; + return ( @@ -421,17 +179,15 @@ export default function VisionCameraScreen() { + {/* Layout sentinel — measures the full-screen area for bbox/canvas sizing */} setCanvasSize({ @@ -439,75 +195,38 @@ export default function VisionCameraScreen() { height: e.nativeEvent.layout.height, }) } - > - {activeModel.startsWith('segmentation') && maskImage && ( - - - - )} - - {activeModel.startsWith('objectDetection') && ( - <> - {detections.map((det, i) => { - const left = det.bbox.x1 * detScale + detOX; - const top = det.bbox.y1 * detScale + detOY; - const w = (det.bbox.x2 - det.bbox.x1) * detScale; - const h = (det.bbox.y2 - det.bbox.y1) * detScale; - return ( - - - - {det.label} {(det.score * 100).toFixed(1)} - - - - ); - })} - - )} - + /> - {activeModel === 'classification' && classResult.label ? ( - - {classResult.label} - - {(classResult.score * 100).toFixed(1)}% - - - ) : null} + {activeTask === 'classification' && } + {activeTask === 'objectDetection' && ( + + )} + {activeTask === 'segmentation' && ( + + )} - {!activeIsReady && ( + {!isReady && ( )} @@ -587,7 +306,6 @@ export default function VisionCameraScreen() { } > - {/* Camera body */} - {/* Rotate arrows — arc with arrowhead around the lens */} Date: Fri, 13 Mar 2026 12:02:57 +0100 Subject: [PATCH 58/71] refactor: remove comments --- apps/computer-vision/app/vision_camera/index.tsx | 1 - .../app/vision_camera/tasks/ClassificationTask.tsx | 1 - .../app/vision_camera/tasks/ObjectDetectionTask.tsx | 1 - .../app/vision_camera/tasks/SegmentationTask.tsx | 1 - apps/computer-vision/app/vision_camera/tasks/types.ts | 1 - apps/computer-vision/app/vision_camera/utils/colors.ts | 2 -- 6 files changed, 7 deletions(-) diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index c134455042..20254f1547 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -1,4 +1,3 @@ -// apps/computer-vision/app/vision_camera/index.tsx import React, { useCallback, useContext, diff --git a/apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx b/apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx index 7111caf9c2..c9b4a2bf21 100644 --- a/apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx +++ b/apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx @@ -1,4 +1,3 @@ -// apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx import React, { useCallback, useEffect, useRef, useState } from 'react'; import { StyleSheet, Text, View } from 'react-native'; import { Frame, useFrameOutput } from 'react-native-vision-camera'; diff --git a/apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx b/apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx index f9edf2f3d0..de9f77edb5 100644 --- a/apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx +++ b/apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx @@ -1,4 +1,3 @@ -// apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx import React, { useCallback, useEffect, useRef, useState } from 'react'; import { StyleSheet, Text, View } from 'react-native'; import { Frame, useFrameOutput } from 'react-native-vision-camera'; diff --git a/apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx b/apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx index 9f7c2a67d7..3064e309bf 100644 --- a/apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx +++ b/apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx @@ -1,4 +1,3 @@ -// apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx import React, { useCallback, useEffect, useRef, useState } from 'react'; import { StyleSheet, View } from 'react-native'; import { Frame, useFrameOutput } from 'react-native-vision-camera'; diff --git a/apps/computer-vision/app/vision_camera/tasks/types.ts b/apps/computer-vision/app/vision_camera/tasks/types.ts index fabc85c576..9727227f2f 100644 --- a/apps/computer-vision/app/vision_camera/tasks/types.ts +++ b/apps/computer-vision/app/vision_camera/tasks/types.ts @@ -1,4 +1,3 @@ -// apps/computer-vision/app/vision_camera/tasks/types.ts import { useFrameOutput } from 'react-native-vision-camera'; import { createSynchronizable } from 'react-native-worklets'; diff --git a/apps/computer-vision/app/vision_camera/utils/colors.ts b/apps/computer-vision/app/vision_camera/utils/colors.ts index 5d03ca65cc..c38493a3b0 100644 --- a/apps/computer-vision/app/vision_camera/utils/colors.ts +++ b/apps/computer-vision/app/vision_camera/utils/colors.ts @@ -1,5 +1,3 @@ -// apps/computer-vision/app/vision_camera/utils/colors.ts - export const CLASS_COLORS: number[][] = [ [0, 0, 0, 0], [51, 255, 87, 180], From 8e55c12990c63f6f90d16b642206ab4c3547ddee Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 14:57:49 +0100 Subject: [PATCH 59/71] fix: after rebase --- .../computer_vision/useClassification.ts | 34 +++++++++++++------ .../computer_vision/useImageEmbeddings.ts | 34 +++++++++++++------ .../computer_vision/ClassificationModule.ts | 11 ++---- .../computer_vision/ImageEmbeddingsModule.ts | 10 ++---- .../computer_vision/StyleTransferModule.ts | 8 ++--- 5 files changed, 57 insertions(+), 40 deletions(-) diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts index c014d6b0ed..2e3b4e22a0 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts @@ -3,6 +3,7 @@ import { ClassificationProps, ClassificationType, } from '../../types/classification'; +import { PixelData } from '../../types/common'; import { useModuleFactory } from '../useModuleFactory'; /** @@ -16,17 +17,30 @@ export const useClassification = ({ model, preventLoad = false, }: ClassificationProps): ClassificationType => { - const { error, isReady, isGenerating, downloadProgress, runForward } = - useModuleFactory({ - factory: (config, onProgress) => - ClassificationModule.fromModelName(config, onProgress), - config: model, - deps: [model.modelName, model.modelSource], - preventLoad, - }); + const { + error, + isReady, + isGenerating, + downloadProgress, + runForward, + runOnFrame, + } = useModuleFactory({ + factory: (config, onProgress) => + ClassificationModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); - const forward = (imageSource: string) => + const forward = (imageSource: string | PixelData) => runForward((inst) => inst.forward(imageSource)); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + runOnFrame, + }; }; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts index b4e79c9263..8c0e6209e5 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts @@ -3,6 +3,7 @@ import { ImageEmbeddingsProps, ImageEmbeddingsType, } from '../../types/imageEmbeddings'; +import { PixelData } from '../../types/common'; import { useModuleFactory } from '../useModuleFactory'; /** @@ -16,17 +17,30 @@ export const useImageEmbeddings = ({ model, preventLoad = false, }: ImageEmbeddingsProps): ImageEmbeddingsType => { - const { error, isReady, isGenerating, downloadProgress, runForward } = - useModuleFactory({ - factory: (config, onProgress) => - ImageEmbeddingsModule.fromModelName(config, onProgress), - config: model, - deps: [model.modelName, model.modelSource], - preventLoad, - }); + const { + error, + isReady, + isGenerating, + downloadProgress, + runForward, + runOnFrame, + } = useModuleFactory({ + factory: (config, onProgress) => + ImageEmbeddingsModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); - const forward = (imageSource: string) => + const forward = (imageSource: string | PixelData) => runForward((inst) => inst.forward(imageSource)); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + runOnFrame, + }; }; diff --git a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts index 61d5c48c90..45154ef996 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts @@ -14,11 +14,6 @@ import { VisionModule } from './VisionModule'; export class ClassificationModule extends VisionModule<{ [category: string]: number; }> { - private constructor(nativeModule: unknown) { - super(); - this.nativeModule = nativeModule; - } - /** * Creates a classification instance for a built-in model. * @@ -46,9 +41,9 @@ export class ClassificationModule extends VisionModule<{ ); } - return new ClassificationModule( - await global.loadClassification(paths[0]) - ); + const instance = new ClassificationModule(); + instance.nativeModule = await global.loadClassification(paths[0]); + return instance; } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts index c4cd57b889..e021182438 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts @@ -12,10 +12,6 @@ import { VisionModule } from './VisionModule'; * @category Typescript API */ export class ImageEmbeddingsModule extends VisionModule { - private constructor(nativeModule: unknown) { - super(); - this.nativeModule = nativeModule; - } /** * Creates an image embeddings instance for a built-in model. * @@ -43,9 +39,9 @@ export class ImageEmbeddingsModule extends VisionModule { ); } - return new ImageEmbeddingsModule( - await global.loadImageEmbeddings(paths[0]) - ); + const instance = new ImageEmbeddingsModule(); + instance.nativeModule = await global.loadImageEmbeddings(paths[0]); + return instance; } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index 8dbc2ec8e6..9df1cbf31b 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -12,10 +12,6 @@ import { VisionModule } from './VisionModule'; * @category Typescript API */ export class StyleTransferModule extends VisionModule { - private constructor(nativeModule: unknown) { - super(); - this.nativeModule = nativeModule; - } /** * Creates a style transfer instance for a built-in model. * @@ -43,7 +39,9 @@ export class StyleTransferModule extends VisionModule { ); } - return new StyleTransferModule(await global.loadStyleTransfer(paths[0])); + const instance = new StyleTransferModule(); + instance.nativeModule = await global.loadStyleTransfer(paths[0]); + return instance; } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); From 61eaf793919c1e9b84a329dc7ea3fd6328f8c798 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 15:12:25 +0100 Subject: [PATCH 60/71] refactor: batch 1 suggestions --- .../rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp | 3 --- .../rnexecutorch/tests/integration/StyleTransferTest.cpp | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp index f742f13f65..d2914469af 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp @@ -1,7 +1,4 @@ #include "ImageEmbeddings.h" - -#include - #include #include #include diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp index c18131002a..9ff30dbe86 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp @@ -191,8 +191,8 @@ TEST(StyleTransferPixelTests, OutputDimensionsMatchInputSize) { // ============================================================================ TEST(StyleTransferThreadSafetyTests, TwoConcurrentGeneratesDoNotCrash) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - std::atomic successCount{0}; - std::atomic exceptionCount{0}; + std::atomic successCount{0}; + std::atomic exceptionCount{0}; auto task = [&]() { try { From 090e95b0594f75462d236abf323137f3a60c99a1 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 15:16:25 +0100 Subject: [PATCH 61/71] tests: apply merging suggestsions --- .../tests/integration/StyleTransferTest.cpp | 19 +++------------- .../tests/unit/FrameProcessorTest.cpp | 22 ++++++------------- 2 files changed, 10 insertions(+), 31 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp index 9ff30dbe86..4a2519f97f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp @@ -93,13 +93,7 @@ TEST(StyleTransferGenerateTests, MultipleGeneratesWork) { // ============================================================================ // generateFromString saveToFile tests // ============================================================================ -TEST(StyleTransferSaveToFileTests, SaveToFileFalseReturnsPixelDataVariant) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - auto result = model.generateFromString(kValidTestImagePath, false); - EXPECT_TRUE(std::holds_alternative(result)); -} - -TEST(StyleTransferSaveToFileTests, SaveToFileFalsePixelDataIsNonNull) { +TEST(StyleTransferSaveToFileTests, SaveToFileFalseReturnsValidPixelData) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); auto result = model.generateFromString(kValidTestImagePath, false); ASSERT_TRUE(std::holds_alternative(result)); @@ -158,15 +152,8 @@ TEST(StyleTransferPixelTests, ValidPixelsSaveToFileFalseHasPositiveDimensions) { EXPECT_GT(pr.height, 0); } -TEST(StyleTransferPixelTests, ValidPixelsSaveToFileTrueReturnsString) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - std::vector buf; - auto view = makeRgbView(buf, 64, 64); - auto result = model.generateFromPixels(view, true); - EXPECT_TRUE(std::holds_alternative(result)); -} - -TEST(StyleTransferPixelTests, ValidPixelsSaveToFileTrueHasFileScheme) { +TEST(StyleTransferPixelTests, + ValidPixelsSaveToFileTrueReturnsFileSchemeString) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); std::vector buf; auto view = makeRgbView(buf, 64, 64); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp index 6465db2a75..cfea1eb2a4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp @@ -18,28 +18,20 @@ static JSTensorViewIn makeValidView(std::vector &buf, int32_t h, // ============================================================================ // Valid input // ============================================================================ -TEST(PixelsToMatValidInput, ProducesCorrectRows) { +TEST(PixelsToMatValidInput, ProducesCorrectDimensions) { std::vector buf; auto view = makeValidView(buf, 48, 64); - EXPECT_EQ(pixelsToMat(view).rows, 48); -} - -TEST(PixelsToMatValidInput, ProducesCorrectCols) { - std::vector buf; - auto view = makeValidView(buf, 48, 64); - EXPECT_EQ(pixelsToMat(view).cols, 64); -} - -TEST(PixelsToMatValidInput, ProducesThreeChannelMat) { - std::vector buf; - auto view = makeValidView(buf, 4, 4); - EXPECT_EQ(pixelsToMat(view).channels(), 3); + auto mat = pixelsToMat(view); + EXPECT_EQ(mat.rows, 48); + EXPECT_EQ(mat.cols, 64); } TEST(PixelsToMatValidInput, MatTypeIsCV_8UC3) { std::vector buf; auto view = makeValidView(buf, 4, 4); - EXPECT_EQ(pixelsToMat(view).type(), CV_8UC3); + auto mat = pixelsToMat(view); + EXPECT_EQ(mat.channels(), 3); + EXPECT_EQ(mat.type(), CV_8UC3); } TEST(PixelsToMatValidInput, MatWrapsOriginalData) { From 3a5f564f1b819f5b9501bc036d1c58ecd9f52b61 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 13 Mar 2026 18:39:54 +0100 Subject: [PATCH 62/71] fix: move the vision camera components so they are not treated as part of navigation --- apps/computer-vision/app/vision_camera/index.tsx | 6 +++--- .../vision_camera/tasks/ClassificationTask.tsx | 0 .../vision_camera/tasks/ObjectDetectionTask.tsx | 0 .../vision_camera/tasks/SegmentationTask.tsx | 0 .../{app => components}/vision_camera/tasks/types.ts | 0 .../{app => components}/vision_camera/utils/colors.ts | 0 6 files changed, 3 insertions(+), 3 deletions(-) rename apps/computer-vision/{app => components}/vision_camera/tasks/ClassificationTask.tsx (100%) rename apps/computer-vision/{app => components}/vision_camera/tasks/ObjectDetectionTask.tsx (100%) rename apps/computer-vision/{app => components}/vision_camera/tasks/SegmentationTask.tsx (100%) rename apps/computer-vision/{app => components}/vision_camera/tasks/types.ts (100%) rename apps/computer-vision/{app => components}/vision_camera/utils/colors.ts (100%) diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index 20254f1547..f6c1804ac7 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -28,9 +28,9 @@ import Svg, { Path, Polygon } from 'react-native-svg'; import { GeneratingContext } from '../../context'; import Spinner from '../../components/Spinner'; import ColorPalette from '../../colors'; -import ClassificationTask from './tasks/ClassificationTask'; -import ObjectDetectionTask from './tasks/ObjectDetectionTask'; -import SegmentationTask from './tasks/SegmentationTask'; +import ClassificationTask from '../../components/vision_camera/tasks/ClassificationTask'; +import ObjectDetectionTask from '../../components/vision_camera/tasks/ObjectDetectionTask'; +import SegmentationTask from '../../components/vision_camera/tasks/SegmentationTask'; type TaskId = 'classification' | 'objectDetection' | 'segmentation'; type ModelId = diff --git a/apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx b/apps/computer-vision/components/vision_camera/tasks/ClassificationTask.tsx similarity index 100% rename from apps/computer-vision/app/vision_camera/tasks/ClassificationTask.tsx rename to apps/computer-vision/components/vision_camera/tasks/ClassificationTask.tsx diff --git a/apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx similarity index 100% rename from apps/computer-vision/app/vision_camera/tasks/ObjectDetectionTask.tsx rename to apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx diff --git a/apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx b/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx similarity index 100% rename from apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx rename to apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx diff --git a/apps/computer-vision/app/vision_camera/tasks/types.ts b/apps/computer-vision/components/vision_camera/tasks/types.ts similarity index 100% rename from apps/computer-vision/app/vision_camera/tasks/types.ts rename to apps/computer-vision/components/vision_camera/tasks/types.ts diff --git a/apps/computer-vision/app/vision_camera/utils/colors.ts b/apps/computer-vision/components/vision_camera/utils/colors.ts similarity index 100% rename from apps/computer-vision/app/vision_camera/utils/colors.ts rename to apps/computer-vision/components/vision_camera/utils/colors.ts From ac7554162018e924ab4eb62ebcfcc1fc5e0d6425 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Mar 2026 10:36:59 +0100 Subject: [PATCH 63/71] fix: style transfer crashes app, rename `output` -> `outputType` --- .../02-computer-vision/useStyleTransfer.md | 4 +-- .../02-computer-vision/StyleTransferModule.md | 2 +- .../hooks/computer_vision/useStyleTransfer.ts | 29 ++++++++++++------- .../computer_vision/StyleTransferModule.ts | 4 +-- .../src/types/styleTransfer.ts | 6 ++-- 5 files changed, 27 insertions(+), 18 deletions(-) diff --git a/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md b/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md index 636589d7c7..43cc9c2c25 100644 --- a/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md +++ b/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md @@ -57,12 +57,12 @@ You need more details? Check the following resources: To run the model, use the [`forward`](../../06-api-reference/interfaces/StyleTransferType.md#forward) method. It accepts two arguments: - `input` (required) — The image to stylize. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). -- `output` (optional) — Controls the return format: +- `outputType` (optional) — Controls the return format: - `'pixelData'` (default) — Returns a `PixelData` object with raw RGB pixels. No file is written. - `'url'` — Saves the result to a temp file and returns its URI as a `string`. :::info -When `output` is `'url'`, the generated image is stored in your application's temporary directory. +When `outputType` is `'url'`, the generated image is stored in your application's temporary directory. ::: ## Example diff --git a/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md b/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md index 7475398971..4c57716001 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md @@ -50,7 +50,7 @@ For more information on loading resources, take a look at [loading models](../.. To run the model, use the [`forward`](../../06-api-reference/classes/StyleTransferModule.md#forward) method. It accepts two arguments: - `input` (required) — The image to stylize. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). -- `output` (optional) — Controls the return format: +- `outputType` (optional) — Controls the return format: - `'pixelData'` (default) — Returns a `PixelData` object with raw RGB pixels. No file is written. - `'url'` — Saves the result to a temp file and returns its URI as a `string`. diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts index 6ebd1ccd6c..dfa9095cc7 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts @@ -17,17 +17,25 @@ export const useStyleTransfer = ({ model, preventLoad = false, }: StyleTransferProps): StyleTransferType => { - const { error, isReady, isGenerating, downloadProgress, runForward } = - useModuleFactory({ - factory: (config, onProgress) => - StyleTransferModule.fromModelName(config, onProgress), - config: model, - deps: [model.modelName, model.modelSource], - preventLoad, - }); + const { + error, + isReady, + isGenerating, + downloadProgress, + runForward, + runOnFrame, + } = useModuleFactory({ + factory: (config, onProgress) => + StyleTransferModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); - const forward = (imageSource: string | PixelData) => - runForward((inst) => inst.forward(imageSource)); + const forward = ( + imageSource: string | PixelData, + outputType?: O + ) => runForward((inst) => inst.forward(imageSource, outputType)); return { error, @@ -35,5 +43,6 @@ export const useStyleTransfer = ({ isGenerating, downloadProgress, forward, + runOnFrame, } as StyleTransferType; }; diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index 9df1cbf31b..a7c9b38fcf 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -71,9 +71,9 @@ export class StyleTransferModule extends VisionModule { async forward( input: string | PixelData, - output?: O + outputType?: O ): Promise { - return super.forward(input, output === 'url') as Promise< + return super.forward(input, outputType === 'url') as Promise< O extends 'url' ? string : PixelData >; } diff --git a/packages/react-native-executorch/src/types/styleTransfer.ts b/packages/react-native-executorch/src/types/styleTransfer.ts index f14412e4a8..a325f94d25 100644 --- a/packages/react-native-executorch/src/types/styleTransfer.ts +++ b/packages/react-native-executorch/src/types/styleTransfer.ts @@ -67,13 +67,13 @@ export interface StyleTransferType { * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. * * @param input - Image source (string or PixelData object) - * @param output - Output format: `'pixelData'` (default) returns raw RGBA pixel data; `'url'` saves the result to a temp file and returns its `file://` path. - * @returns A Promise resolving to `PixelData` when `output` is `'pixelData'` (default), or a `file://` URL string when `output` is `'url'`. + * @param outputType - Output format: `'pixelData'` (default) returns raw RGBA pixel data; `'url'` saves the result to a temp file and returns its `file://` path. + * @returns A Promise resolving to `PixelData` when `outputType` is `'pixelData'` (default), or a `file://` URL string when `outputType` is `'url'`. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ forward( input: string | PixelData, - output?: O + outputType?: O ): Promise; /** From c075d8bb84b6612fb834be76a5be13d861249ff9 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Mar 2026 10:42:42 +0100 Subject: [PATCH 64/71] reafactor: use ScalarType enum instead of magic number in jsi conversions --- .../common/rnexecutorch/host_objects/JsiConversions.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 4f55c9bcfb..dc715f0f2a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -578,7 +578,8 @@ getJsiValue(const models::style_transfer::PixelDataResult &result, sizesArray.setValueAtIndex(runtime, 2, jsi::Value(4)); obj.setProperty(runtime, "sizes", sizesArray); - obj.setProperty(runtime, "scalarType", jsi::Value(0)); + obj.setProperty(runtime, "scalarType", + jsi::Value(static_cast(ScalarType::Byte))); return obj; } From 33721b46e9a2690b0c5a5b6a03cab22accc8e6eb Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Mar 2026 10:44:37 +0100 Subject: [PATCH 65/71] docs: remove a comment about vision camera integration --- docs/docs/03-hooks/02-computer-vision/useClassification.md | 2 -- docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md | 2 -- docs/docs/03-hooks/02-computer-vision/useOCR.md | 2 -- docs/docs/03-hooks/02-computer-vision/useObjectDetection.md | 2 -- .../docs/03-hooks/02-computer-vision/useSemanticSegmentation.md | 2 -- docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md | 2 -- docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md | 2 -- 7 files changed, 14 deletions(-) diff --git a/docs/docs/03-hooks/02-computer-vision/useClassification.md b/docs/docs/03-hooks/02-computer-vision/useClassification.md index d627d95364..e88cce1aff 100644 --- a/docs/docs/03-hooks/02-computer-vision/useClassification.md +++ b/docs/docs/03-hooks/02-computer-vision/useClassification.md @@ -60,8 +60,6 @@ Images from external sources are stored in your application's temporary director ## VisionCamera integration -For real-time classification on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns `{ [category: string]: number }`. - See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Example diff --git a/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md b/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md index 8ce07e0ba4..a6ea5fa982 100644 --- a/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md +++ b/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md @@ -67,8 +67,6 @@ To run the model, use the [`forward`](../../06-api-reference/interfaces/ImageEmb ## VisionCamera integration -For real-time embedding on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns `Float32Array`. - See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Example diff --git a/docs/docs/03-hooks/02-computer-vision/useOCR.md b/docs/docs/03-hooks/02-computer-vision/useOCR.md index f622414651..41491c7143 100644 --- a/docs/docs/03-hooks/02-computer-vision/useOCR.md +++ b/docs/docs/03-hooks/02-computer-vision/useOCR.md @@ -54,8 +54,6 @@ To run the model, use the [`forward`](../../06-api-reference/interfaces/OCRType. ## VisionCamera integration -For real-time text recognition on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns `OCRDetection[]`. - See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Detection object diff --git a/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md b/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md index a910645d32..5fb2b2bb3a 100644 --- a/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md +++ b/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md @@ -109,8 +109,6 @@ function App() { ## VisionCamera integration -For real-time object detection on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns `Detection[]`. - See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Supported models diff --git a/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md b/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md index 1ca0987361..dc654369c7 100644 --- a/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md +++ b/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md @@ -117,8 +117,6 @@ function App() { ## VisionCamera integration -For real-time segmentation on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns the same segmentation result object as `forward`. - See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Supported models diff --git a/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md b/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md index 43cc9c2c25..d08d7e8688 100644 --- a/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md +++ b/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md @@ -100,8 +100,6 @@ function App() { ## VisionCamera integration -For real-time style transfer on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and always returns `PixelData`. - See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Supported models diff --git a/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md b/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md index 53d0e8b7ff..80b142ac62 100644 --- a/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md +++ b/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md @@ -62,8 +62,6 @@ To run the model, use the [`forward`](../../06-api-reference/interfaces/OCRType. ## VisionCamera integration -For real-time text recognition on camera frames, use `runOnFrame`. It runs synchronously on the JS worklet thread and returns `OCRDetection[]`. - See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Detection object From 81b490a87abb07a086408a9460807232f8d0494d Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Mar 2026 10:49:18 +0100 Subject: [PATCH 66/71] chore: use camelCase ids for model in vision camera demo --- .eslintrc.js | 2 +- .../app/vision_camera/index.tsx | 52 +++++++++---------- .../tasks/ObjectDetectionTask.tsx | 8 +-- .../vision_camera/tasks/SegmentationTask.tsx | 42 +++++++-------- 4 files changed, 52 insertions(+), 52 deletions(-) diff --git a/.eslintrc.js b/.eslintrc.js index 26f3b92475..a9613d48ed 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -32,7 +32,7 @@ module.exports = { customWordListFile: path.resolve(__dirname, '.cspell-wordlist.txt'), }, ], - 'camelcase': ['error', { properties: 'never' }], + 'camelcase': 'error', }, plugins: ['prettier', 'markdown'], overrides: [ diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index f6c1804ac7..b2af60d504 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -35,15 +35,15 @@ import SegmentationTask from '../../components/vision_camera/tasks/SegmentationT type TaskId = 'classification' | 'objectDetection' | 'segmentation'; type ModelId = | 'classification' - | 'objectDetection_ssdlite' - | 'objectDetection_rfdetr' - | 'segmentation_deeplab_resnet50' - | 'segmentation_deeplab_resnet101' - | 'segmentation_deeplab_mobilenet' - | 'segmentation_lraspp' - | 'segmentation_fcn_resnet50' - | 'segmentation_fcn_resnet101' - | 'segmentation_selfie'; + | 'objectDetectionSsdlite' + | 'objectDetectionRfdetr' + | 'segmentationDeeplabResnet50' + | 'segmentationDeeplabResnet101' + | 'segmentationDeeplabMobilenet' + | 'segmentationLraspp' + | 'segmentationFcnResnet50' + | 'segmentationFcnResnet101' + | 'segmentationSelfie'; type TaskVariant = { id: ModelId; label: string }; type Task = { id: TaskId; label: string; variants: TaskVariant[] }; @@ -58,21 +58,21 @@ const TASKS: Task[] = [ id: 'segmentation', label: 'Segment', variants: [ - { id: 'segmentation_deeplab_resnet50', label: 'DeepLab ResNet50' }, - { id: 'segmentation_deeplab_resnet101', label: 'DeepLab ResNet101' }, - { id: 'segmentation_deeplab_mobilenet', label: 'DeepLab MobileNet' }, - { id: 'segmentation_lraspp', label: 'LRASPP MobileNet' }, - { id: 'segmentation_fcn_resnet50', label: 'FCN ResNet50' }, - { id: 'segmentation_fcn_resnet101', label: 'FCN ResNet101' }, - { id: 'segmentation_selfie', label: 'Selfie' }, + { id: 'segmentationDeeplabResnet50', label: 'DeepLab ResNet50' }, + { id: 'segmentationDeeplabResnet101', label: 'DeepLab ResNet101' }, + { id: 'segmentationDeeplabMobilenet', label: 'DeepLab MobileNet' }, + { id: 'segmentationLraspp', label: 'LRASPP MobileNet' }, + { id: 'segmentationFcnResnet50', label: 'FCN ResNet50' }, + { id: 'segmentationFcnResnet101', label: 'FCN ResNet101' }, + { id: 'segmentationSelfie', label: 'Selfie' }, ], }, { id: 'objectDetection', label: 'Detect', variants: [ - { id: 'objectDetection_ssdlite', label: 'SSDLite MobileNet' }, - { id: 'objectDetection_rfdetr', label: 'RF-DETR Nano' }, + { id: 'objectDetectionSsdlite', label: 'SSDLite MobileNet' }, + { id: 'objectDetectionRfdetr', label: 'RF-DETR Nano' }, ], }, ]; @@ -201,7 +201,7 @@ export default function VisionCameraScreen() { )} @@ -210,13 +210,13 @@ export default function VisionCameraScreen() { {...taskProps} activeModel={ activeModel as - | 'segmentation_deeplab_resnet50' - | 'segmentation_deeplab_resnet101' - | 'segmentation_deeplab_mobilenet' - | 'segmentation_lraspp' - | 'segmentation_fcn_resnet50' - | 'segmentation_fcn_resnet101' - | 'segmentation_selfie' + | 'segmentationDeeplabResnet50' + | 'segmentationDeeplabResnet101' + | 'segmentationDeeplabMobilenet' + | 'segmentationLraspp' + | 'segmentationFcnResnet50' + | 'segmentationFcnResnet101' + | 'segmentationSelfie' } /> )} diff --git a/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx index de9f77edb5..a54d20c87e 100644 --- a/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx +++ b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx @@ -11,7 +11,7 @@ import { import { labelColor, labelColorBg } from '../utils/colors'; import { TaskProps } from './types'; -type ObjModelId = 'objectDetection_ssdlite' | 'objectDetection_rfdetr'; +type ObjModelId = 'objectDetectionSsdlite' | 'objectDetectionRfdetr'; type Props = TaskProps & { activeModel: ObjModelId }; @@ -28,14 +28,14 @@ export default function ObjectDetectionTask({ }: Props) { const ssdlite = useObjectDetection({ model: SSDLITE_320_MOBILENET_V3_LARGE, - preventLoad: activeModel !== 'objectDetection_ssdlite', + preventLoad: activeModel !== 'objectDetectionSsdlite', }); const rfdetr = useObjectDetection({ model: RF_DETR_NANO, - preventLoad: activeModel !== 'objectDetection_rfdetr', + preventLoad: activeModel !== 'objectDetectionRfdetr', }); - const active = activeModel === 'objectDetection_ssdlite' ? ssdlite : rfdetr; + const active = activeModel === 'objectDetectionSsdlite' ? ssdlite : rfdetr; const [detections, setDetections] = useState([]); const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); diff --git a/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx b/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx index 3064e309bf..8226b0aae9 100644 --- a/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx +++ b/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx @@ -24,13 +24,13 @@ import { CLASS_COLORS } from '../utils/colors'; import { TaskProps } from './types'; type SegModelId = - | 'segmentation_deeplab_resnet50' - | 'segmentation_deeplab_resnet101' - | 'segmentation_deeplab_mobilenet' - | 'segmentation_lraspp' - | 'segmentation_fcn_resnet50' - | 'segmentation_fcn_resnet101' - | 'segmentation_selfie'; + | 'segmentationDeeplabResnet50' + | 'segmentationDeeplabResnet101' + | 'segmentationDeeplabMobilenet' + | 'segmentationLraspp' + | 'segmentationFcnResnet50' + | 'segmentationFcnResnet101' + | 'segmentationSelfie'; type Props = TaskProps & { activeModel: SegModelId }; @@ -47,41 +47,41 @@ export default function SegmentationTask({ }: Props) { const segDeeplabResnet50 = useSemanticSegmentation({ model: DEEPLAB_V3_RESNET50_QUANTIZED, - preventLoad: activeModel !== 'segmentation_deeplab_resnet50', + preventLoad: activeModel !== 'segmentationDeeplabResnet50', }); const segDeeplabResnet101 = useSemanticSegmentation({ model: DEEPLAB_V3_RESNET101_QUANTIZED, - preventLoad: activeModel !== 'segmentation_deeplab_resnet101', + preventLoad: activeModel !== 'segmentationDeeplabResnet101', }); const segDeeplabMobilenet = useSemanticSegmentation({ model: DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED, - preventLoad: activeModel !== 'segmentation_deeplab_mobilenet', + preventLoad: activeModel !== 'segmentationDeeplabMobilenet', }); const segLraspp = useSemanticSegmentation({ model: LRASPP_MOBILENET_V3_LARGE_QUANTIZED, - preventLoad: activeModel !== 'segmentation_lraspp', + preventLoad: activeModel !== 'segmentationLraspp', }); const segFcnResnet50 = useSemanticSegmentation({ model: FCN_RESNET50_QUANTIZED, - preventLoad: activeModel !== 'segmentation_fcn_resnet50', + preventLoad: activeModel !== 'segmentationFcnResnet50', }); const segFcnResnet101 = useSemanticSegmentation({ model: FCN_RESNET101_QUANTIZED, - preventLoad: activeModel !== 'segmentation_fcn_resnet101', + preventLoad: activeModel !== 'segmentationFcnResnet101', }); const segSelfie = useSemanticSegmentation({ model: SELFIE_SEGMENTATION, - preventLoad: activeModel !== 'segmentation_selfie', + preventLoad: activeModel !== 'segmentationSelfie', }); const active = { - segmentation_deeplab_resnet50: segDeeplabResnet50, - segmentation_deeplab_resnet101: segDeeplabResnet101, - segmentation_deeplab_mobilenet: segDeeplabMobilenet, - segmentation_lraspp: segLraspp, - segmentation_fcn_resnet50: segFcnResnet50, - segmentation_fcn_resnet101: segFcnResnet101, - segmentation_selfie: segSelfie, + segmentationDeeplabResnet50: segDeeplabResnet50, + segmentationDeeplabResnet101: segDeeplabResnet101, + segmentationDeeplabMobilenet: segDeeplabMobilenet, + segmentationLraspp: segLraspp, + segmentationFcnResnet50: segFcnResnet50, + segmentationFcnResnet101: segFcnResnet101, + segmentationSelfie: segSelfie, }[activeModel]; const [maskImage, setMaskImage] = useState(null); From 648a045b340f13407c57ca6efdcbaaafc594b248 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Mar 2026 13:42:03 +0100 Subject: [PATCH 67/71] Update packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp Co-authored-by: Jakub Chmura <92989966+chmjkb@users.noreply.github.com> --- .../common/rnexecutorch/utils/FrameProcessor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp index 93f645b008..19df5ba34e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp @@ -69,7 +69,7 @@ cv::Mat pixelsToMat(const JSTensorViewIn &pixelData) { "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); } - uint8_t *dataPtr = static_cast(pixelData.dataPtr); + auto *dataPtr = static_cast(pixelData.dataPtr); return cv::Mat(height, width, CV_8UC3, dataPtr); } From 0a0c664d6979da553a1678bda7a0cfcac930deb2 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Mar 2026 13:43:15 +0100 Subject: [PATCH 68/71] feat: requested changes --- .../src/controllers/BaseOCRController.ts | 7 +++++-- .../computer_vision/ClassificationModule.ts | 10 +++++++--- .../computer_vision/ImageEmbeddingsModule.ts | 10 +++++++--- .../computer_vision/StyleTransferModule.ts | 8 +++++--- .../computer_vision/VisionLabeledModule.ts | 18 ++---------------- .../modules/computer_vision/VisionModule.ts | 11 ++++++----- .../src/types/classification.ts | 14 -------------- .../src/types/computerVision.ts | 17 +++++++++++++++++ .../src/types/imageEmbeddings.ts | 2 +- 9 files changed, 50 insertions(+), 47 deletions(-) diff --git a/packages/react-native-executorch/src/controllers/BaseOCRController.ts b/packages/react-native-executorch/src/controllers/BaseOCRController.ts index e16fd82477..5ef5f935bc 100644 --- a/packages/react-native-executorch/src/controllers/BaseOCRController.ts +++ b/packages/react-native-executorch/src/controllers/BaseOCRController.ts @@ -89,8 +89,11 @@ export abstract class BaseOCRController { }; get runOnFrame(): ((frame: Frame) => OCRDetection[]) | null { - if (!this.nativeModule?.generateFromFrame) { - return null; + if (!this.isReady) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling runOnFrame().' + ); } const nativeGenerateFromFrame = this.nativeModule.generateFromFrame; diff --git a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts index 45154ef996..d9ef0d7f73 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts @@ -14,6 +14,10 @@ import { VisionModule } from './VisionModule'; export class ClassificationModule extends VisionModule<{ [category: string]: number; }> { + private constructor(nativeModule: unknown) { + super(); + this.nativeModule = nativeModule; + } /** * Creates a classification instance for a built-in model. * @@ -41,9 +45,9 @@ export class ClassificationModule extends VisionModule<{ ); } - const instance = new ClassificationModule(); - instance.nativeModule = await global.loadClassification(paths[0]); - return instance; + return new ClassificationModule( + await global.loadClassification(paths[0]) + ); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts index e021182438..c4cd57b889 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts @@ -12,6 +12,10 @@ import { VisionModule } from './VisionModule'; * @category Typescript API */ export class ImageEmbeddingsModule extends VisionModule { + private constructor(nativeModule: unknown) { + super(); + this.nativeModule = nativeModule; + } /** * Creates an image embeddings instance for a built-in model. * @@ -39,9 +43,9 @@ export class ImageEmbeddingsModule extends VisionModule { ); } - const instance = new ImageEmbeddingsModule(); - instance.nativeModule = await global.loadImageEmbeddings(paths[0]); - return instance; + return new ImageEmbeddingsModule( + await global.loadImageEmbeddings(paths[0]) + ); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index a7c9b38fcf..6519a29b91 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -12,6 +12,10 @@ import { VisionModule } from './VisionModule'; * @category Typescript API */ export class StyleTransferModule extends VisionModule { + private constructor(nativeModule: unknown) { + super(); + this.nativeModule = nativeModule; + } /** * Creates a style transfer instance for a built-in model. * @@ -39,9 +43,7 @@ export class StyleTransferModule extends VisionModule { ); } - const instance = new StyleTransferModule(); - instance.nativeModule = await global.loadStyleTransfer(paths[0]); - return instance; + return new StyleTransferModule(await global.loadStyleTransfer(paths[0])); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); diff --git a/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts index 1cda359db5..188b03c8c9 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts @@ -4,6 +4,8 @@ import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError } from '../../errors/errorUtils'; import { VisionModule } from './VisionModule'; +export { ResolveLabels } from '../../types/computerVision'; + /** * Fetches a model binary and returns its local path, throwing if the download * was interrupted (paused or cancelled). @@ -24,22 +26,6 @@ export async function fetchModelPath( return paths[0]; } -/** - * Given a model configs record (mapping model names to `{ labelMap }`) and a - * type `T` (either a model name key or a raw {@link LabelEnum}), resolves to - * the label map for that model or `T` itself. - * - * @internal - */ -export type ResolveLabels< - T, - Configs extends Record, -> = T extends keyof Configs - ? Configs[T]['labelMap'] - : T extends LabelEnum - ? T - : never; - /** * Base class for computer vision modules that carry a type-safe label map * and support the full VisionModule API (string/PixelData forward + runOnFrame). diff --git a/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts index 31d3baba7b..d2c78edf0d 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts @@ -59,8 +59,11 @@ export abstract class VisionModule extends BaseModule { * ``` */ get runOnFrame(): ((frame: Frame, ...args: any[]) => TOutput) | null { - if (!this.nativeModule?.generateFromFrame) { - return null; + if (!this.nativeModule) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling runOnFrame().' + ); } // Extract pure JSI function reference (runs on JS thread) @@ -73,9 +76,7 @@ export abstract class VisionModule extends BaseModule { let nativeBuffer: any = null; try { nativeBuffer = frame.getNativeBuffer(); - const frameData = { - nativeBuffer: nativeBuffer.pointer, - }; + const frameData = { nativeBuffer: nativeBuffer.pointer }; return nativeGenerateFromFrame(frameData, ...args); } finally { if (nativeBuffer?.release) { diff --git a/packages/react-native-executorch/src/types/classification.ts b/packages/react-native-executorch/src/types/classification.ts index d38c664493..994d72a05c 100644 --- a/packages/react-native-executorch/src/types/classification.ts +++ b/packages/react-native-executorch/src/types/classification.ts @@ -77,20 +77,6 @@ export interface ClassificationType { * * Available after model is loaded (`isReady: true`). * - * @example - * ```typescript - * const { runOnFrame, isReady } = useClassification({ model: MODEL }); - * - * const frameOutput = useFrameOutput({ - * onFrame(frame) { - * 'worklet'; - * if (!runOnFrame) return; - * const result = runOnFrame(frame); - * frame.dispose(); - * } - * }); - * ``` - * * @param frame - VisionCamera Frame object * @returns Object mapping class labels to confidence scores. */ diff --git a/packages/react-native-executorch/src/types/computerVision.ts b/packages/react-native-executorch/src/types/computerVision.ts index e69de29bb2..62357f1402 100644 --- a/packages/react-native-executorch/src/types/computerVision.ts +++ b/packages/react-native-executorch/src/types/computerVision.ts @@ -0,0 +1,17 @@ +import { LabelEnum } from './common'; + +/** + * Given a model configs record (mapping model names to `{ labelMap }`) and a + * type `T` (either a model name key or a raw {@link LabelEnum}), resolves to + * the label map for that model or `T` itself. + * + * @internal + */ +export type ResolveLabels< + T, + Configs extends Record, +> = T extends keyof Configs + ? Configs[T]['labelMap'] + : T extends LabelEnum + ? T + : never; diff --git a/packages/react-native-executorch/src/types/imageEmbeddings.ts b/packages/react-native-executorch/src/types/imageEmbeddings.ts index 2963639c26..7130ac5b84 100644 --- a/packages/react-native-executorch/src/types/imageEmbeddings.ts +++ b/packages/react-native-executorch/src/types/imageEmbeddings.ts @@ -60,7 +60,7 @@ export interface ImageEmbeddingsType { * * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. * - * @param input - Image source (string or PixelData object) + * @param input - Image source (string or {@link PixelData} object) * @returns A Promise that resolves to a `Float32Array` containing the generated embedding vector. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ From c5bf1fd9fd554d544d3aa5ee2267ec347375bfbc Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Mar 2026 14:01:47 +0100 Subject: [PATCH 69/71] tests: create typed tests for vision models concurrent generates --- .../tests/integration/ClassificationTest.cpp | 3 + .../tests/integration/ImageEmbeddingsTest.cpp | 3 + .../tests/integration/ObjectDetectionTest.cpp | 3 + .../tests/integration/StyleTransferTest.cpp | 58 +----------------- .../tests/integration/VisionModelTests.h | 61 +++++++++++++++++++ 5 files changed, 73 insertions(+), 55 deletions(-) create mode 100644 packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTests.h diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp index 5725778def..d164fcacb0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp @@ -1,4 +1,5 @@ #include "BaseModelTests.h" +#include "VisionModelTests.h" #include #include #include @@ -38,6 +39,8 @@ template <> struct ModelTraits { using ClassificationTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(Classification, CommonModelTest, ClassificationTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(Classification, VisionModelTest, + ClassificationTypes); // ============================================================================ // Model-specific tests diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp index 87d37908b1..4982206614 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp @@ -1,4 +1,5 @@ #include "BaseModelTests.h" +#include "VisionModelTests.h" #include #include #include @@ -39,6 +40,8 @@ template <> struct ModelTraits { using ImageEmbeddingsTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, CommonModelTest, ImageEmbeddingsTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, VisionModelTest, + ImageEmbeddingsTypes); // ============================================================================ // Model-specific tests diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp index c983f2fc77..6222f6d682 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp @@ -1,4 +1,5 @@ #include "BaseModelTests.h" +#include "VisionModelTests.h" #include #include #include @@ -57,6 +58,8 @@ template <> struct ModelTraits { using ObjectDetectionTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, CommonModelTest, ObjectDetectionTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, VisionModelTest, + ObjectDetectionTypes); // ============================================================================ // Model-specific tests diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp index 4a2519f97f..a4511cad11 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp @@ -1,11 +1,10 @@ #include "BaseModelTests.h" -#include +#include "VisionModelTests.h" #include #include #include #include #include -#include #include using namespace rnexecutorch; @@ -48,6 +47,8 @@ template <> struct ModelTraits { using StyleTransferTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, CommonModelTest, StyleTransferTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, VisionModelTest, + StyleTransferTypes); // ============================================================================ // generateFromString tests @@ -79,17 +80,6 @@ TEST(StyleTransferGenerateTests, ValidImageReturnsFilePath) { EXPECT_GT(pr.height, 0); } -TEST(StyleTransferGenerateTests, MultipleGeneratesWork) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, false)); - auto result1 = model.generateFromString(kValidTestImagePath, false); - auto result2 = model.generateFromString(kValidTestImagePath, false); - ASSERT_TRUE(std::holds_alternative(result1)); - ASSERT_TRUE(std::holds_alternative(result2)); - EXPECT_NE(std::get(result1).dataPtr, nullptr); - EXPECT_NE(std::get(result2).dataPtr, nullptr); -} - // ============================================================================ // generateFromString saveToFile tests // ============================================================================ @@ -173,48 +163,6 @@ TEST(StyleTransferPixelTests, OutputDimensionsMatchInputSize) { EXPECT_EQ(pr.height, 64); } -// ============================================================================ -// Thread safety tests -// ============================================================================ -TEST(StyleTransferThreadSafetyTests, TwoConcurrentGeneratesDoNotCrash) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - std::atomic successCount{0}; - std::atomic exceptionCount{0}; - - auto task = [&]() { - try { - (void)model.generateFromString(kValidTestImagePath, false); - successCount++; - } catch (const RnExecutorchError &) { - exceptionCount++; - } - }; - - std::thread a(task); - std::thread b(task); - a.join(); - b.join(); - - EXPECT_EQ(successCount + exceptionCount, 2); -} - -TEST(StyleTransferThreadSafetyTests, - GenerateAndUnloadConcurrentlyDoesNotCrash) { - StyleTransfer model(kValidStyleTransferModelPath, nullptr); - - std::thread a([&]() { - try { - (void)model.generateFromString(kValidTestImagePath, false); - } catch (const RnExecutorchError &) { - } - }); - std::thread b([&]() { model.unload(); }); - - a.join(); - b.join(); - // If we reach here without crashing, the mutex serialized correctly. -} - // ============================================================================ // Inherited BaseModel tests // ============================================================================ diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTests.h b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTests.h new file mode 100644 index 0000000000..04e407787d --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTests.h @@ -0,0 +1,61 @@ +#pragma once + +#include "BaseModelTests.h" +#include +#include +#include +#include + +namespace model_tests { + +template class VisionModelTest : public ::testing::Test { +protected: + using Traits = ModelTraits; + using ModelType = typename Traits::ModelType; +}; + +TYPED_TEST_SUITE_P(VisionModelTest); + +TYPED_TEST_P(VisionModelTest, TwoConcurrentGeneratesDoNotCrash) { + SETUP_TRAITS(); + auto model = Traits::createValid(); + std::atomic successCount{0}; + std::atomic exceptionCount{0}; + + auto task = [&]() { + try { + Traits::callGenerate(model); + successCount++; + } catch (const rnexecutorch::RnExecutorchError &) { + exceptionCount++; + } + }; + + std::thread a(task); + std::thread b(task); + a.join(); + b.join(); + + EXPECT_EQ(successCount + exceptionCount, 2); +} + +TYPED_TEST_P(VisionModelTest, GenerateAndUnloadConcurrentlyDoesNotCrash) { + SETUP_TRAITS(); + auto model = Traits::createValid(); + + std::thread a([&]() { + try { + Traits::callGenerate(model); + } catch (const rnexecutorch::RnExecutorchError &) { + } + }); + std::thread b([&]() { model.unload(); }); + + a.join(); + b.join(); +} + +REGISTER_TYPED_TEST_SUITE_P(VisionModelTest, TwoConcurrentGeneratesDoNotCrash, + GenerateAndUnloadConcurrentlyDoesNotCrash); + +} // namespace model_tests From 836c7a27bf212355d0d7265b59cdbaedfb9c3446 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Mar 2026 16:48:49 +0100 Subject: [PATCH 70/71] refactor: follow declaration order in VisionModel class --- .../rnexecutorch/models/VisionModel.cpp | 17 +++++++++++ .../common/rnexecutorch/models/VisionModel.h | 30 +++++-------------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index 89727361ca..0b6acbc383 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -7,6 +7,23 @@ namespace rnexecutorch::models { using namespace facebook; +VisionModel::VisionModel(const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) {} + +void VisionModel::unload() noexcept { + std::scoped_lock lock(inference_mutex_); + BaseModel::unload(); +} + +cv::Size VisionModel::modelInputSize() const { + if (modelInputShape_.size() < 2) { + return {0, 0}; + } + return cv::Size(modelInputShape_[modelInputShape_.size() - 1], + modelInputShape_[modelInputShape_.size() - 2]); +} + cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) const { cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h index 772ed40e61..6f9a9532f4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -41,15 +41,8 @@ namespace models { */ class VisionModel : public BaseModel { public: - /** - * @brief Construct a VisionModel with the same parameters as BaseModel - * - * VisionModel uses the same construction pattern as BaseModel, just adding - * thread-safety on top. - */ VisionModel(const std::string &modelSource, - std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) {} + std::shared_ptr callInvoker); virtual ~VisionModel() = default; @@ -62,12 +55,13 @@ class VisionModel : public BaseModel { * destroys module_ while generateFromFrame() is still executing on the * VisionCamera worklet thread. */ - void unload() noexcept { - std::scoped_lock lock(inference_mutex_); - BaseModel::unload(); - } + void unload() noexcept; protected: + /// Cached input tensor shape (getAllInputShapes()[0]). + /// Set once by each subclass constructor to avoid per-frame metadata lookups. + std::vector modelInputShape_; + /** * @brief Mutex to ensure thread-safe inference * @@ -96,18 +90,8 @@ class VisionModel : public BaseModel { */ virtual cv::Mat preprocess(const cv::Mat &image) const; - /// Cached input tensor shape (getAllInputShapes()[0]). - /// Set once by each subclass constructor to avoid per-frame metadata lookups. - std::vector modelInputShape_; - /// Convenience accessor: spatial dimensions of the model input. - cv::Size modelInputSize() const { - if (modelInputShape_.size() < 2) { - return {0, 0}; - } - return cv::Size(modelInputShape_[modelInputShape_.size() - 1], - modelInputShape_[modelInputShape_.size() - 2]); - } + cv::Size modelInputSize() const; /** * @brief Extract an RGB cv::Mat from a VisionCamera frame From 3ea1a50f5a1294b0b19f4149015651ae56556a0f Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 16 Mar 2026 18:26:07 +0100 Subject: [PATCH 71/71] fix: replace magic number 4 with channels field in PixelDataResult --- .../common/rnexecutorch/host_objects/JsiConversions.h | 2 +- .../common/rnexecutorch/models/style_transfer/StyleTransfer.cpp | 2 +- .../common/rnexecutorch/models/style_transfer/Types.h | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index dc715f0f2a..a4e373c2b8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -575,7 +575,7 @@ getJsiValue(const models::style_transfer::PixelDataResult &result, auto sizesArray = jsi::Array(runtime, 3); sizesArray.setValueAtIndex(runtime, 0, jsi::Value(result.height)); sizesArray.setValueAtIndex(runtime, 1, jsi::Value(result.width)); - sizesArray.setValueAtIndex(runtime, 2, jsi::Value(4)); + sizesArray.setValueAtIndex(runtime, 2, jsi::Value(result.channels)); obj.setProperty(runtime, "sizes", sizesArray); obj.setProperty(runtime, "scalarType", diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index 4aaa774cc4..ad94e76ba4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -61,7 +61,7 @@ PixelDataResult toPixelDataResult(const cv::Mat &bgrMat) { cv::cvtColor(bgrMat, rgba, cv::COLOR_BGR2RGBA); std::size_t dataSize = static_cast(size.width) * size.height * 4; auto pixelBuffer = std::make_shared(rgba.data, dataSize); - return PixelDataResult{pixelBuffer, size.width, size.height}; + return PixelDataResult{pixelBuffer, size.width, size.height, rgba.channels()}; } StyleTransferResult StyleTransfer::generateFromString(std::string imageSource, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h index 57e69eb730..27df4ec6c6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h @@ -11,6 +11,7 @@ struct PixelDataResult { std::shared_ptr dataPtr; int width; int height; + int channels; }; using StyleTransferResult = std::variant;