diff --git a/apps/computer-vision/app/_layout.tsx b/apps/computer-vision/app/_layout.tsx
index 4ce2f3e5c2..cac2692f07 100644
--- a/apps/computer-vision/app/_layout.tsx
+++ b/apps/computer-vision/app/_layout.tsx
@@ -59,6 +59,15 @@ export default function _layout() {
headerTitleStyle: { color: ColorPalette.primary },
}}
>
+
-
Select a demo model
+ router.navigate('vision_camera/')}
+ >
+ Vision Camera
+
router.navigate('classification/')}
@@ -29,12 +35,6 @@ export default function Home() {
>
Object Detection
- router.navigate('object_detection_live/')}
- >
- Object Detection Live
-
router.navigate('ocr/')}
diff --git a/apps/computer-vision/app/object_detection_live/index.tsx b/apps/computer-vision/app/object_detection_live/index.tsx
deleted file mode 100644
index 3db2c53602..0000000000
--- a/apps/computer-vision/app/object_detection_live/index.tsx
+++ /dev/null
@@ -1,222 +0,0 @@
-import React, {
- useCallback,
- useContext,
- useEffect,
- useMemo,
- useRef,
- useState,
-} from 'react';
-import {
- StatusBar,
- StyleSheet,
- Text,
- TouchableOpacity,
- View,
-} from 'react-native';
-import { useSafeAreaInsets } from 'react-native-safe-area-context';
-
-import {
- Camera,
- getCameraFormat,
- Templates,
- useCameraDevices,
- useCameraPermission,
- useFrameOutput,
-} from 'react-native-vision-camera';
-import { scheduleOnRN } from 'react-native-worklets';
-import {
- Detection,
- SSDLITE_320_MOBILENET_V3_LARGE,
- useObjectDetection,
-} from 'react-native-executorch';
-import { GeneratingContext } from '../../context';
-import Spinner from '../../components/Spinner';
-import ColorPalette from '../../colors';
-
-export default function ObjectDetectionLiveScreen() {
- const insets = useSafeAreaInsets();
-
- const model = useObjectDetection({ model: SSDLITE_320_MOBILENET_V3_LARGE });
- const { setGlobalGenerating } = useContext(GeneratingContext);
-
- useEffect(() => {
- setGlobalGenerating(model.isGenerating);
- }, [model.isGenerating, setGlobalGenerating]);
- const [detectionCount, setDetectionCount] = useState(0);
- const [fps, setFps] = useState(0);
- const lastFrameTimeRef = useRef(Date.now());
-
- const cameraPermission = useCameraPermission();
- const devices = useCameraDevices();
- const device = devices.find((d) => d.position === 'back') ?? devices[0];
-
- const format = useMemo(() => {
- if (device == null) return undefined;
- try {
- return getCameraFormat(device, Templates.FrameProcessing);
- } catch {
- return undefined;
- }
- }, [device]);
-
- const updateStats = useCallback((results: Detection[]) => {
- setDetectionCount(results.length);
- const now = Date.now();
- const timeDiff = now - lastFrameTimeRef.current;
- if (timeDiff > 0) {
- setFps(Math.round(1000 / timeDiff));
- }
- lastFrameTimeRef.current = now;
- }, []);
-
- const frameOutput = useFrameOutput({
- pixelFormat: 'rgb',
- dropFramesWhileBusy: true,
- onFrame(frame) {
- 'worklet';
- if (!model.runOnFrame) {
- frame.dispose();
- return;
- }
- try {
- const result = model.runOnFrame(frame, 0.5);
- if (result) {
- scheduleOnRN(updateStats, result);
- }
- } catch {
- // ignore frame errors
- } finally {
- frame.dispose();
- }
- },
- });
-
- if (!model.isReady) {
- return (
-
- );
- }
-
- if (!cameraPermission.hasPermission) {
- return (
-
- Camera access needed
- cameraPermission.requestPermission()}
- style={styles.button}
- >
- Grant Permission
-
-
- );
- }
-
- if (device == null) {
- return (
-
- No camera device found
-
- );
- }
-
- return (
-
-
-
-
-
-
-
-
- {detectionCount}
- objects
-
-
-
- {fps}
- fps
-
-
-
-
- );
-}
-
-const styles = StyleSheet.create({
- container: {
- flex: 1,
- backgroundColor: 'black',
- },
- centered: {
- flex: 1,
- backgroundColor: 'black',
- justifyContent: 'center',
- alignItems: 'center',
- gap: 16,
- },
- message: {
- color: 'white',
- fontSize: 18,
- },
- button: {
- paddingHorizontal: 24,
- paddingVertical: 12,
- backgroundColor: ColorPalette.primary,
- borderRadius: 24,
- },
- buttonText: {
- color: 'white',
- fontSize: 15,
- fontWeight: '600',
- letterSpacing: 0.3,
- },
- bottomBarWrapper: {
- position: 'absolute',
- bottom: 0,
- left: 0,
- right: 0,
- alignItems: 'center',
- },
- bottomBar: {
- flexDirection: 'row',
- alignItems: 'center',
- backgroundColor: 'rgba(0, 0, 0, 0.55)',
- borderRadius: 24,
- paddingHorizontal: 28,
- paddingVertical: 10,
- gap: 24,
- },
- statItem: {
- alignItems: 'center',
- },
- statValue: {
- color: 'white',
- fontSize: 22,
- fontWeight: '700',
- letterSpacing: -0.5,
- },
- statLabel: {
- color: 'rgba(255,255,255,0.55)',
- fontSize: 11,
- fontWeight: '500',
- textTransform: 'uppercase',
- letterSpacing: 0.8,
- },
- statDivider: {
- width: 1,
- height: 32,
- backgroundColor: 'rgba(255,255,255,0.2)',
- },
-});
diff --git a/apps/computer-vision/app/style_transfer/index.tsx b/apps/computer-vision/app/style_transfer/index.tsx
index dc6a0d4963..46ae3e814a 100644
--- a/apps/computer-vision/app/style_transfer/index.tsx
+++ b/apps/computer-vision/app/style_transfer/index.tsx
@@ -16,20 +16,24 @@ export default function StyleTransferScreen() {
useEffect(() => {
setGlobalGenerating(model.isGenerating);
}, [model.isGenerating, setGlobalGenerating]);
+
const [imageUri, setImageUri] = useState('');
+ const [styledUri, setStyledUri] = useState('');
+
const handleCameraPress = async (isCamera: boolean) => {
const image = await getImage(isCamera);
const uri = image?.uri;
if (typeof uri === 'string') {
- setImageUri(uri as string);
+ setImageUri(uri);
+ setStyledUri('');
}
};
const runForward = async () => {
if (imageUri) {
try {
- const output = await model.forward(imageUri);
- setImageUri(output);
+ const uri = await model.forward(imageUri, 'url');
+ setStyledUri(uri);
} catch (e) {
console.error(e);
}
@@ -52,9 +56,11 @@ export default function StyleTransferScreen() {
style={styles.image}
resizeMode="contain"
source={
- imageUri
- ? { uri: imageUri }
- : require('../../assets/icons/executorch_logo.png')
+ styledUri
+ ? { uri: styledUri }
+ : imageUri
+ ? { uri: imageUri }
+ : require('../../assets/icons/executorch_logo.png')
}
/>
diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx
new file mode 100644
index 0000000000..b2af60d504
--- /dev/null
+++ b/apps/computer-vision/app/vision_camera/index.tsx
@@ -0,0 +1,442 @@
+import React, {
+ useCallback,
+ useContext,
+ useEffect,
+ useMemo,
+ useState,
+} from 'react';
+import {
+ ScrollView,
+ StatusBar,
+ StyleSheet,
+ Text,
+ TouchableOpacity,
+ View,
+} from 'react-native';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+import { useIsFocused } from '@react-navigation/native';
+import {
+ Camera,
+ getCameraFormat,
+ Templates,
+ useCameraDevices,
+ useCameraPermission,
+ useFrameOutput,
+} from 'react-native-vision-camera';
+import { createSynchronizable } from 'react-native-worklets';
+import Svg, { Path, Polygon } from 'react-native-svg';
+import { GeneratingContext } from '../../context';
+import Spinner from '../../components/Spinner';
+import ColorPalette from '../../colors';
+import ClassificationTask from '../../components/vision_camera/tasks/ClassificationTask';
+import ObjectDetectionTask from '../../components/vision_camera/tasks/ObjectDetectionTask';
+import SegmentationTask from '../../components/vision_camera/tasks/SegmentationTask';
+
+type TaskId = 'classification' | 'objectDetection' | 'segmentation';
+type ModelId =
+ | 'classification'
+ | 'objectDetectionSsdlite'
+ | 'objectDetectionRfdetr'
+ | 'segmentationDeeplabResnet50'
+ | 'segmentationDeeplabResnet101'
+ | 'segmentationDeeplabMobilenet'
+ | 'segmentationLraspp'
+ | 'segmentationFcnResnet50'
+ | 'segmentationFcnResnet101'
+ | 'segmentationSelfie';
+
+type TaskVariant = { id: ModelId; label: string };
+type Task = { id: TaskId; label: string; variants: TaskVariant[] };
+
+const TASKS: Task[] = [
+ {
+ id: 'classification',
+ label: 'Classify',
+ variants: [{ id: 'classification', label: 'EfficientNet V2 S' }],
+ },
+ {
+ id: 'segmentation',
+ label: 'Segment',
+ variants: [
+ { id: 'segmentationDeeplabResnet50', label: 'DeepLab ResNet50' },
+ { id: 'segmentationDeeplabResnet101', label: 'DeepLab ResNet101' },
+ { id: 'segmentationDeeplabMobilenet', label: 'DeepLab MobileNet' },
+ { id: 'segmentationLraspp', label: 'LRASPP MobileNet' },
+ { id: 'segmentationFcnResnet50', label: 'FCN ResNet50' },
+ { id: 'segmentationFcnResnet101', label: 'FCN ResNet101' },
+ { id: 'segmentationSelfie', label: 'Selfie' },
+ ],
+ },
+ {
+ id: 'objectDetection',
+ label: 'Detect',
+ variants: [
+ { id: 'objectDetectionSsdlite', label: 'SSDLite MobileNet' },
+ { id: 'objectDetectionRfdetr', label: 'RF-DETR Nano' },
+ ],
+ },
+];
+
+// Module-level const so worklets in task components can always reference the same stable object.
+// Never replaced — only mutated via setBlocking to avoid closure staleness.
+const frameKillSwitch = createSynchronizable(false);
+
+export default function VisionCameraScreen() {
+ const insets = useSafeAreaInsets();
+ const [activeTask, setActiveTask] = useState('classification');
+ const [activeModel, setActiveModel] = useState('classification');
+ const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 });
+ const [cameraPosition, setCameraPosition] = useState<'front' | 'back'>(
+ 'back'
+ );
+ const [fps, setFps] = useState(0);
+ const [frameMs, setFrameMs] = useState(0);
+ const [isReady, setIsReady] = useState(false);
+ const [downloadProgress, setDownloadProgress] = useState(0);
+ const [frameOutput, setFrameOutput] = useState | null>(null);
+ const { setGlobalGenerating } = useContext(GeneratingContext);
+
+ const isFocused = useIsFocused();
+ const cameraPermission = useCameraPermission();
+ const devices = useCameraDevices();
+ const device =
+ devices.find((d) => d.position === cameraPosition) ?? devices[0];
+ const format = useMemo(() => {
+ if (device == null) return undefined;
+ try {
+ return getCameraFormat(device, Templates.FrameProcessing);
+ } catch {
+ return undefined;
+ }
+ }, [device]);
+
+ useEffect(() => {
+ frameKillSwitch.setBlocking(true);
+ const id = setTimeout(() => {
+ frameKillSwitch.setBlocking(false);
+ }, 300);
+ return () => clearTimeout(id);
+ }, [activeModel]);
+
+ const handleFpsChange = useCallback((newFps: number, newMs: number) => {
+ setFps(newFps);
+ setFrameMs(newMs);
+ }, []);
+
+ const handleGeneratingChange = useCallback(
+ (generating: boolean) => {
+ setGlobalGenerating(generating);
+ },
+ [setGlobalGenerating]
+ );
+
+ if (!cameraPermission.hasPermission) {
+ return (
+
+ Camera access needed
+ cameraPermission.requestPermission()}
+ style={styles.button}
+ >
+ Grant Permission
+
+
+ );
+ }
+
+ if (device == null) {
+ return (
+
+ No camera device found
+
+ );
+ }
+
+ const activeTaskInfo = TASKS.find((t) => t.id === activeTask)!;
+ const activeVariantLabel =
+ activeTaskInfo.variants.find((v) => v.id === activeModel)?.label ??
+ activeTaskInfo.variants[0]!.label;
+
+ const taskProps = {
+ activeModel,
+ canvasSize,
+ cameraPosition,
+ frameKillSwitch,
+ onFrameOutputChange: setFrameOutput,
+ onReadyChange: setIsReady,
+ onProgressChange: setDownloadProgress,
+ onGeneratingChange: handleGeneratingChange,
+ onFpsChange: handleFpsChange,
+ };
+
+ return (
+
+
+
+
+
+ {/* Layout sentinel — measures the full-screen area for bbox/canvas sizing */}
+
+ setCanvasSize({
+ width: e.nativeEvent.layout.width,
+ height: e.nativeEvent.layout.height,
+ })
+ }
+ />
+
+ {activeTask === 'classification' && }
+ {activeTask === 'objectDetection' && (
+
+ )}
+ {activeTask === 'segmentation' && (
+
+ )}
+
+ {!isReady && (
+
+
+
+ )}
+
+
+
+ {activeVariantLabel}
+
+ {fps} FPS – {frameMs.toFixed(0)} ms
+
+
+
+
+ {TASKS.map((t) => (
+ {
+ setActiveTask(t.id);
+ setActiveModel(t.variants[0]!.id);
+ }}
+ >
+
+ {t.label}
+
+
+ ))}
+
+
+
+ {activeTaskInfo.variants.map((v) => (
+ setActiveModel(v.id)}
+ >
+
+ {v.label}
+
+
+ ))}
+
+
+
+
+
+ setCameraPosition((p) => (p === 'back' ? 'front' : 'back'))
+ }
+ >
+
+
+
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: { flex: 1, backgroundColor: 'black' },
+ centered: {
+ flex: 1,
+ backgroundColor: 'black',
+ justifyContent: 'center',
+ alignItems: 'center',
+ gap: 16,
+ },
+ message: { color: 'white', fontSize: 18 },
+ button: {
+ paddingHorizontal: 24,
+ paddingVertical: 12,
+ backgroundColor: ColorPalette.primary,
+ borderRadius: 24,
+ },
+ buttonText: { color: 'white', fontSize: 15, fontWeight: '600' },
+ loadingOverlay: {
+ ...StyleSheet.absoluteFillObject,
+ backgroundColor: 'rgba(0,0,0,0.6)',
+ justifyContent: 'center',
+ alignItems: 'center',
+ },
+ topOverlay: {
+ position: 'absolute',
+ top: 0,
+ left: 0,
+ right: 0,
+ alignItems: 'center',
+ gap: 8,
+ },
+ titleRow: {
+ alignItems: 'center',
+ paddingHorizontal: 16,
+ },
+ modelTitle: {
+ color: 'white',
+ fontSize: 22,
+ fontWeight: '700',
+ textShadowColor: 'rgba(0,0,0,0.7)',
+ textShadowOffset: { width: 0, height: 1 },
+ textShadowRadius: 4,
+ },
+ fpsText: {
+ color: 'rgba(255,255,255,0.85)',
+ fontSize: 14,
+ fontWeight: '500',
+ marginTop: 2,
+ textShadowColor: 'rgba(0,0,0,0.7)',
+ textShadowOffset: { width: 0, height: 1 },
+ textShadowRadius: 4,
+ },
+ tabsContent: {
+ paddingHorizontal: 12,
+ gap: 6,
+ },
+ tab: {
+ paddingHorizontal: 18,
+ paddingVertical: 7,
+ borderRadius: 20,
+ backgroundColor: 'rgba(0,0,0,0.45)',
+ borderWidth: 1,
+ borderColor: 'rgba(255,255,255,0.25)',
+ },
+ tabActive: {
+ backgroundColor: 'rgba(255,255,255,0.2)',
+ borderColor: 'white',
+ },
+ tabText: {
+ color: 'rgba(255,255,255,0.7)',
+ fontSize: 14,
+ fontWeight: '600',
+ },
+ tabTextActive: { color: 'white' },
+ chipsContent: {
+ paddingHorizontal: 12,
+ gap: 6,
+ },
+ variantChip: {
+ paddingHorizontal: 14,
+ paddingVertical: 5,
+ borderRadius: 16,
+ backgroundColor: 'rgba(0,0,0,0.35)',
+ borderWidth: 1,
+ borderColor: 'rgba(255,255,255,0.15)',
+ },
+ variantChipActive: {
+ backgroundColor: ColorPalette.primary,
+ borderColor: ColorPalette.primary,
+ },
+ variantChipText: {
+ color: 'rgba(255,255,255,0.6)',
+ fontSize: 12,
+ fontWeight: '500',
+ },
+ variantChipTextActive: { color: 'white' },
+ bottomOverlay: {
+ position: 'absolute',
+ bottom: 0,
+ left: 0,
+ right: 0,
+ alignItems: 'center',
+ },
+ flipButton: {
+ width: 56,
+ height: 56,
+ borderRadius: 28,
+ backgroundColor: 'rgba(255,255,255,0.2)',
+ justifyContent: 'center',
+ alignItems: 'center',
+ borderWidth: 1.5,
+ borderColor: 'rgba(255,255,255,0.4)',
+ },
+});
diff --git a/apps/computer-vision/components/vision_camera/tasks/ClassificationTask.tsx b/apps/computer-vision/components/vision_camera/tasks/ClassificationTask.tsx
new file mode 100644
index 0000000000..c9b4a2bf21
--- /dev/null
+++ b/apps/computer-vision/components/vision_camera/tasks/ClassificationTask.tsx
@@ -0,0 +1,120 @@
+import React, { useCallback, useEffect, useRef, useState } from 'react';
+import { StyleSheet, Text, View } from 'react-native';
+import { Frame, useFrameOutput } from 'react-native-vision-camera';
+import { scheduleOnRN } from 'react-native-worklets';
+import { EFFICIENTNET_V2_S, useClassification } from 'react-native-executorch';
+import { TaskProps } from './types';
+
+type Props = Omit;
+
+export default function ClassificationTask({
+ frameKillSwitch,
+ onFrameOutputChange,
+ onReadyChange,
+ onProgressChange,
+ onGeneratingChange,
+ onFpsChange,
+}: Props) {
+ const model = useClassification({ model: EFFICIENTNET_V2_S });
+ const [classResult, setClassResult] = useState({ label: '', score: 0 });
+ const lastFrameTimeRef = useRef(Date.now());
+
+ useEffect(() => {
+ onReadyChange(model.isReady);
+ }, [model.isReady, onReadyChange]);
+
+ useEffect(() => {
+ onProgressChange(model.downloadProgress);
+ }, [model.downloadProgress, onProgressChange]);
+
+ useEffect(() => {
+ onGeneratingChange(model.isGenerating);
+ }, [model.isGenerating, onGeneratingChange]);
+
+ const classRof = model.runOnFrame;
+
+ const updateClass = useCallback(
+ (r: { label: string; score: number }) => {
+ setClassResult(r);
+ const now = Date.now();
+ const diff = now - lastFrameTimeRef.current;
+ if (diff > 0) onFpsChange(Math.round(1000 / diff), diff);
+ lastFrameTimeRef.current = now;
+ },
+ [onFpsChange]
+ );
+
+ const frameOutput = useFrameOutput({
+ pixelFormat: 'rgb',
+ dropFramesWhileBusy: true,
+ onFrame: useCallback(
+ (frame: Frame) => {
+ 'worklet';
+ if (frameKillSwitch.getDirty()) {
+ frame.dispose();
+ return;
+ }
+ try {
+ if (!classRof) return;
+ const result = classRof(frame);
+ if (result) {
+ let bestLabel = '';
+ let bestScore = -1;
+ const entries = Object.entries(result);
+ for (let i = 0; i < entries.length; i++) {
+ const [label, score] = entries[i]!;
+ if ((score as number) > bestScore) {
+ bestScore = score as number;
+ bestLabel = label;
+ }
+ }
+ scheduleOnRN(updateClass, { label: bestLabel, score: bestScore });
+ }
+ } catch {
+ // ignore
+ } finally {
+ frame.dispose();
+ }
+ },
+ [classRof, frameKillSwitch, updateClass]
+ ),
+ });
+
+ useEffect(() => {
+ onFrameOutputChange(frameOutput);
+ }, [frameOutput, onFrameOutputChange]);
+
+ return classResult.label ? (
+
+ {classResult.label}
+ {(classResult.score * 100).toFixed(1)}%
+
+ ) : null;
+}
+
+const styles = StyleSheet.create({
+ overlay: {
+ ...StyleSheet.absoluteFillObject,
+ justifyContent: 'center',
+ alignItems: 'center',
+ },
+ label: {
+ color: 'white',
+ fontSize: 28,
+ fontWeight: '700',
+ textAlign: 'center',
+ textShadowColor: 'rgba(0,0,0,0.8)',
+ textShadowOffset: { width: 0, height: 1 },
+ textShadowRadius: 6,
+ paddingHorizontal: 24,
+ },
+ score: {
+ color: 'rgba(255,255,255,0.75)',
+ fontSize: 18,
+ fontWeight: '500',
+ marginTop: 4,
+ textShadowColor: 'rgba(0,0,0,0.8)',
+ textShadowOffset: { width: 0, height: 1 },
+ textShadowRadius: 6,
+ },
+});
diff --git a/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx
new file mode 100644
index 0000000000..a54d20c87e
--- /dev/null
+++ b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx
@@ -0,0 +1,174 @@
+import React, { useCallback, useEffect, useRef, useState } from 'react';
+import { StyleSheet, Text, View } from 'react-native';
+import { Frame, useFrameOutput } from 'react-native-vision-camera';
+import { scheduleOnRN } from 'react-native-worklets';
+import {
+ Detection,
+ RF_DETR_NANO,
+ SSDLITE_320_MOBILENET_V3_LARGE,
+ useObjectDetection,
+} from 'react-native-executorch';
+import { labelColor, labelColorBg } from '../utils/colors';
+import { TaskProps } from './types';
+
+type ObjModelId = 'objectDetectionSsdlite' | 'objectDetectionRfdetr';
+
+type Props = TaskProps & { activeModel: ObjModelId };
+
+export default function ObjectDetectionTask({
+ activeModel,
+ canvasSize,
+ cameraPosition,
+ frameKillSwitch,
+ onFrameOutputChange,
+ onReadyChange,
+ onProgressChange,
+ onGeneratingChange,
+ onFpsChange,
+}: Props) {
+ const ssdlite = useObjectDetection({
+ model: SSDLITE_320_MOBILENET_V3_LARGE,
+ preventLoad: activeModel !== 'objectDetectionSsdlite',
+ });
+ const rfdetr = useObjectDetection({
+ model: RF_DETR_NANO,
+ preventLoad: activeModel !== 'objectDetectionRfdetr',
+ });
+
+ const active = activeModel === 'objectDetectionSsdlite' ? ssdlite : rfdetr;
+
+ const [detections, setDetections] = useState([]);
+ const [imageSize, setImageSize] = useState({ width: 1, height: 1 });
+ const lastFrameTimeRef = useRef(Date.now());
+
+ useEffect(() => {
+ onReadyChange(active.isReady);
+ }, [active.isReady, onReadyChange]);
+
+ useEffect(() => {
+ onProgressChange(active.downloadProgress);
+ }, [active.downloadProgress, onProgressChange]);
+
+ useEffect(() => {
+ onGeneratingChange(active.isGenerating);
+ }, [active.isGenerating, onGeneratingChange]);
+
+ const detRof = active.runOnFrame;
+
+ const updateDetections = useCallback(
+ (p: { results: Detection[]; imageWidth: number; imageHeight: number }) => {
+ setDetections(p.results);
+ setImageSize({ width: p.imageWidth, height: p.imageHeight });
+ const now = Date.now();
+ const diff = now - lastFrameTimeRef.current;
+ if (diff > 0) onFpsChange(Math.round(1000 / diff), diff);
+ lastFrameTimeRef.current = now;
+ },
+ [onFpsChange]
+ );
+
+ const frameOutput = useFrameOutput({
+ pixelFormat: 'rgb',
+ dropFramesWhileBusy: true,
+ onFrame: useCallback(
+ (frame: Frame) => {
+ 'worklet';
+ if (frameKillSwitch.getDirty()) {
+ frame.dispose();
+ return;
+ }
+ try {
+ if (!detRof) return;
+ const iw = frame.width > frame.height ? frame.height : frame.width;
+ const ih = frame.width > frame.height ? frame.width : frame.height;
+ const result = detRof(frame, 0.5);
+ if (result) {
+ scheduleOnRN(updateDetections, {
+ results: result,
+ imageWidth: iw,
+ imageHeight: ih,
+ });
+ }
+ } catch {
+ // ignore
+ } finally {
+ frame.dispose();
+ }
+ },
+ [detRof, frameKillSwitch, updateDetections]
+ ),
+ });
+
+ useEffect(() => {
+ onFrameOutputChange(frameOutput);
+ }, [frameOutput, onFrameOutputChange]);
+
+ const scale = Math.max(
+ canvasSize.width / imageSize.width,
+ canvasSize.height / imageSize.height
+ );
+ const offsetX = (canvasSize.width - imageSize.width * scale) / 2;
+ const offsetY = (canvasSize.height - imageSize.height * scale) / 2;
+
+ return (
+
+ {detections.map((det, i) => {
+ const left = det.bbox.x1 * scale + offsetX;
+ const top = det.bbox.y1 * scale + offsetY;
+ const w = (det.bbox.x2 - det.bbox.x1) * scale;
+ const h = (det.bbox.y2 - det.bbox.y1) * scale;
+ return (
+
+
+
+ {det.label} {(det.score * 100).toFixed(1)}
+
+
+
+ );
+ })}
+
+ );
+}
+
+const styles = StyleSheet.create({
+ bbox: {
+ position: 'absolute',
+ borderWidth: 2,
+ borderColor: 'cyan',
+ borderRadius: 4,
+ },
+ bboxLabel: {
+ position: 'absolute',
+ top: -22,
+ left: -2,
+ paddingHorizontal: 6,
+ paddingVertical: 2,
+ borderRadius: 4,
+ },
+ bboxLabelText: { color: 'white', fontSize: 11, fontWeight: '600' },
+});
diff --git a/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx b/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx
new file mode 100644
index 0000000000..8226b0aae9
--- /dev/null
+++ b/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx
@@ -0,0 +1,212 @@
+import React, { useCallback, useEffect, useRef, useState } from 'react';
+import { StyleSheet, View } from 'react-native';
+import { Frame, useFrameOutput } from 'react-native-vision-camera';
+import { scheduleOnRN } from 'react-native-worklets';
+import {
+ DEEPLAB_V3_RESNET50_QUANTIZED,
+ DEEPLAB_V3_RESNET101_QUANTIZED,
+ DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED,
+ FCN_RESNET50_QUANTIZED,
+ FCN_RESNET101_QUANTIZED,
+ LRASPP_MOBILENET_V3_LARGE_QUANTIZED,
+ SELFIE_SEGMENTATION,
+ useSemanticSegmentation,
+} from 'react-native-executorch';
+import {
+ AlphaType,
+ Canvas,
+ ColorType,
+ Image as SkiaImage,
+ Skia,
+ SkImage,
+} from '@shopify/react-native-skia';
+import { CLASS_COLORS } from '../utils/colors';
+import { TaskProps } from './types';
+
+type SegModelId =
+ | 'segmentationDeeplabResnet50'
+ | 'segmentationDeeplabResnet101'
+ | 'segmentationDeeplabMobilenet'
+ | 'segmentationLraspp'
+ | 'segmentationFcnResnet50'
+ | 'segmentationFcnResnet101'
+ | 'segmentationSelfie';
+
+type Props = TaskProps & { activeModel: SegModelId };
+
+export default function SegmentationTask({
+ activeModel,
+ canvasSize,
+ cameraPosition,
+ frameKillSwitch,
+ onFrameOutputChange,
+ onReadyChange,
+ onProgressChange,
+ onGeneratingChange,
+ onFpsChange,
+}: Props) {
+ const segDeeplabResnet50 = useSemanticSegmentation({
+ model: DEEPLAB_V3_RESNET50_QUANTIZED,
+ preventLoad: activeModel !== 'segmentationDeeplabResnet50',
+ });
+ const segDeeplabResnet101 = useSemanticSegmentation({
+ model: DEEPLAB_V3_RESNET101_QUANTIZED,
+ preventLoad: activeModel !== 'segmentationDeeplabResnet101',
+ });
+ const segDeeplabMobilenet = useSemanticSegmentation({
+ model: DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED,
+ preventLoad: activeModel !== 'segmentationDeeplabMobilenet',
+ });
+ const segLraspp = useSemanticSegmentation({
+ model: LRASPP_MOBILENET_V3_LARGE_QUANTIZED,
+ preventLoad: activeModel !== 'segmentationLraspp',
+ });
+ const segFcnResnet50 = useSemanticSegmentation({
+ model: FCN_RESNET50_QUANTIZED,
+ preventLoad: activeModel !== 'segmentationFcnResnet50',
+ });
+ const segFcnResnet101 = useSemanticSegmentation({
+ model: FCN_RESNET101_QUANTIZED,
+ preventLoad: activeModel !== 'segmentationFcnResnet101',
+ });
+ const segSelfie = useSemanticSegmentation({
+ model: SELFIE_SEGMENTATION,
+ preventLoad: activeModel !== 'segmentationSelfie',
+ });
+
+ const active = {
+ segmentationDeeplabResnet50: segDeeplabResnet50,
+ segmentationDeeplabResnet101: segDeeplabResnet101,
+ segmentationDeeplabMobilenet: segDeeplabMobilenet,
+ segmentationLraspp: segLraspp,
+ segmentationFcnResnet50: segFcnResnet50,
+ segmentationFcnResnet101: segFcnResnet101,
+ segmentationSelfie: segSelfie,
+ }[activeModel];
+
+ const [maskImage, setMaskImage] = useState(null);
+ const lastFrameTimeRef = useRef(Date.now());
+
+ useEffect(() => {
+ onReadyChange(active.isReady);
+ }, [active.isReady, onReadyChange]);
+
+ useEffect(() => {
+ onProgressChange(active.downloadProgress);
+ }, [active.downloadProgress, onProgressChange]);
+
+ useEffect(() => {
+ onGeneratingChange(active.isGenerating);
+ }, [active.isGenerating, onGeneratingChange]);
+
+ // Clear stale mask when the segmentation model variant changes
+ useEffect(() => {
+ setMaskImage((prev) => {
+ prev?.dispose();
+ return null;
+ });
+ }, [activeModel]);
+
+ // Dispose native Skia image on unmount to prevent memory leaks
+ useEffect(() => {
+ return () => {
+ setMaskImage((prev) => {
+ prev?.dispose();
+ return null;
+ });
+ };
+ }, []);
+
+ const segRof = active.runOnFrame;
+
+ const updateMask = useCallback(
+ (img: SkImage) => {
+ setMaskImage((prev) => {
+ prev?.dispose();
+ return img;
+ });
+ const now = Date.now();
+ const diff = now - lastFrameTimeRef.current;
+ if (diff > 0) onFpsChange(Math.round(1000 / diff), diff);
+ lastFrameTimeRef.current = now;
+ },
+ [onFpsChange]
+ );
+
+ // CLASS_COLORS captured directly in closure — worklets cannot import modules
+ const colors = CLASS_COLORS;
+
+ const frameOutput = useFrameOutput({
+ pixelFormat: 'rgb',
+ dropFramesWhileBusy: true,
+ onFrame: useCallback(
+ (frame: Frame) => {
+ 'worklet';
+ if (frameKillSwitch.getDirty()) {
+ frame.dispose();
+ return;
+ }
+ try {
+ if (!segRof) return;
+ const result = segRof(frame, [], false);
+ if (result?.ARGMAX) {
+ const argmax: Int32Array = result.ARGMAX;
+ const side = Math.round(Math.sqrt(argmax.length));
+ const pixels = new Uint8Array(side * side * 4);
+ for (let i = 0; i < argmax.length; i++) {
+ const color = colors[argmax[i]!] ?? [0, 0, 0, 0];
+ pixels[i * 4] = color[0]!;
+ pixels[i * 4 + 1] = color[1]!;
+ pixels[i * 4 + 2] = color[2]!;
+ pixels[i * 4 + 3] = color[3]!;
+ }
+ const skData = Skia.Data.fromBytes(pixels);
+ const img = Skia.Image.MakeImage(
+ {
+ width: side,
+ height: side,
+ alphaType: AlphaType.Unpremul,
+ colorType: ColorType.RGBA_8888,
+ },
+ skData,
+ side * 4
+ );
+ if (img) scheduleOnRN(updateMask, img);
+ }
+ } catch {
+ // ignore
+ } finally {
+ frame.dispose();
+ }
+ },
+ [colors, frameKillSwitch, segRof, updateMask]
+ ),
+ });
+
+ useEffect(() => {
+ onFrameOutputChange(frameOutput);
+ }, [frameOutput, onFrameOutputChange]);
+
+ if (!maskImage) return null;
+
+ return (
+
+
+
+ );
+}
diff --git a/apps/computer-vision/components/vision_camera/tasks/types.ts b/apps/computer-vision/components/vision_camera/tasks/types.ts
new file mode 100644
index 0000000000..9727227f2f
--- /dev/null
+++ b/apps/computer-vision/components/vision_camera/tasks/types.ts
@@ -0,0 +1,14 @@
+import { useFrameOutput } from 'react-native-vision-camera';
+import { createSynchronizable } from 'react-native-worklets';
+
+export type TaskProps = {
+ activeModel: string;
+ canvasSize: { width: number; height: number };
+ cameraPosition: 'front' | 'back';
+ frameKillSwitch: ReturnType>;
+ onFrameOutputChange: (frameOutput: ReturnType) => void;
+ onReadyChange: (isReady: boolean) => void;
+ onProgressChange: (progress: number) => void;
+ onGeneratingChange: (isGenerating: boolean) => void;
+ onFpsChange: (fps: number, frameMs: number) => void;
+};
diff --git a/apps/computer-vision/components/vision_camera/utils/colors.ts b/apps/computer-vision/components/vision_camera/utils/colors.ts
new file mode 100644
index 0000000000..c38493a3b0
--- /dev/null
+++ b/apps/computer-vision/components/vision_camera/utils/colors.ts
@@ -0,0 +1,41 @@
+export const CLASS_COLORS: number[][] = [
+ [0, 0, 0, 0],
+ [51, 255, 87, 180],
+ [51, 87, 255, 180],
+ [255, 51, 246, 180],
+ [51, 255, 246, 180],
+ [243, 255, 51, 180],
+ [141, 51, 255, 180],
+ [255, 131, 51, 180],
+ [51, 255, 131, 180],
+ [131, 51, 255, 180],
+ [255, 255, 51, 180],
+ [51, 255, 255, 180],
+ [255, 51, 143, 180],
+ [127, 51, 255, 180],
+ [51, 255, 175, 180],
+ [255, 175, 51, 180],
+ [179, 255, 51, 180],
+ [255, 87, 51, 180],
+ [255, 51, 162, 180],
+ [51, 162, 255, 180],
+ [162, 51, 255, 180],
+];
+
+export function hashLabel(label: string): number {
+ let hash = 5381;
+ for (let i = 0; i < label.length; i++) {
+ hash = (hash + hash * 32 + label.charCodeAt(i)) % 1000003;
+ }
+ return 1 + (Math.abs(hash) % (CLASS_COLORS.length - 1));
+}
+
+export function labelColor(label: string): string {
+ const color = CLASS_COLORS[hashLabel(label)]!;
+ return `rgba(${color[0]},${color[1]},${color[2]},1)`;
+}
+
+export function labelColorBg(label: string): string {
+ const color = CLASS_COLORS[hashLabel(label)]!;
+ return `rgba(${color[0]},${color[1]},${color[2]},0.75)`;
+}
diff --git a/apps/computer-vision/package.json b/apps/computer-vision/package.json
index d7128125dd..578acf19b3 100644
--- a/apps/computer-vision/package.json
+++ b/apps/computer-vision/package.json
@@ -31,14 +31,14 @@
"react-native-gesture-handler": "~2.28.0",
"react-native-image-picker": "^7.2.2",
"react-native-loading-spinner-overlay": "^3.0.1",
- "react-native-nitro-image": "^0.12.0",
- "react-native-nitro-modules": "^0.33.9",
+ "react-native-nitro-image": "0.13.0",
+ "react-native-nitro-modules": "0.35.0",
"react-native-reanimated": "~4.2.2",
"react-native-safe-area-context": "~5.6.0",
"react-native-screens": "~4.16.0",
"react-native-svg": "15.15.3",
"react-native-svg-transformer": "^1.5.3",
- "react-native-vision-camera": "5.0.0-beta.2",
+ "react-native-vision-camera": "5.0.0-beta.6",
"react-native-worklets": "0.7.4"
},
"devDependencies": {
diff --git a/docs/docs/03-hooks/02-computer-vision/useClassification.md b/docs/docs/03-hooks/02-computer-vision/useClassification.md
index e9c2eebfab..e88cce1aff 100644
--- a/docs/docs/03-hooks/02-computer-vision/useClassification.md
+++ b/docs/docs/03-hooks/02-computer-vision/useClassification.md
@@ -52,12 +52,16 @@ You need more details? Check the following resources:
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns a promise, which can resolve either to an error or an object containing categories with their probabilities.
+To run the model, use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to an object containing categories with their probabilities.
:::info
Images from external sources are stored in your application's temporary directory.
:::
+## VisionCamera integration
+
+See the full guide: [VisionCamera Integration](./visioncamera-integration.md).
+
## Example
```typescript
diff --git a/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md b/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md
index caef87cdf2..a6ea5fa982 100644
--- a/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md
+++ b/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md
@@ -63,7 +63,11 @@ You need more details? Check the following resources:
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/interfaces/ImageEmbeddingsType.md#forward) method. It accepts one argument which is a URI/URL to an image you want to encode or base64 (whole URI or only raw base64). The function returns a promise, which can resolve either to an error or an array of numbers representing the embedding.
+To run the model, use the [`forward`](../../06-api-reference/interfaces/ImageEmbeddingsType.md#forward) method. It accepts one argument — the image to embed. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to a `Float32Array` representing the embedding.
+
+## VisionCamera integration
+
+See the full guide: [VisionCamera Integration](./visioncamera-integration.md).
## Example
diff --git a/docs/docs/03-hooks/02-computer-vision/useOCR.md b/docs/docs/03-hooks/02-computer-vision/useOCR.md
index 76e7ad6956..41491c7143 100644
--- a/docs/docs/03-hooks/02-computer-vision/useOCR.md
+++ b/docs/docs/03-hooks/02-computer-vision/useOCR.md
@@ -50,7 +50,11 @@ You need more details? Check the following resources:
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the text recognized within the box, and the confidence score. For more information, please refer to the reference or type definitions.
+To run the model, use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score.
+
+## VisionCamera integration
+
+See the full guide: [VisionCamera Integration](./visioncamera-integration.md).
## Detection object
diff --git a/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md b/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md
index 69ac3c79ac..5fb2b2bb3a 100644
--- a/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md
+++ b/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md
@@ -66,7 +66,7 @@ You need more details? Check the following resources:
To run the model, use the [`forward`](../../06-api-reference/interfaces/ObjectDetectionType.md#forward) method. It accepts two arguments:
-- `imageSource` (required) - The image to process. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64).
+- `input` (required) - The image to process. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer).
- `detectionThreshold` (optional) - A number between 0 and 1 representing the minimum confidence score for a detection to be included in the results. Defaults to `0.7`.
`forward` returns a promise resolving to an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects, each containing:
@@ -107,6 +107,10 @@ function App() {
}
```
+## VisionCamera integration
+
+See the full guide: [VisionCamera Integration](./visioncamera-integration.md).
+
## Supported models
| Model | Number of classes | Class list |
diff --git a/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md b/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md
index 3a6fa46553..dc654369c7 100644
--- a/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md
+++ b/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md
@@ -66,7 +66,7 @@ You need more details? Check the following resources:
To run the model, use the [`forward`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) method. It accepts three arguments:
-- [`imageSource`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64).
+- [`input`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer).
- [`classesOfInterest`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]` (no class masks). The `ARGMAX` map is always returned regardless of this parameter.
- [`resizeToInput`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions (e.g. 224x224 for `DEEPLAB_V3_RESNET50`).
@@ -115,6 +115,10 @@ function App() {
}
```
+## VisionCamera integration
+
+See the full guide: [VisionCamera Integration](./visioncamera-integration.md).
+
## Supported models
| Model | Number of classes | Class list | Quantized |
diff --git a/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md b/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md
index 471bde35e9..d08d7e8688 100644
--- a/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md
+++ b/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md
@@ -23,10 +23,13 @@ import {
const model = useStyleTransfer({ model: STYLE_TRANSFER_CANDY });
-const imageUri = 'file::///Users/.../cute_cat.png';
+const imageUri = 'file:///Users/.../cute_cat.png';
try {
- const generatedImageUrl = await model.forward(imageUri);
+ // Returns a file URI string
+ const uri = await model.forward(imageUri, 'url');
+ // Or returns raw PixelData (default)
+ const pixels = await model.forward(imageUri);
} catch (error) {
console.error(error);
}
@@ -51,30 +54,54 @@ You need more details? Check the following resources:
## Running the model
-To run the model, you can use [`forward`](../../06-api-reference/interfaces/StyleTransferType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns a promise which can resolve either to an error or a URL to generated image.
+To run the model, use the [`forward`](../../06-api-reference/interfaces/StyleTransferType.md#forward) method. It accepts two arguments:
+
+- `input` (required) — The image to stylize. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer).
+- `outputType` (optional) — Controls the return format:
+ - `'pixelData'` (default) — Returns a `PixelData` object with raw RGB pixels. No file is written.
+ - `'url'` — Saves the result to a temp file and returns its URI as a `string`.
:::info
-Images from external sources and the generated image are stored in your application's temporary directory.
+When `outputType` is `'url'`, the generated image is stored in your application's temporary directory.
:::
## Example
```typescript
+import {
+ useStyleTransfer,
+ STYLE_TRANSFER_CANDY,
+} from 'react-native-executorch';
+
function App() {
const model = useStyleTransfer({ model: STYLE_TRANSFER_CANDY });
- // ...
- const imageUri = 'file::///Users/.../cute_cat.png';
-
- try {
- const generatedImageUrl = await model.forward(imageUri);
- } catch (error) {
- console.error(error);
- }
- // ...
+ // Returns a file URI — easy to pass to
+ const runWithUrl = async (imageUri: string) => {
+ try {
+ const uri = await model.forward(imageUri, 'url');
+ console.log('Styled image saved at:', uri);
+ } catch (error) {
+ console.error(error);
+ }
+ };
+
+ // Returns raw PixelData — useful for further processing or frame pipelines
+ const runWithPixelData = async (imageUri: string) => {
+ try {
+ const pixels = await model.forward(imageUri);
+ // pixels.dataPtr is a Uint8Array of RGB bytes
+ } catch (error) {
+ console.error(error);
+ }
+ };
}
```
+## VisionCamera integration
+
+See the full guide: [VisionCamera Integration](./visioncamera-integration.md).
+
## Supported models
- [Candy](https://github.com/pytorch/examples/tree/main/fast_neural_style)
diff --git a/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md b/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md
index b9d29fc423..80b142ac62 100644
--- a/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md
+++ b/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md
@@ -58,7 +58,11 @@ You need more details? Check the following resources:
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the text recognized within the box, and the confidence score. For more information, please refer to the reference or type definitions.
+To run the model, use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score.
+
+## VisionCamera integration
+
+See the full guide: [VisionCamera Integration](./visioncamera-integration.md).
## Detection object
diff --git a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md
new file mode 100644
index 0000000000..79f6c5aad8
--- /dev/null
+++ b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md
@@ -0,0 +1,207 @@
+---
+title: VisionCamera Integration
+---
+
+React Native ExecuTorch vision models support real-time frame processing via [VisionCamera v5](https://react-native-vision-camera-v5-docs.vercel.app) using the `runOnFrame` worklet. This page explains how to set it up and what to watch out for.
+
+## Prerequisites
+
+Make sure you have the following packages installed:
+
+- [`react-native-vision-camera`](https://react-native-vision-camera-v5-docs.vercel.app) v5
+- [`react-native-worklets`](https://docs.swmansion.com/react-native-worklets/)
+
+## Which models support runOnFrame?
+
+The following hooks expose `runOnFrame`:
+
+- [`useClassification`](./useClassification.md)
+- [`useImageEmbeddings`](./useImageEmbeddings.md)
+- [`useOCR`](./useOCR.md)
+- [`useVerticalOCR`](./useVerticalOCR.md)
+- [`useObjectDetection`](./useObjectDetection.md)
+- [`useSemanticSegmentation`](./useSemanticSegmentation.md)
+- [`useStyleTransfer`](./useStyleTransfer.md)
+
+## runOnFrame vs forward
+
+| | `runOnFrame` | `forward` |
+| -------- | -------------------- | -------------------------- |
+| Thread | JS (worklet) | Background thread |
+| Input | VisionCamera `Frame` | `string` URI / `PixelData` |
+| Output | Model result (sync) | `Promise` |
+| Use case | Real-time camera | Single image |
+
+Use `runOnFrame` when you need to process every camera frame. Use `forward` for one-off image inference.
+
+## How it works
+
+VisionCamera v5 delivers frames via [`useFrameOutput`](https://react-native-vision-camera-v5-docs.vercel.app/docs/frame-output). Inside the `onFrame` worklet you call `runOnFrame(frame)` synchronously, then use `scheduleOnRN` from `react-native-worklets` to post the result back to React state on the main thread.
+
+:::warning
+You **must** set `pixelFormat: 'rgb'` in `useFrameOutput`. Our extraction pipeline expect RGB pixel data — any other format (e.g. the default `yuv`) will produce incorrect results.
+:::
+
+:::warning
+`runOnFrame` is synchronous and runs on the JS worklet thread. For models with longer inference times, use `dropFramesWhileBusy: true` to skip frames and avoid blocking the camera pipeline. For more control, see VisionCamera's [async frame processing guide](https://react-native-vision-camera-v5-docs.vercel.app/docs/async-frame-processing).
+:::
+
+:::note
+Always call `frame.dispose()` after processing to release the frame buffer. Wrap your inference in a `try/finally` to ensure it's always called even if `runOnFrame` throws.
+:::
+
+## Full example (Classification)
+
+```tsx
+import { useState, useCallback } from 'react';
+import { Text, StyleSheet } from 'react-native';
+import {
+ Camera,
+ Frame,
+ useCameraDevices,
+ useCameraPermission,
+ useFrameOutput,
+} from 'react-native-vision-camera';
+import { scheduleOnRN } from 'react-native-worklets';
+import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch';
+
+export default function App() {
+ const { hasPermission, requestPermission } = useCameraPermission();
+ const devices = useCameraDevices();
+ const device = devices.find((d) => d.position === 'back');
+ const model = useClassification({ model: EFFICIENTNET_V2_S });
+ const [topLabel, setTopLabel] = useState('');
+
+ // Extract runOnFrame so it can be captured by the useCallback dependency array
+ const runOnFrame = model.runOnFrame;
+
+ const frameOutput = useFrameOutput({
+ pixelFormat: 'rgb',
+ dropFramesWhileBusy: true,
+ onFrame: useCallback(
+ (frame: Frame) => {
+ 'worklet';
+ if (!runOnFrame) return;
+ try {
+ const scores = runOnFrame(frame);
+ if (scores) {
+ let best = '';
+ let bestScore = -1;
+ for (const [label, score] of Object.entries(scores)) {
+ if ((score as number) > bestScore) {
+ bestScore = score as number;
+ best = label;
+ }
+ }
+ scheduleOnRN(setTopLabel, best);
+ }
+ } finally {
+ frame.dispose();
+ }
+ },
+ [runOnFrame]
+ ),
+ });
+
+ if (!hasPermission) {
+ requestPermission();
+ return null;
+ }
+
+ if (!device) return null;
+
+ return (
+ <>
+
+ {topLabel}
+ >
+ );
+}
+
+const styles = StyleSheet.create({
+ camera: { flex: 1 },
+ label: {
+ position: 'absolute',
+ bottom: 40,
+ alignSelf: 'center',
+ color: 'white',
+ fontSize: 20,
+ },
+});
+```
+
+## Using the Module API
+
+If you use the TypeScript Module API (e.g. `ClassificationModule`) directly instead of a hook, `runOnFrame` is a worklet function and **cannot** be passed directly to `useState` — React would invoke it as a state initializer. Use the functional updater form `() => module.runOnFrame`:
+
+```tsx
+import { useState, useEffect, useCallback } from 'react';
+import { Camera, useFrameOutput } from 'react-native-vision-camera';
+import { scheduleOnRN } from 'react-native-worklets';
+import {
+ ClassificationModule,
+ EFFICIENTNET_V2_S,
+} from 'react-native-executorch';
+
+export default function App() {
+ const [module] = useState(() => new ClassificationModule());
+ const [runOnFrame, setRunOnFrame] = useState(
+ null
+ );
+
+ useEffect(() => {
+ module.load(EFFICIENTNET_V2_S).then(() => {
+ // () => module.runOnFrame is required — passing module.runOnFrame directly
+ // would cause React to call it as a state initializer function
+ setRunOnFrame(() => module.runOnFrame);
+ });
+ }, [module]);
+
+ const frameOutput = useFrameOutput({
+ pixelFormat: 'rgb',
+ dropFramesWhileBusy: true,
+ onFrame: useCallback(
+ (frame) => {
+ 'worklet';
+ if (!runOnFrame) return;
+ try {
+ const result = runOnFrame(frame);
+ if (result) scheduleOnRN(setResult, result);
+ } finally {
+ frame.dispose();
+ }
+ },
+ [runOnFrame]
+ ),
+ });
+
+ return ;
+}
+```
+
+## Common issues
+
+### Results look wrong or scrambled
+
+You forgot to set `pixelFormat: 'rgb'`. The default VisionCamera pixel format is `yuv` — our frame extraction works only with RGB data.
+
+### App freezes or camera drops frames
+
+Your model's inference time exceeds the frame interval. Enable `dropFramesWhileBusy: true` in `useFrameOutput`, or move inference off the worklet thread using VisionCamera's [async frame processing](https://react-native-vision-camera-v5-docs.vercel.app/docs/async-frame-processing).
+
+### Memory leak / crash after many frames
+
+You are not calling `frame.dispose()`. Always dispose the frame in a `finally` block.
+
+### `runOnFrame` is always null
+
+The model hasn't finished loading yet. Guard with `if (!runOnFrame) return` inside `onFrame`, or check `model.isReady` before enabling the camera.
+
+### TypeError: `module.runOnFrame` is not a function (Module API)
+
+You passed `module.runOnFrame` directly to `setState` instead of `() => module.runOnFrame`. React invoked it as a state initializer — see the [Module API section](#using-the-module-api) above.
diff --git a/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md b/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md
index 4234fa865d..a9bf6ed77d 100644
--- a/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md
+++ b/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md
@@ -47,7 +47,9 @@ For more information on loading resources, take a look at [loading models](../..
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method on the module object. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an object containing categories with their probabilities.
+To run the model, use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an object containing categories with their probabilities.
+
+For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead.
## Managing memory
diff --git a/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md b/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md
index 7388416334..47eceef00a 100644
--- a/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md
+++ b/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md
@@ -48,4 +48,6 @@ For more information on loading resources, take a look at [loading models](../..
## Running the model
-[`forward`](../../06-api-reference/classes/ImageEmbeddingsModule.md#forward) accepts one argument: image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns a promise, which can resolve either to an error or an array of numbers representing the embedding.
+[`forward`](../../06-api-reference/classes/ImageEmbeddingsModule.md#forward) accepts one argument — the image to embed. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to a `Float32Array` representing the embedding.
+
+For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead.
diff --git a/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md b/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md
index 1524859ed2..1391982173 100644
--- a/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md
+++ b/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md
@@ -41,4 +41,6 @@ For more information on loading resources, take a look at [loading models](../..
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/classes/OCRModule.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score.
+To run the model, use the [`forward`](../../06-api-reference/classes/OCRModule.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score.
+
+For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead.
diff --git a/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md b/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md
index d942eded65..b56cb47713 100644
--- a/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md
+++ b/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md
@@ -40,7 +40,14 @@ For more information on loading resources, take a look at [loading models](../..
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/classes/ObjectDetectionModule.md#forward) method on the module object. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score.
+To run the model, use the [`forward`](../../06-api-reference/classes/ObjectDetectionModule.md#forward) method. It accepts two arguments:
+
+- `input` (required) - The image to process. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer).
+- `detectionThreshold` (optional) - A number between 0 and 1. Defaults to `0.7`.
+
+The method returns a promise resolving to an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects, each containing the bounding box, label, and confidence score.
+
+For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead.
## Using a custom model
diff --git a/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md b/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md
index bf88690bdb..4bf2129ac1 100644
--- a/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md
+++ b/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md
@@ -86,7 +86,7 @@ For more information on loading resources, take a look at [loading models](../..
To run the model, use the [`forward`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) method. It accepts three arguments:
-- [`imageSource`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64).
+- [`input`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer).
- [`classesOfInterest`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]`. The `ARGMAX` map is always returned regardless.
- [`resizeToInput`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions.
@@ -113,6 +113,8 @@ result.CAT; // Float32Array
result.DOG; // Float32Array
```
+For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead.
+
## Managing memory
The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/SemanticSegmentationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) after [`delete`](../../06-api-reference/classes/SemanticSegmentationModule.md#delete) unless you create a new instance.
diff --git a/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md b/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md
index 0e70a24796..4c57716001 100644
--- a/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md
+++ b/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md
@@ -47,7 +47,14 @@ For more information on loading resources, take a look at [loading models](../..
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/classes/StyleTransferModule.md#forward) method on the module object. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or a URL to generated image.
+To run the model, use the [`forward`](../../06-api-reference/classes/StyleTransferModule.md#forward) method. It accepts two arguments:
+
+- `input` (required) — The image to stylize. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer).
+- `outputType` (optional) — Controls the return format:
+ - `'pixelData'` (default) — Returns a `PixelData` object with raw RGB pixels. No file is written.
+ - `'url'` — Saves the result to a temp file and returns its URI as a `string`.
+
+For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead.
## Managing memory
diff --git a/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md b/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md
index eb47efa857..cc1cdba51d 100644
--- a/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md
+++ b/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md
@@ -43,4 +43,6 @@ For more information on loading resources, take a look at [loading models](../..
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/classes/VerticalOCRModule.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score.
+To run the model, use the [`forward`](../../06-api-reference/classes/VerticalOCRModule.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score.
+
+For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead.
diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp
index bd29500b00..d7a763819a 100644
--- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp
@@ -225,7 +225,7 @@ readImageToTensor(const std::string &path,
if (tensorDims.size() < 2) {
char errorMessage[100];
std::snprintf(errorMessage, sizeof(errorMessage),
- "Unexpected tensor size, expected at least 2 dimentions "
+ "Unexpected tensor size, expected at least 2 dimensions "
"but got: %zu.",
tensorDims.size());
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h
index 4057950b23..1eb288ee30 100644
--- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h
+++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h
@@ -1,5 +1,8 @@
#pragma once
+#include
+#include
+
namespace rnexecutorch {
using executorch::aten::ScalarType;
diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
index 96e3168ee7..a4e373c2b8 100644
--- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
+++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
@@ -6,6 +6,7 @@
#include
#include
#include
+#include
#include
#include
@@ -17,8 +18,10 @@
#include
#include
#include
+#include
#include
#include
+#include
#include
using namespace rnexecutorch::models::speech_to_text;
@@ -557,4 +560,63 @@ inline jsi::Value getJsiValue(const TranscriptionResult &result,
return obj;
}
+inline jsi::Value
+getJsiValue(const models::style_transfer::PixelDataResult &result,
+ jsi::Runtime &runtime) {
+ jsi::Object obj(runtime);
+
+ auto arrayBuffer = jsi::ArrayBuffer(runtime, result.dataPtr);
+ auto uint8ArrayCtor =
+ runtime.global().getPropertyAsFunction(runtime, "Uint8Array");
+ auto uint8Array =
+ uint8ArrayCtor.callAsConstructor(runtime, arrayBuffer).getObject(runtime);
+ obj.setProperty(runtime, "dataPtr", uint8Array);
+
+ auto sizesArray = jsi::Array(runtime, 3);
+ sizesArray.setValueAtIndex(runtime, 0, jsi::Value(result.height));
+ sizesArray.setValueAtIndex(runtime, 1, jsi::Value(result.width));
+ sizesArray.setValueAtIndex(runtime, 2, jsi::Value(result.channels));
+ obj.setProperty(runtime, "sizes", sizesArray);
+
+ obj.setProperty(runtime, "scalarType",
+ jsi::Value(static_cast(ScalarType::Byte)));
+
+ return obj;
+}
+
+inline jsi::Value getJsiValue(
+ const rnexecutorch::models::semantic_segmentation::SegmentationResult
+ &result,
+ jsi::Runtime &runtime) {
+ jsi::Object dict(runtime);
+
+ auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, result.argmax);
+ auto int32ArrayCtor =
+ runtime.global().getPropertyAsFunction(runtime, "Int32Array");
+ auto int32Array = int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer)
+ .getObject(runtime);
+ dict.setProperty(runtime, "ARGMAX", int32Array);
+
+ for (auto &[classLabel, owningBuffer] : *result.classBuffers) {
+ auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer);
+ auto float32ArrayCtor =
+ runtime.global().getPropertyAsFunction(runtime, "Float32Array");
+ auto float32Array =
+ float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer)
+ .getObject(runtime);
+ dict.setProperty(runtime, jsi::String::createFromAscii(runtime, classLabel),
+ float32Array);
+ }
+
+ return dict;
+}
+
+inline jsi::Value
+getJsiValue(const models::style_transfer::StyleTransferResult &result,
+ jsi::Runtime &runtime) {
+ return std::visit(
+ [&runtime](const auto &value) { return getJsiValue(value, runtime); },
+ result);
+}
+
} // namespace rnexecutorch::jsi_conversion
diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h
index c13b8991dc..e4361273d5 100644
--- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h
+++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h
@@ -190,13 +190,11 @@ template class ModelHostObject : public JsiHostObject {
addFunctions(JSI_EXPORT_FUNCTION(
ModelHostObject, synchronousHostFunction<&Model::streamStop>,
"streamStop"));
- addFunctions(JSI_EXPORT_FUNCTION(
- ModelHostObject, synchronousHostFunction<&Model::streamInsert>,
- "streamInsert"));
addFunctions(
JSI_EXPORT_FUNCTION(ModelHostObject,
promiseHostFunction<&Model::generateFromPhonemes>,
"generateFromPhonemes"));
+
addFunctions(
JSI_EXPORT_FUNCTION(ModelHostObject,
promiseHostFunction<&Model::streamFromPhonemes>,
@@ -328,8 +326,24 @@ template class ModelHostObject : public JsiHostObject {
return jsi_conversion::getJsiValue(std::move(result), runtime);
}
+ } catch (const RnExecutorchError &e) {
+ jsi::Object errorData(runtime);
+ errorData.setProperty(runtime, "code", e.getNumericCode());
+ errorData.setProperty(runtime, "message",
+ jsi::String::createFromUtf8(runtime, e.what()));
+ throw jsi::JSError(runtime, jsi::Value(runtime, std::move(errorData)));
+ } catch (const std::runtime_error &e) {
+ // This catch should be merged with the next one
+ // (std::runtime_error inherits from std::exception) HOWEVER react
+ // native has broken RTTI which breaks proper exception type
+ // checking. Remove when the following change is present in our
+ // version:
+ // https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
+ throw jsi::JSError(runtime, e.what());
} catch (const std::exception &e) {
throw jsi::JSError(runtime, e.what());
+ } catch (...) {
+ throw jsi::JSError(runtime, "Unknown error in vision function");
}
}
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp
index b88310e124..0b6acbc383 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp
@@ -1,53 +1,53 @@
#include "VisionModel.h"
#include
#include
-#include
#include
namespace rnexecutorch::models {
using namespace facebook;
-cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime,
- const jsi::Value &frameData) const {
- auto frameObj = frameData.asObject(runtime);
- return ::rnexecutorch::utils::extractFrame(runtime, frameObj);
+VisionModel::VisionModel(const std::string &modelSource,
+ std::shared_ptr callInvoker)
+ : BaseModel(modelSource, callInvoker) {}
+
+void VisionModel::unload() noexcept {
+ std::scoped_lock lock(inference_mutex_);
+ BaseModel::unload();
}
-cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const {
- if (tensorView.sizes.size() != 3) {
- char errorMessage[100];
- std::snprintf(errorMessage, sizeof(errorMessage),
- "Invalid pixel data: sizes must have 3 elements "
- "[height, width, channels], got %zu",
- tensorView.sizes.size());
- throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
- errorMessage);
+cv::Size VisionModel::modelInputSize() const {
+ if (modelInputShape_.size() < 2) {
+ return {0, 0};
}
+ return cv::Size(modelInputShape_[modelInputShape_.size() - 1],
+ modelInputShape_[modelInputShape_.size() - 2]);
+}
- int32_t height = tensorView.sizes[0];
- int32_t width = tensorView.sizes[1];
- int32_t channels = tensorView.sizes[2];
-
- if (channels != 3) {
- char errorMessage[100];
- std::snprintf(errorMessage, sizeof(errorMessage),
- "Invalid pixel data: expected 3 channels (RGB), got %d",
- channels);
- throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
- errorMessage);
- }
+cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime,
+ const jsi::Value &frameData) const {
+ cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData);
+ cv::Mat rgb;
+#ifdef __APPLE__
+ cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB);
+#else
+ cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB);
+#endif
+ return rgb;
+}
- if (tensorView.scalarType != ScalarType::Byte) {
- throw RnExecutorchError(
- RnExecutorchErrorCode::InvalidUserInput,
- "Invalid pixel data: scalarType must be BYTE (Uint8Array)");
+cv::Mat VisionModel::preprocess(const cv::Mat &image) const {
+ const cv::Size targetSize = modelInputSize();
+ if (image.size() == targetSize) {
+ return image;
}
+ cv::Mat resized;
+ cv::resize(image, resized, targetSize);
+ return resized;
+}
- uint8_t *dataPtr = static_cast(tensorView.dataPtr);
- cv::Mat image(height, width, CV_8UC3, dataPtr);
-
- return image;
+cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const {
+ return ::rnexecutorch::utils::pixelsToMat(tensorView);
}
} // namespace rnexecutorch::models
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h
index 4828f26578..6f9a9532f4 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h
@@ -23,8 +23,8 @@ namespace models {
* Usage:
* Subclasses should:
* 1. Inherit from VisionModel instead of BaseModel
- * 2. Implement preprocessFrame() with model-specific preprocessing
- * 3. Delegate to runInference() which handles locking internally
+ * 2. Optionally override preprocess() for model-specific preprocessing
+ * 3. Implement runInference() which acquires the lock internally
*
* Example:
* @code
@@ -33,7 +33,7 @@ namespace models {
* std::unordered_map
* generateFromFrame(jsi::Runtime& runtime, const jsi::Value& frameValue) {
* auto frameObject = frameValue.asObject(runtime);
- * cv::Mat frame = utils::extractFrame(runtime, frameObject);
+ * cv::Mat frame = extractFromFrame(runtime, frameObject);
* return runInference(frame);
* }
* };
@@ -41,19 +41,27 @@ namespace models {
*/
class VisionModel : public BaseModel {
public:
- /**
- * @brief Construct a VisionModel with the same parameters as BaseModel
- *
- * VisionModel uses the same construction pattern as BaseModel, just adding
- * thread-safety on top.
- */
VisionModel(const std::string &modelSource,
- std::shared_ptr callInvoker)
- : BaseModel(modelSource, callInvoker) {}
+ std::shared_ptr callInvoker);
virtual ~VisionModel() = default;
+ /**
+ * @brief Thread-safe unload that waits for any in-flight inference to
+ * complete
+ *
+ * Overrides BaseModel::unload() to acquire inference_mutex_ before
+ * resetting the module. This prevents a crash where BaseModel::unload()
+ * destroys module_ while generateFromFrame() is still executing on the
+ * VisionCamera worklet thread.
+ */
+ void unload() noexcept;
+
protected:
+ /// Cached input tensor shape (getAllInputShapes()[0]).
+ /// Set once by each subclass constructor to avoid per-frame metadata lookups.
+ std::vector modelInputShape_;
+
/**
* @brief Mutex to ensure thread-safe inference
*
@@ -70,44 +78,32 @@ class VisionModel : public BaseModel {
mutable std::mutex inference_mutex_;
/**
- * @brief Preprocess a camera frame for model input
+ * @brief Resize an RGB image to the model's expected input size
*
- * This method should implement model-specific preprocessing such as:
- * - Resizing to the model's expected input size
- * - Color space conversion (e.g., BGR to RGB)
- * - Normalization
- * - Any other model-specific transformations
+ * Resizes to modelInputSize() if needed. Subclasses may override for
+ * model-specific preprocessing (e.g., normalisation).
*
- * @param frame Input frame from camera (already extracted and rotated by
- * FrameExtractor)
- * @return Preprocessed cv::Mat ready for tensor conversion
+ * @param image Input image in RGB format
+ * @return cv::Mat resized to modelInputSize(), in RGB format
*
- * @note The input frame is already in RGB format and rotated 90° clockwise
- * @note This method is called under mutex protection in generateFromFrame()
+ * @note Called from runInference() under the inference mutex
*/
- virtual cv::Mat preprocessFrame(const cv::Mat &frame) const = 0;
+ virtual cv::Mat preprocess(const cv::Mat &image) const;
+
+ /// Convenience accessor: spatial dimensions of the model input.
+ cv::Size modelInputSize() const;
/**
- * @brief Extract and preprocess frame from VisionCamera in one call
+ * @brief Extract an RGB cv::Mat from a VisionCamera frame
*
- * This is a convenience method that combines frame extraction and
- * preprocessing. It handles both nativeBuffer (zero-copy) and ArrayBuffer
- * paths automatically.
+ * Calls frameToMat() then converts the raw 4-channel frame
+ * (BGRA on iOS, RGBA on Android) to RGB.
*
* @param runtime JSI runtime
* @param frameData JSI value containing frame data from VisionCamera
+ * @return cv::Mat in RGB format (3 channels)
*
- * @return Preprocessed cv::Mat ready for tensor conversion
- *
- * @throws std::runtime_error if frame extraction fails
- *
- * @note This method does NOT acquire the inference mutex - caller is
- * responsible
- * @note Typical usage:
- * @code
- * cv::Mat preprocessed = extractFromFrame(runtime, frameData);
- * auto tensor = image_processing::getTensorFromMatrix(dims, preprocessed);
- * @endcode
+ * @note Does NOT acquire the inference mutex — caller is responsible
*/
cv::Mat extractFromFrame(jsi::Runtime &runtime,
const jsi::Value &frameData) const;
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp
index 0fba071087..f713b59605 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp
@@ -12,31 +12,33 @@ namespace rnexecutorch::models::classification {
Classification::Classification(const std::string &modelSource,
std::shared_ptr callInvoker)
- : BaseModel(modelSource, callInvoker) {
+ : VisionModel(modelSource, callInvoker) {
auto inputShapes = getAllInputShapes();
if (inputShapes.size() == 0) {
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
"Model seems to not take any input tensors.");
}
- std::vector modelInputShape = inputShapes[0];
- if (modelInputShape.size() < 2) {
+ modelInputShape_ = inputShapes[0];
+ if (modelInputShape_.size() < 2) {
char errorMessage[100];
std::snprintf(errorMessage, sizeof(errorMessage),
- "Unexpected model input size, expected at least 2 dimentions "
+ "Unexpected model input size, expected at least 2 dimensions "
"but got: %zu.",
- modelInputShape.size());
+ modelInputShape_.size());
throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions,
errorMessage);
}
- modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1],
- modelInputShape[modelInputShape.size() - 2]);
}
std::unordered_map
-Classification::generate(std::string imageSource) {
+Classification::runInference(cv::Mat image) {
+ std::scoped_lock lock(inference_mutex_);
+
+ cv::Mat preprocessed = preprocess(image);
+
auto inputTensor =
- image_processing::readImageToTensor(imageSource, getAllInputShapes()[0])
- .first;
+ image_processing::getTensorFromMatrix(modelInputShape_, preprocessed);
+
auto forwardResult = BaseModel::forward(inputTensor);
if (!forwardResult.ok()) {
throw RnExecutorchError(forwardResult.error(),
@@ -46,6 +48,30 @@ Classification::generate(std::string imageSource) {
return postprocess(forwardResult->at(0).toTensor());
}
+std::unordered_map
+Classification::generateFromString(std::string imageSource) {
+ cv::Mat imageBGR = image_processing::readImage(imageSource);
+
+ cv::Mat imageRGB;
+ cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB);
+
+ return runInference(imageRGB);
+}
+
+std::unordered_map
+Classification::generateFromFrame(jsi::Runtime &runtime,
+ const jsi::Value &frameData) {
+ cv::Mat frame = extractFromFrame(runtime, frameData);
+ return runInference(frame);
+}
+
+std::unordered_map
+Classification::generateFromPixels(JSTensorViewIn pixelData) {
+ cv::Mat image = extractFromPixels(pixelData);
+
+ return runInference(image);
+}
+
std::unordered_map
Classification::postprocess(const Tensor &tensor) {
std::span resultData(
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h
index 1465fc5f9b..9f62864b9e 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h
@@ -3,28 +3,38 @@
#include
#include
+#include
#include
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
-#include
+#include
namespace rnexecutorch {
namespace models::classification {
using executorch::aten::Tensor;
using executorch::extension::TensorPtr;
-class Classification : public BaseModel {
+class Classification : public VisionModel {
public:
Classification(const std::string &modelSource,
std::shared_ptr callInvoker);
+
[[nodiscard("Registered non-void function")]] std::unordered_map<
std::string_view, float>
- generate(std::string imageSource);
+ generateFromString(std::string imageSource);
+
+ [[nodiscard("Registered non-void function")]] std::unordered_map<
+ std::string_view, float>
+ generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData);
+
+ [[nodiscard("Registered non-void function")]] std::unordered_map<
+ std::string_view, float>
+ generateFromPixels(JSTensorViewIn pixelData);
private:
- std::unordered_map postprocess(const Tensor &tensor);
+ std::unordered_map runInference(cv::Mat image);
- cv::Size modelImageSize{0, 0};
+ std::unordered_map postprocess(const Tensor &tensor);
};
} // namespace models::classification
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp
index ec3129e760..d2914469af 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp
@@ -1,40 +1,40 @@
#include "ImageEmbeddings.h"
-
#include
#include
#include
#include
-#include
namespace rnexecutorch::models::embeddings {
ImageEmbeddings::ImageEmbeddings(
const std::string &modelSource,
std::shared_ptr callInvoker)
- : BaseEmbeddings(modelSource, callInvoker) {
+ : VisionModel(modelSource, callInvoker) {
auto inputTensors = getAllInputShapes();
if (inputTensors.size() == 0) {
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
"Model seems to not take any input tensors.");
}
- std::vector modelInputShape = inputTensors[0];
- if (modelInputShape.size() < 2) {
+ modelInputShape_ = inputTensors[0];
+ if (modelInputShape_.size() < 2) {
char errorMessage[100];
std::snprintf(errorMessage, sizeof(errorMessage),
- "Unexpected model input size, expected at least 2 dimentions "
+ "Unexpected model input size, expected at least 2 dimensions "
"but got: %zu.",
- modelInputShape.size());
+ modelInputShape_.size());
throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions,
errorMessage);
}
- modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1],
- modelInputShape[modelInputShape.size() - 2]);
}
std::shared_ptr
-ImageEmbeddings::generate(std::string imageSource) {
- auto [inputTensor, originalSize] =
- image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]);
+ImageEmbeddings::runInference(cv::Mat image) {
+ std::scoped_lock lock(inference_mutex_);
+
+ cv::Mat preprocessed = preprocess(image);
+
+ auto inputTensor =
+ image_processing::getTensorFromMatrix(modelInputShape_, preprocessed);
auto forwardResult = BaseModel::forward(inputTensor);
@@ -45,7 +45,33 @@ ImageEmbeddings::generate(std::string imageSource) {
"is correct.");
}
- return BaseEmbeddings::postprocess(forwardResult);
+ auto forwardResultTensor = forwardResult->at(0).toTensor();
+ return std::make_shared(
+ forwardResultTensor.const_data_ptr(), forwardResultTensor.nbytes());
+}
+
+std::shared_ptr
+ImageEmbeddings::generateFromString(std::string imageSource) {
+ cv::Mat imageBGR = image_processing::readImage(imageSource);
+
+ cv::Mat imageRGB;
+ cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB);
+
+ return runInference(imageRGB);
+}
+
+std::shared_ptr
+ImageEmbeddings::generateFromFrame(jsi::Runtime &runtime,
+ const jsi::Value &frameData) {
+ cv::Mat frame = extractFromFrame(runtime, frameData);
+ return runInference(frame);
+}
+
+std::shared_ptr
+ImageEmbeddings::generateFromPixels(JSTensorViewIn pixelData) {
+ cv::Mat image = extractFromPixels(pixelData);
+
+ return runInference(image);
}
} // namespace rnexecutorch::models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h
index 7e114e939d..3a20301724 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h
@@ -2,26 +2,37 @@
#include
#include
+#include
#include
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
-#include
+#include
+#include
namespace rnexecutorch {
namespace models::embeddings {
using executorch::extension::TensorPtr;
using executorch::runtime::EValue;
-class ImageEmbeddings final : public BaseEmbeddings {
+class ImageEmbeddings final : public VisionModel {
public:
ImageEmbeddings(const std::string &modelSource,
std::shared_ptr callInvoker);
+
+ [[nodiscard(
+ "Registered non-void function")]] std::shared_ptr
+ generateFromString(std::string imageSource);
+
+ [[nodiscard(
+ "Registered non-void function")]] std::shared_ptr
+ generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData);
+
[[nodiscard(
"Registered non-void function")]] std::shared_ptr
- generate(std::string imageSource);
+ generateFromPixels(JSTensorViewIn pixelData);
private:
- cv::Size modelImageSize{0, 0};
+ std::shared_ptr runInference(cv::Mat image);
};
} // namespace models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp
index 1dad1a61a5..d30fc8aeed 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp
@@ -5,7 +5,6 @@
#include
#include
#include
-#include
namespace rnexecutorch::models::object_detection {
@@ -20,18 +19,16 @@ ObjectDetection::ObjectDetection(
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
"Model seems to not take any input tensors.");
}
- std::vector modelInputShape = inputTensors[0];
- if (modelInputShape.size() < 2) {
+ modelInputShape_ = inputTensors[0];
+ if (modelInputShape_.size() < 2) {
char errorMessage[100];
std::snprintf(errorMessage, sizeof(errorMessage),
- "Unexpected model input size, expected at least 2 dimentions "
+ "Unexpected model input size, expected at least 2 dimensions "
"but got: %zu.",
- modelInputShape.size());
+ modelInputShape_.size());
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
errorMessage);
}
- modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1],
- modelInputShape[modelInputShape.size() - 2]);
if (normMean.size() == 3) {
normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]);
} else if (!normMean.empty()) {
@@ -46,46 +43,13 @@ ObjectDetection::ObjectDetection(
}
}
-cv::Mat ObjectDetection::preprocessFrame(const cv::Mat &frame) const {
- const std::vector tensorDims = getAllInputShapes()[0];
- cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1],
- tensorDims[tensorDims.size() - 2]);
-
- cv::Mat rgb;
-
- if (frame.channels() == 4) {
-#ifdef __APPLE__
- cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB);
-#else
- cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB);
-#endif
- } else if (frame.channels() == 3) {
- rgb = frame;
- } else {
- char errorMessage[100];
- std::snprintf(errorMessage, sizeof(errorMessage),
- "Unsupported frame format: %d channels", frame.channels());
- throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
- errorMessage);
- }
-
- // Only resize if dimensions don't match
- if (rgb.size() != tensorSize) {
- cv::Mat resized;
- cv::resize(rgb, resized, tensorSize);
- return resized;
- }
-
- return rgb;
-}
-
std::vector
ObjectDetection::postprocess(const std::vector &tensors,
cv::Size originalSize, double detectionThreshold) {
- float widthRatio =
- static_cast(originalSize.width) / modelImageSize.width;
+ const cv::Size inputSize = modelInputSize();
+ float widthRatio = static_cast(originalSize.width) / inputSize.width;
float heightRatio =
- static_cast(originalSize.height) / modelImageSize.height;
+ static_cast(originalSize.height) / inputSize.height;
std::vector detections;
auto bboxTensor = tensors.at(0).toTensor();
@@ -134,14 +98,14 @@ ObjectDetection::runInference(cv::Mat image, double detectionThreshold) {
std::scoped_lock lock(inference_mutex_);
cv::Size originalSize = image.size();
- cv::Mat preprocessed = preprocessFrame(image);
+ cv::Mat preprocessed = preprocess(image);
- const std::vector tensorDims = getAllInputShapes()[0];
auto inputTensor =
(normMean_ && normStd_)
- ? image_processing::getTensorFromMatrix(tensorDims, preprocessed,
- *normMean_, *normStd_)
- : image_processing::getTensorFromMatrix(tensorDims, preprocessed);
+ ? image_processing::getTensorFromMatrix(
+ modelInputShape_, preprocessed, *normMean_, *normStd_)
+ : image_processing::getTensorFromMatrix(modelInputShape_,
+ preprocessed);
auto forwardResult = BaseModel::forward(inputTensor);
if (!forwardResult.ok()) {
@@ -168,9 +132,7 @@ std::vector
ObjectDetection::generateFromFrame(jsi::Runtime &runtime,
const jsi::Value &frameData,
double detectionThreshold) {
- auto frameObj = frameData.asObject(runtime);
- cv::Mat frame = rnexecutorch::utils::extractFrame(runtime, frameObj);
-
+ cv::Mat frame = extractFromFrame(runtime, frameData);
return runInference(frame, detectionThreshold);
}
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h
index f1159d88e0..d94087688a 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h
@@ -77,7 +77,6 @@ class ObjectDetection : public VisionModel {
protected:
std::vector runInference(cv::Mat image,
double detectionThreshold);
- cv::Mat preprocessFrame(const cv::Mat &frame) const override;
private:
/**
@@ -100,9 +99,6 @@ class ObjectDetection : public VisionModel {
cv::Size originalSize,
double detectionThreshold);
- /// Expected input image dimensions derived from the model's input shape.
- cv::Size modelImageSize{0, 0};
-
/// Optional per-channel mean for input normalisation (set in constructor).
std::optional normMean_;
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp
index a521b4e8b0..3c64ba115f 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp
@@ -4,6 +4,7 @@
#include
#include
#include
+#include
namespace rnexecutorch::models::ocr {
OCR::OCR(const std::string &detectorSource, const std::string &recognizerSource,
@@ -12,12 +13,8 @@ OCR::OCR(const std::string &detectorSource, const std::string &recognizerSource,
: detector(detectorSource, callInvoker),
recognitionHandler(recognizerSource, symbols, callInvoker) {}
-std::vector OCR::generate(std::string input) {
- cv::Mat image = image_processing::readImage(input);
- if (image.empty()) {
- throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed,
- "Failed to load image from path: " + input);
- }
+std::vector OCR::runInference(cv::Mat image) {
+ std::scoped_lock lock(inference_mutex_);
/*
1. Detection process returns the list of bounding boxes containing areas
@@ -43,12 +40,46 @@ std::vector OCR::generate(std::string input) {
return result;
}
+std::vector OCR::generateFromString(std::string input) {
+ cv::Mat image = image_processing::readImage(input);
+ if (image.empty()) {
+ throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed,
+ "Failed to load image from path: " + input);
+ }
+ return runInference(image);
+}
+
+std::vector
+OCR::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) {
+ cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData);
+ cv::Mat bgr;
+#ifdef __APPLE__
+ cv::cvtColor(frame, bgr, cv::COLOR_BGRA2BGR);
+#elif defined(__ANDROID__)
+ cv::cvtColor(frame, bgr, cv::COLOR_RGBA2BGR);
+#else
+ throw RnExecutorchError(
+ RnExecutorchErrorCode::PlatformNotSupported,
+ "generateFromFrame is not supported on this platform");
+#endif
+ return runInference(bgr);
+}
+
+std::vector
+OCR::generateFromPixels(JSTensorViewIn pixelData) {
+ cv::Mat image;
+ cv::cvtColor(::rnexecutorch::utils::pixelsToMat(pixelData), image,
+ cv::COLOR_RGB2BGR);
+ return runInference(image);
+}
+
std::size_t OCR::getMemoryLowerBound() const noexcept {
return detector.getMemoryLowerBound() +
recognitionHandler.getMemoryLowerBound();
}
void OCR::unload() noexcept {
+ std::scoped_lock lock(inference_mutex_);
detector.unload();
recognitionHandler.unload();
}
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h
index d84ba903f0..719cb957c4 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h
@@ -1,9 +1,11 @@
#pragma once
+#include
#include
#include
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
+#include
#include
#include
#include
@@ -28,13 +30,20 @@ class OCR final {
const std::string &recognizerSource, const std::string &symbols,
std::shared_ptr callInvoker);
[[nodiscard("Registered non-void function")]] std::vector
- generate(std::string input);
+ generateFromString(std::string input);
+ [[nodiscard("Registered non-void function")]] std::vector
+ generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData);
+ [[nodiscard("Registered non-void function")]] std::vector
+ generateFromPixels(JSTensorViewIn pixelData);
std::size_t getMemoryLowerBound() const noexcept;
void unload() noexcept;
private:
+ std::vector runInference(cv::Mat image);
+
Detector detector;
RecognitionHandler recognitionHandler;
+ mutable std::mutex inference_mutex_;
};
} // namespace models::ocr
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp
index fc6a04ebcd..5ecf7493c4 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp
@@ -1,8 +1,6 @@
#include "BaseSemanticSegmentation.h"
#include "jsi/jsi.h"
-#include
-
#include
#include
#include
@@ -15,7 +13,8 @@ BaseSemanticSegmentation::BaseSemanticSegmentation(
const std::string &modelSource, std::vector normMean,
std::vector normStd, std::vector allClasses,
std::shared_ptr callInvoker)
- : BaseModel(modelSource, callInvoker), allClasses_(std::move(allClasses)) {
+ : VisionModel(modelSource, callInvoker),
+ allClasses_(std::move(allClasses)) {
initModelImageSize();
if (normMean.size() == 3) {
normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]);
@@ -37,46 +36,71 @@ void BaseSemanticSegmentation::initModelImageSize() {
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
"Model seems to not take any input tensors.");
}
- std::vector modelInputShape = inputShapes[0];
- if (modelInputShape.size() < 2) {
+ modelInputShape_ = inputShapes[0];
+ if (modelInputShape_.size() < 2) {
throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions,
"Unexpected model input size, expected at least 2 "
"dimensions but got: " +
- std::to_string(modelInputShape.size()) + ".");
+ std::to_string(modelInputShape_.size()) + ".");
}
- modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1],
- modelInputShape[modelInputShape.size() - 2]);
- numModelPixels = modelImageSize.area();
-}
-
-TensorPtr BaseSemanticSegmentation::preprocess(const std::string &imageSource,
- cv::Size &originalSize) {
- auto [inputTensor, origSize] = image_processing::readImageToTensor(
- imageSource, getAllInputShapes()[0], false, normMean_, normStd_);
- originalSize = origSize;
- return inputTensor;
+ numModelPixels = modelInputSize().area();
}
-std::shared_ptr BaseSemanticSegmentation::generate(
- std::string imageSource,
- std::set> classesOfInterest, bool resize) {
+semantic_segmentation::SegmentationResult
+BaseSemanticSegmentation::runInference(
+ cv::Mat image, cv::Size originalSize,
+ std::set> &classesOfInterest, bool resize) {
+ std::scoped_lock lock(inference_mutex_);
- cv::Size originalSize;
- auto inputTensor = preprocess(imageSource, originalSize);
+ cv::Mat preprocessed = VisionModel::preprocess(image);
+ auto inputTensor =
+ (normMean_ && normStd_)
+ ? image_processing::getTensorFromMatrix(
+ modelInputShape_, preprocessed, *normMean_, *normStd_)
+ : image_processing::getTensorFromMatrix(modelInputShape_,
+ preprocessed);
auto forwardResult = BaseModel::forward(inputTensor);
-
if (!forwardResult.ok()) {
throw RnExecutorchError(forwardResult.error(),
"The model's forward function did not succeed. "
"Ensure the model input is correct.");
}
- return postprocess(forwardResult->at(0).toTensor(), originalSize, allClasses_,
- classesOfInterest, resize);
+ return computeResult(forwardResult->at(0).toTensor(), originalSize,
+ allClasses_, classesOfInterest, resize);
+}
+
+semantic_segmentation::SegmentationResult
+BaseSemanticSegmentation::generateFromString(
+ std::string imageSource,
+ std::set> classesOfInterest, bool resize) {
+ cv::Mat imageBGR = image_processing::readImage(imageSource);
+ cv::Size originalSize = imageBGR.size();
+ cv::Mat imageRGB;
+ cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB);
+
+ return runInference(imageRGB, originalSize, classesOfInterest, resize);
+}
+
+semantic_segmentation::SegmentationResult
+BaseSemanticSegmentation::generateFromPixels(
+ JSTensorViewIn pixelData,
+ std::set> classesOfInterest, bool resize) {
+ cv::Mat image = extractFromPixels(pixelData);
+ return runInference(image, image.size(), classesOfInterest, resize);
}
-std::shared_ptr BaseSemanticSegmentation::postprocess(
+semantic_segmentation::SegmentationResult
+BaseSemanticSegmentation::generateFromFrame(
+ jsi::Runtime &runtime, const jsi::Value &frameData,
+ std::set> classesOfInterest, bool resize) {
+ cv::Mat frame = extractFromFrame(runtime, frameData);
+ return runInference(frame, frame.size(), classesOfInterest, resize);
+}
+
+semantic_segmentation::SegmentationResult
+BaseSemanticSegmentation::computeResult(
const Tensor &tensor, cv::Size originalSize,
std::vector &allClasses,
std::set> &classesOfInterest, bool resize) {
@@ -161,8 +185,8 @@ std::shared_ptr BaseSemanticSegmentation::postprocess(
}
// Filter classes of interest
- auto buffersToReturn = std::make_shared>>();
+ auto buffersToReturn = std::make_shared<
+ std::unordered_map>>();
for (std::size_t cl = 0; cl < resultClasses.size(); ++cl) {
if (cl < allClasses.size() && classesOfInterest.contains(allClasses[cl])) {
(*buffersToReturn)[allClasses[cl]] = resultClasses[cl];
@@ -185,48 +209,7 @@ std::shared_ptr BaseSemanticSegmentation::postprocess(
}
}
- return populateDictionary(argmax, buffersToReturn);
-}
-
-std::shared_ptr BaseSemanticSegmentation::populateDictionary(
- std::shared_ptr argmax,
- std::shared_ptr>>
- classesToOutput) {
- auto promisePtr = std::make_shared>();
- std::future doneFuture = promisePtr->get_future();
-
- std::shared_ptr dictPtr = nullptr;
- callInvoker->invokeAsync(
- [argmax, classesToOutput, &dictPtr, promisePtr](jsi::Runtime &runtime) {
- dictPtr = std::make_shared(runtime);
- auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax);
-
- auto int32ArrayCtor =
- runtime.global().getPropertyAsFunction(runtime, "Int32Array");
- auto int32Array =
- int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer)
- .getObject(runtime);
- dictPtr->setProperty(runtime, "ARGMAX", int32Array);
-
- for (auto &[classLabel, owningBuffer] : *classesToOutput) {
- auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer);
-
- auto float32ArrayCtor =
- runtime.global().getPropertyAsFunction(runtime, "Float32Array");
- auto float32Array =
- float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer)
- .getObject(runtime);
-
- dictPtr->setProperty(
- runtime, jsi::String::createFromAscii(runtime, classLabel.data()),
- float32Array);
- }
- promisePtr->set_value();
- });
-
- doneFuture.wait();
- return dictPtr;
+ return semantic_segmentation::SegmentationResult{argmax, buffersToReturn};
}
} // namespace rnexecutorch::models::semantic_segmentation
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h
index d39a7e5d4a..a30ae375bf 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h
@@ -1,23 +1,20 @@
#pragma once
-#include
-#include
#include
#include
#include
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
-#include
-#include
+#include
+#include
namespace rnexecutorch {
namespace models::semantic_segmentation {
using namespace facebook;
using executorch::aten::Tensor;
-using executorch::extension::TensorPtr;
-class BaseSemanticSegmentation : public BaseModel {
+class BaseSemanticSegmentation : public VisionModel {
public:
BaseSemanticSegmentation(const std::string &modelSource,
std::vector normMean,
@@ -25,33 +22,42 @@ class BaseSemanticSegmentation : public BaseModel {
std::vector allClasses,
std::shared_ptr callInvoker);
- [[nodiscard("Registered non-void function")]] std::shared_ptr
- generate(std::string imageSource,
- std::set> classesOfInterest, bool resize);
+ [[nodiscard("Registered non-void function")]]
+ semantic_segmentation::SegmentationResult
+ generateFromString(std::string imageSource,
+ std::set> classesOfInterest,
+ bool resize);
+
+ [[nodiscard("Registered non-void function")]]
+ semantic_segmentation::SegmentationResult
+ generateFromPixels(JSTensorViewIn pixelData,
+ std::set> classesOfInterest,
+ bool resize);
+
+ [[nodiscard("Registered non-void function")]]
+ semantic_segmentation::SegmentationResult
+ generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData,
+ std::set> classesOfInterest,
+ bool resize);
protected:
- virtual TensorPtr preprocess(const std::string &imageSource,
- cv::Size &originalSize);
- virtual std::shared_ptr
- postprocess(const Tensor &tensor, cv::Size originalSize,
- std::vector &allClasses,
- std::set> &classesOfInterest,
- bool resize);
-
- cv::Size modelImageSize;
+ virtual semantic_segmentation::SegmentationResult
+ computeResult(const Tensor &tensor, cv::Size originalSize,
+ std::vector &allClasses,
+ std::set> &classesOfInterest,
+ bool resize);
std::size_t numModelPixels;
std::optional normMean_;
std::optional normStd_;
std::vector allClasses_;
- std::shared_ptr populateDictionary(
- std::shared_ptr argmax,
- std::shared_ptr>>
- classesToOutput);
-
private:
void initModelImageSize();
+
+ semantic_segmentation::SegmentationResult
+ runInference(cv::Mat image, cv::Size originalSize,
+ std::set> &classesOfInterest,
+ bool resize);
};
} // namespace models::semantic_segmentation
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h
new file mode 100644
index 0000000000..b305b96a70
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+
+namespace rnexecutorch::models::semantic_segmentation {
+
+struct SegmentationResult {
+ std::shared_ptr argmax;
+ std::shared_ptr<
+ std::unordered_map>>
+ classBuffers;
+};
+
+} // namespace rnexecutorch::models::semantic_segmentation
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp
index 3b9c0187b9..ad94e76ba4 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp
@@ -13,37 +13,31 @@ using executorch::extension::TensorPtr;
StyleTransfer::StyleTransfer(const std::string &modelSource,
std::shared_ptr callInvoker)
- : BaseModel(modelSource, callInvoker) {
+ : VisionModel(modelSource, callInvoker) {
auto inputShapes = getAllInputShapes();
if (inputShapes.size() == 0) {
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
"Model seems to not take any input tensors");
}
- std::vector modelInputShape = inputShapes[0];
- if (modelInputShape.size() < 2) {
+ modelInputShape_ = inputShapes[0];
+ if (modelInputShape_.size() < 2) {
char errorMessage[100];
std::snprintf(errorMessage, sizeof(errorMessage),
- "Unexpected model input size, expected at least 2 dimentions "
+ "Unexpected model input size, expected at least 2 dimensions "
"but got: %zu.",
- modelInputShape.size());
+ modelInputShape_.size());
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
errorMessage);
}
- modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1],
- modelInputShape[modelInputShape.size() - 2]);
}
-std::string StyleTransfer::postprocess(const Tensor &tensor,
- cv::Size originalSize) {
- cv::Mat mat = image_processing::getMatrixFromTensor(modelImageSize, tensor);
- cv::resize(mat, mat, originalSize);
+cv::Mat StyleTransfer::runInference(cv::Mat image, cv::Size outputSize) {
+ std::scoped_lock lock(inference_mutex_);
- return image_processing::saveToTempFile(mat);
-}
+ cv::Mat preprocessed = preprocess(image);
-std::string StyleTransfer::generate(std::string imageSource) {
- auto [inputTensor, originalSize] =
- image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]);
+ auto inputTensor =
+ image_processing::getTensorFromMatrix(modelInputShape_, preprocessed);
auto forwardResult = BaseModel::forward(inputTensor);
if (!forwardResult.ok()) {
@@ -52,7 +46,55 @@ std::string StyleTransfer::generate(std::string imageSource) {
"Ensure the model input is correct.");
}
- return postprocess(forwardResult->at(0).toTensor(), originalSize);
+ cv::Mat mat = image_processing::getMatrixFromTensor(
+ modelInputSize(), forwardResult->at(0).toTensor());
+ if (mat.size() != outputSize) {
+ cv::resize(mat, mat, outputSize);
+ }
+ return mat;
+}
+
+PixelDataResult toPixelDataResult(const cv::Mat &bgrMat) {
+ cv::Size size = bgrMat.size();
+ // Convert BGR -> RGBA so JS can pass the buffer directly to Skia
+ cv::Mat rgba;
+ cv::cvtColor(bgrMat, rgba, cv::COLOR_BGR2RGBA);
+ std::size_t dataSize = static_cast(size.width) * size.height * 4;
+ auto pixelBuffer = std::make_shared(rgba.data, dataSize);
+ return PixelDataResult{pixelBuffer, size.width, size.height, rgba.channels()};
+}
+
+StyleTransferResult StyleTransfer::generateFromString(std::string imageSource,
+ bool saveToFile) {
+ cv::Mat imageBGR = image_processing::readImage(imageSource);
+ cv::Size originalSize = imageBGR.size();
+
+ cv::Mat imageRGB;
+ cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB);
+
+ cv::Mat result = runInference(imageRGB, originalSize);
+ if (saveToFile) {
+ return image_processing::saveToTempFile(result);
+ }
+ return toPixelDataResult(result);
+}
+
+PixelDataResult StyleTransfer::generateFromFrame(jsi::Runtime &runtime,
+ const jsi::Value &frameData) {
+ cv::Mat frame = extractFromFrame(runtime, frameData);
+
+ return toPixelDataResult(runInference(frame, modelInputSize()));
+}
+
+StyleTransferResult StyleTransfer::generateFromPixels(JSTensorViewIn pixelData,
+ bool saveToFile) {
+ cv::Mat image = extractFromPixels(pixelData);
+
+ cv::Mat result = runInference(image, image.size());
+ if (saveToFile) {
+ return image_processing::saveToTempFile(result);
+ }
+ return toPixelDataResult(result);
}
} // namespace rnexecutorch::models::style_transfer
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h
index 73744c4d82..c15095bf5b 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h
@@ -9,25 +9,30 @@
#include
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
-#include
+#include
+#include
+#include
namespace rnexecutorch {
namespace models::style_transfer {
using namespace facebook;
-using executorch::aten::Tensor;
-using executorch::extension::TensorPtr;
-class StyleTransfer : public BaseModel {
+class StyleTransfer : public VisionModel {
public:
StyleTransfer(const std::string &modelSource,
std::shared_ptr callInvoker);
- [[nodiscard("Registered non-void function")]] std::string
- generate(std::string imageSource);
-private:
- std::string postprocess(const Tensor &tensor, cv::Size originalSize);
+ [[nodiscard("Registered non-void function")]] StyleTransferResult
+ generateFromString(std::string imageSource, bool saveToFile);
+
+ [[nodiscard("Registered non-void function")]] PixelDataResult
+ generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData);
- cv::Size modelImageSize{0, 0};
+ [[nodiscard("Registered non-void function")]] StyleTransferResult
+ generateFromPixels(JSTensorViewIn pixelData, bool saveToFile);
+
+private:
+ cv::Mat runInference(cv::Mat image, cv::Size outputSize);
};
} // namespace models::style_transfer
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h
new file mode 100644
index 0000000000..27df4ec6c6
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+
+namespace rnexecutorch::models::style_transfer {
+
+struct PixelDataResult {
+ std::shared_ptr dataPtr;
+ int width;
+ int height;
+ int channels;
+};
+
+using StyleTransferResult = std::variant;
+
+} // namespace rnexecutorch::models::style_transfer
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp
index 0f75d20152..fef78d7953 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp
@@ -1,10 +1,11 @@
#include "VerticalOCR.h"
-#include
#include
+#include
#include
#include
#include
#include
+#include
#include
namespace rnexecutorch::models::ocr {
@@ -16,12 +17,9 @@ VerticalOCR::VerticalOCR(const std::string &detectorSource,
converter(symbols), independentCharacters(independentChars),
callInvoker(invoker) {}
-std::vector VerticalOCR::generate(std::string input) {
- cv::Mat image = image_processing::readImage(input);
- if (image.empty()) {
- throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed,
- "Failed to load image from path: " + input);
- }
+std::vector VerticalOCR::runInference(cv::Mat image) {
+ std::scoped_lock lock(inference_mutex_);
+
// 1. Large Detector
std::vector largeBoxes =
detector.generate(image, constants::kLargeDetectorWidth);
@@ -44,6 +42,41 @@ std::vector VerticalOCR::generate(std::string input) {
return predictions;
}
+std::vector
+VerticalOCR::generateFromString(std::string input) {
+ cv::Mat image = image_processing::readImage(input);
+ if (image.empty()) {
+ throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed,
+ "Failed to load image from path: " + input);
+ }
+ return runInference(image);
+}
+
+std::vector
+VerticalOCR::generateFromFrame(jsi::Runtime &runtime,
+ const jsi::Value &frameData) {
+ cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData);
+ cv::Mat bgr;
+#ifdef __APPLE__
+ cv::cvtColor(frame, bgr, cv::COLOR_BGRA2BGR);
+#elif defined(__ANDROID__)
+ cv::cvtColor(frame, bgr, cv::COLOR_RGBA2BGR);
+#else
+ throw RnExecutorchError(
+ RnExecutorchErrorCode::PlatformNotSupported,
+ "generateFromFrame is not supported on this platform");
+#endif
+ return runInference(bgr);
+}
+
+std::vector
+VerticalOCR::generateFromPixels(JSTensorViewIn pixelData) {
+ cv::Mat image;
+ cv::cvtColor(::rnexecutorch::utils::pixelsToMat(pixelData), image,
+ cv::COLOR_RGB2BGR);
+ return runInference(image);
+}
+
std::size_t VerticalOCR::getMemoryLowerBound() const noexcept {
return detector.getMemoryLowerBound() + recognizer.getMemoryLowerBound();
}
@@ -176,6 +209,7 @@ types::OCRDetection VerticalOCR::_processSingleTextBox(
}
void VerticalOCR::unload() noexcept {
+ std::scoped_lock lock(inference_mutex_);
detector.unload();
recognizer.unload();
}
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h
index e97fb90348..4016e28138 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h
@@ -1,12 +1,14 @@
#pragma once
#include
+#include
#include
#include
#include
#include
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
+#include
#include
#include
#include
@@ -48,11 +50,17 @@ class VerticalOCR final {
bool indpendentCharacters,
std::shared_ptr callInvoker);
[[nodiscard("Registered non-void function")]] std::vector
- generate(std::string input);
+ generateFromString(std::string input);
+ [[nodiscard("Registered non-void function")]] std::vector
+ generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData);
+ [[nodiscard("Registered non-void function")]] std::vector
+ generateFromPixels(JSTensorViewIn pixelData);
std::size_t getMemoryLowerBound() const noexcept;
void unload() noexcept;
private:
+ std::vector runInference(cv::Mat image);
+
std::pair _handleIndependentCharacters(
const types::DetectorBBox &box, const cv::Mat &originalImage,
const std::vector &characterBoxes,
@@ -75,6 +83,7 @@ class VerticalOCR final {
CTCLabelConverter converter;
bool independentCharacters;
std::shared_ptr callInvoker;
+ mutable std::mutex inference_mutex_;
};
} // namespace models::ocr
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
index 426aafc1f3..bd359dcb8f 100644
--- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
@@ -157,13 +157,33 @@ add_rn_test(ImageProcessingTest unit/ImageProcessingTest.cpp
LIBS opencv_deps
)
+add_rn_test(FrameProcessorTests unit/FrameProcessorTest.cpp
+ SOURCES
+ ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
+ ${IMAGE_UTILS_SOURCES}
+ LIBS opencv_deps android
+)
+
add_rn_test(BaseModelTests integration/BaseModelTest.cpp)
+add_rn_test(VisionModelTests integration/VisionModelTest.cpp
+ SOURCES
+ ${RNEXECUTORCH_DIR}/models/VisionModel.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
+ ${IMAGE_UTILS_SOURCES}
+ LIBS opencv_deps android
+)
+
add_rn_test(ClassificationTests integration/ClassificationTest.cpp
SOURCES
${RNEXECUTORCH_DIR}/models/classification/Classification.cpp
+ ${RNEXECUTORCH_DIR}/models/VisionModel.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
${IMAGE_UTILS_SOURCES}
- LIBS opencv_deps
+ LIBS opencv_deps android
)
add_rn_test(ObjectDetectionTests integration/ObjectDetectionTest.cpp
@@ -181,8 +201,11 @@ add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp
SOURCES
${RNEXECUTORCH_DIR}/models/embeddings/image/ImageEmbeddings.cpp
${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp
+ ${RNEXECUTORCH_DIR}/models/VisionModel.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
${IMAGE_UTILS_SOURCES}
- LIBS opencv_deps
+ LIBS opencv_deps android
)
add_rn_test(TextEmbeddingsTests integration/TextEmbeddingsTest.cpp
@@ -196,8 +219,11 @@ add_rn_test(TextEmbeddingsTests integration/TextEmbeddingsTest.cpp
add_rn_test(StyleTransferTests integration/StyleTransferTest.cpp
SOURCES
${RNEXECUTORCH_DIR}/models/style_transfer/StyleTransfer.cpp
+ ${RNEXECUTORCH_DIR}/models/VisionModel.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
${IMAGE_UTILS_SOURCES}
- LIBS opencv_deps
+ LIBS opencv_deps android
)
add_rn_test(VADTests integration/VoiceActivityDetectionTest.cpp
@@ -273,8 +299,10 @@ add_rn_test(OCRTests integration/OCRTest.cpp
${RNEXECUTORCH_DIR}/models/ocr/utils/DetectorUtils.cpp
${RNEXECUTORCH_DIR}/models/ocr/utils/RecognitionHandlerUtils.cpp
${RNEXECUTORCH_DIR}/models/ocr/utils/RecognizerUtils.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
${IMAGE_UTILS_SOURCES}
- LIBS opencv_deps
+ LIBS opencv_deps android
)
add_rn_test(VerticalOCRTests integration/VerticalOCRTest.cpp
@@ -287,6 +315,8 @@ add_rn_test(VerticalOCRTests integration/VerticalOCRTest.cpp
${RNEXECUTORCH_DIR}/models/ocr/utils/DetectorUtils.cpp
${RNEXECUTORCH_DIR}/models/ocr/utils/RecognitionHandlerUtils.cpp
${RNEXECUTORCH_DIR}/models/ocr/utils/RecognizerUtils.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
+ ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
${IMAGE_UTILS_SOURCES}
- LIBS opencv_deps
+ LIBS opencv_deps android
)
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp
index 10aa663a4a..d164fcacb0 100644
--- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp
@@ -1,6 +1,9 @@
#include "BaseModelTests.h"
+#include "VisionModelTests.h"
+#include
#include
#include
+#include
#include
#include
@@ -28,7 +31,7 @@ template <> struct ModelTraits {
}
static void callGenerate(ModelType &model) {
- (void)model.generate(kValidTestImagePath);
+ (void)model.generateFromString(kValidTestImagePath);
}
};
} // namespace model_tests
@@ -36,43 +39,45 @@ template <> struct ModelTraits {
using ClassificationTypes = ::testing::Types;
INSTANTIATE_TYPED_TEST_SUITE_P(Classification, CommonModelTest,
ClassificationTypes);
+INSTANTIATE_TYPED_TEST_SUITE_P(Classification, VisionModelTest,
+ ClassificationTypes);
// ============================================================================
// Model-specific tests
// ============================================================================
TEST(ClassificationGenerateTests, InvalidImagePathThrows) {
Classification model(kValidClassificationModelPath, nullptr);
- EXPECT_THROW((void)model.generate("nonexistent_image.jpg"),
+ EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"),
RnExecutorchError);
}
TEST(ClassificationGenerateTests, EmptyImagePathThrows) {
Classification model(kValidClassificationModelPath, nullptr);
- EXPECT_THROW((void)model.generate(""), RnExecutorchError);
+ EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError);
}
TEST(ClassificationGenerateTests, MalformedURIThrows) {
Classification model(kValidClassificationModelPath, nullptr);
- EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"),
+ EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"),
RnExecutorchError);
}
TEST(ClassificationGenerateTests, ValidImageReturnsResults) {
Classification model(kValidClassificationModelPath, nullptr);
- auto results = model.generate(kValidTestImagePath);
+ auto results = model.generateFromString(kValidTestImagePath);
EXPECT_FALSE(results.empty());
}
TEST(ClassificationGenerateTests, ResultsHaveCorrectSize) {
Classification model(kValidClassificationModelPath, nullptr);
- auto results = model.generate(kValidTestImagePath);
+ auto results = model.generateFromString(kValidTestImagePath);
auto expectedNumClasses = constants::kImagenet1kV1Labels.size();
EXPECT_EQ(results.size(), expectedNumClasses);
}
TEST(ClassificationGenerateTests, ResultsContainValidProbabilities) {
Classification model(kValidClassificationModelPath, nullptr);
- auto results = model.generate(kValidTestImagePath);
+ auto results = model.generateFromString(kValidTestImagePath);
float sum = 0.0f;
for (const auto &[label, prob] : results) {
@@ -85,7 +90,7 @@ TEST(ClassificationGenerateTests, ResultsContainValidProbabilities) {
TEST(ClassificationGenerateTests, TopPredictionHasReasonableConfidence) {
Classification model(kValidClassificationModelPath, nullptr);
- auto results = model.generate(kValidTestImagePath);
+ auto results = model.generateFromString(kValidTestImagePath);
float maxProb = 0.0f;
for (const auto &[label, prob] : results) {
@@ -115,3 +120,15 @@ TEST(ClassificationInheritedTests, GetMethodMetaWorks) {
auto result = model.getMethodMeta("forward");
EXPECT_TRUE(result.ok());
}
+
+// ============================================================================
+// generateFromPixels smoke test
+// ============================================================================
+TEST(ClassificationPixelTests, ValidPixelsReturnsResults) {
+ Classification model(kValidClassificationModelPath, nullptr);
+ std::vector buf(64 * 64 * 3, 128);
+ JSTensorViewIn view{
+ buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte};
+ auto results = model.generateFromPixels(view);
+ EXPECT_FALSE(results.empty());
+}
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp
index 3a23746957..4982206614 100644
--- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp
@@ -1,7 +1,10 @@
#include "BaseModelTests.h"
+#include "VisionModelTests.h"
#include
+#include
#include
#include
+#include
#include