diff --git a/apps/computer-vision/app/_layout.tsx b/apps/computer-vision/app/_layout.tsx index 4ce2f3e5c2..cac2692f07 100644 --- a/apps/computer-vision/app/_layout.tsx +++ b/apps/computer-vision/app/_layout.tsx @@ -59,6 +59,15 @@ export default function _layout() { headerTitleStyle: { color: ColorPalette.primary }, }} > + - Select a demo model + router.navigate('vision_camera/')} + > + Vision Camera + router.navigate('classification/')} @@ -29,12 +35,6 @@ export default function Home() { > Object Detection - router.navigate('object_detection_live/')} - > - Object Detection Live - router.navigate('ocr/')} diff --git a/apps/computer-vision/app/object_detection_live/index.tsx b/apps/computer-vision/app/object_detection_live/index.tsx deleted file mode 100644 index 3db2c53602..0000000000 --- a/apps/computer-vision/app/object_detection_live/index.tsx +++ /dev/null @@ -1,222 +0,0 @@ -import React, { - useCallback, - useContext, - useEffect, - useMemo, - useRef, - useState, -} from 'react'; -import { - StatusBar, - StyleSheet, - Text, - TouchableOpacity, - View, -} from 'react-native'; -import { useSafeAreaInsets } from 'react-native-safe-area-context'; - -import { - Camera, - getCameraFormat, - Templates, - useCameraDevices, - useCameraPermission, - useFrameOutput, -} from 'react-native-vision-camera'; -import { scheduleOnRN } from 'react-native-worklets'; -import { - Detection, - SSDLITE_320_MOBILENET_V3_LARGE, - useObjectDetection, -} from 'react-native-executorch'; -import { GeneratingContext } from '../../context'; -import Spinner from '../../components/Spinner'; -import ColorPalette from '../../colors'; - -export default function ObjectDetectionLiveScreen() { - const insets = useSafeAreaInsets(); - - const model = useObjectDetection({ model: SSDLITE_320_MOBILENET_V3_LARGE }); - const { setGlobalGenerating } = useContext(GeneratingContext); - - useEffect(() => { - setGlobalGenerating(model.isGenerating); - }, [model.isGenerating, setGlobalGenerating]); - const [detectionCount, setDetectionCount] = useState(0); - const [fps, setFps] = useState(0); - const lastFrameTimeRef = useRef(Date.now()); - - const cameraPermission = useCameraPermission(); - const devices = useCameraDevices(); - const device = devices.find((d) => d.position === 'back') ?? devices[0]; - - const format = useMemo(() => { - if (device == null) return undefined; - try { - return getCameraFormat(device, Templates.FrameProcessing); - } catch { - return undefined; - } - }, [device]); - - const updateStats = useCallback((results: Detection[]) => { - setDetectionCount(results.length); - const now = Date.now(); - const timeDiff = now - lastFrameTimeRef.current; - if (timeDiff > 0) { - setFps(Math.round(1000 / timeDiff)); - } - lastFrameTimeRef.current = now; - }, []); - - const frameOutput = useFrameOutput({ - pixelFormat: 'rgb', - dropFramesWhileBusy: true, - onFrame(frame) { - 'worklet'; - if (!model.runOnFrame) { - frame.dispose(); - return; - } - try { - const result = model.runOnFrame(frame, 0.5); - if (result) { - scheduleOnRN(updateStats, result); - } - } catch { - // ignore frame errors - } finally { - frame.dispose(); - } - }, - }); - - if (!model.isReady) { - return ( - - ); - } - - if (!cameraPermission.hasPermission) { - return ( - - Camera access needed - cameraPermission.requestPermission()} - style={styles.button} - > - Grant Permission - - - ); - } - - if (device == null) { - return ( - - No camera device found - - ); - } - - return ( - - - - - - - - - {detectionCount} - objects - - - - {fps} - fps - - - - - ); -} - -const styles = StyleSheet.create({ - container: { - flex: 1, - backgroundColor: 'black', - }, - centered: { - flex: 1, - backgroundColor: 'black', - justifyContent: 'center', - alignItems: 'center', - gap: 16, - }, - message: { - color: 'white', - fontSize: 18, - }, - button: { - paddingHorizontal: 24, - paddingVertical: 12, - backgroundColor: ColorPalette.primary, - borderRadius: 24, - }, - buttonText: { - color: 'white', - fontSize: 15, - fontWeight: '600', - letterSpacing: 0.3, - }, - bottomBarWrapper: { - position: 'absolute', - bottom: 0, - left: 0, - right: 0, - alignItems: 'center', - }, - bottomBar: { - flexDirection: 'row', - alignItems: 'center', - backgroundColor: 'rgba(0, 0, 0, 0.55)', - borderRadius: 24, - paddingHorizontal: 28, - paddingVertical: 10, - gap: 24, - }, - statItem: { - alignItems: 'center', - }, - statValue: { - color: 'white', - fontSize: 22, - fontWeight: '700', - letterSpacing: -0.5, - }, - statLabel: { - color: 'rgba(255,255,255,0.55)', - fontSize: 11, - fontWeight: '500', - textTransform: 'uppercase', - letterSpacing: 0.8, - }, - statDivider: { - width: 1, - height: 32, - backgroundColor: 'rgba(255,255,255,0.2)', - }, -}); diff --git a/apps/computer-vision/app/style_transfer/index.tsx b/apps/computer-vision/app/style_transfer/index.tsx index dc6a0d4963..46ae3e814a 100644 --- a/apps/computer-vision/app/style_transfer/index.tsx +++ b/apps/computer-vision/app/style_transfer/index.tsx @@ -16,20 +16,24 @@ export default function StyleTransferScreen() { useEffect(() => { setGlobalGenerating(model.isGenerating); }, [model.isGenerating, setGlobalGenerating]); + const [imageUri, setImageUri] = useState(''); + const [styledUri, setStyledUri] = useState(''); + const handleCameraPress = async (isCamera: boolean) => { const image = await getImage(isCamera); const uri = image?.uri; if (typeof uri === 'string') { - setImageUri(uri as string); + setImageUri(uri); + setStyledUri(''); } }; const runForward = async () => { if (imageUri) { try { - const output = await model.forward(imageUri); - setImageUri(output); + const uri = await model.forward(imageUri, 'url'); + setStyledUri(uri); } catch (e) { console.error(e); } @@ -52,9 +56,11 @@ export default function StyleTransferScreen() { style={styles.image} resizeMode="contain" source={ - imageUri - ? { uri: imageUri } - : require('../../assets/icons/executorch_logo.png') + styledUri + ? { uri: styledUri } + : imageUri + ? { uri: imageUri } + : require('../../assets/icons/executorch_logo.png') } /> diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx new file mode 100644 index 0000000000..b2af60d504 --- /dev/null +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -0,0 +1,442 @@ +import React, { + useCallback, + useContext, + useEffect, + useMemo, + useState, +} from 'react'; +import { + ScrollView, + StatusBar, + StyleSheet, + Text, + TouchableOpacity, + View, +} from 'react-native'; +import { useSafeAreaInsets } from 'react-native-safe-area-context'; +import { useIsFocused } from '@react-navigation/native'; +import { + Camera, + getCameraFormat, + Templates, + useCameraDevices, + useCameraPermission, + useFrameOutput, +} from 'react-native-vision-camera'; +import { createSynchronizable } from 'react-native-worklets'; +import Svg, { Path, Polygon } from 'react-native-svg'; +import { GeneratingContext } from '../../context'; +import Spinner from '../../components/Spinner'; +import ColorPalette from '../../colors'; +import ClassificationTask from '../../components/vision_camera/tasks/ClassificationTask'; +import ObjectDetectionTask from '../../components/vision_camera/tasks/ObjectDetectionTask'; +import SegmentationTask from '../../components/vision_camera/tasks/SegmentationTask'; + +type TaskId = 'classification' | 'objectDetection' | 'segmentation'; +type ModelId = + | 'classification' + | 'objectDetectionSsdlite' + | 'objectDetectionRfdetr' + | 'segmentationDeeplabResnet50' + | 'segmentationDeeplabResnet101' + | 'segmentationDeeplabMobilenet' + | 'segmentationLraspp' + | 'segmentationFcnResnet50' + | 'segmentationFcnResnet101' + | 'segmentationSelfie'; + +type TaskVariant = { id: ModelId; label: string }; +type Task = { id: TaskId; label: string; variants: TaskVariant[] }; + +const TASKS: Task[] = [ + { + id: 'classification', + label: 'Classify', + variants: [{ id: 'classification', label: 'EfficientNet V2 S' }], + }, + { + id: 'segmentation', + label: 'Segment', + variants: [ + { id: 'segmentationDeeplabResnet50', label: 'DeepLab ResNet50' }, + { id: 'segmentationDeeplabResnet101', label: 'DeepLab ResNet101' }, + { id: 'segmentationDeeplabMobilenet', label: 'DeepLab MobileNet' }, + { id: 'segmentationLraspp', label: 'LRASPP MobileNet' }, + { id: 'segmentationFcnResnet50', label: 'FCN ResNet50' }, + { id: 'segmentationFcnResnet101', label: 'FCN ResNet101' }, + { id: 'segmentationSelfie', label: 'Selfie' }, + ], + }, + { + id: 'objectDetection', + label: 'Detect', + variants: [ + { id: 'objectDetectionSsdlite', label: 'SSDLite MobileNet' }, + { id: 'objectDetectionRfdetr', label: 'RF-DETR Nano' }, + ], + }, +]; + +// Module-level const so worklets in task components can always reference the same stable object. +// Never replaced — only mutated via setBlocking to avoid closure staleness. +const frameKillSwitch = createSynchronizable(false); + +export default function VisionCameraScreen() { + const insets = useSafeAreaInsets(); + const [activeTask, setActiveTask] = useState('classification'); + const [activeModel, setActiveModel] = useState('classification'); + const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); + const [cameraPosition, setCameraPosition] = useState<'front' | 'back'>( + 'back' + ); + const [fps, setFps] = useState(0); + const [frameMs, setFrameMs] = useState(0); + const [isReady, setIsReady] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); + const [frameOutput, setFrameOutput] = useState | null>(null); + const { setGlobalGenerating } = useContext(GeneratingContext); + + const isFocused = useIsFocused(); + const cameraPermission = useCameraPermission(); + const devices = useCameraDevices(); + const device = + devices.find((d) => d.position === cameraPosition) ?? devices[0]; + const format = useMemo(() => { + if (device == null) return undefined; + try { + return getCameraFormat(device, Templates.FrameProcessing); + } catch { + return undefined; + } + }, [device]); + + useEffect(() => { + frameKillSwitch.setBlocking(true); + const id = setTimeout(() => { + frameKillSwitch.setBlocking(false); + }, 300); + return () => clearTimeout(id); + }, [activeModel]); + + const handleFpsChange = useCallback((newFps: number, newMs: number) => { + setFps(newFps); + setFrameMs(newMs); + }, []); + + const handleGeneratingChange = useCallback( + (generating: boolean) => { + setGlobalGenerating(generating); + }, + [setGlobalGenerating] + ); + + if (!cameraPermission.hasPermission) { + return ( + + Camera access needed + cameraPermission.requestPermission()} + style={styles.button} + > + Grant Permission + + + ); + } + + if (device == null) { + return ( + + No camera device found + + ); + } + + const activeTaskInfo = TASKS.find((t) => t.id === activeTask)!; + const activeVariantLabel = + activeTaskInfo.variants.find((v) => v.id === activeModel)?.label ?? + activeTaskInfo.variants[0]!.label; + + const taskProps = { + activeModel, + canvasSize, + cameraPosition, + frameKillSwitch, + onFrameOutputChange: setFrameOutput, + onReadyChange: setIsReady, + onProgressChange: setDownloadProgress, + onGeneratingChange: handleGeneratingChange, + onFpsChange: handleFpsChange, + }; + + return ( + + + + + + {/* Layout sentinel — measures the full-screen area for bbox/canvas sizing */} + + setCanvasSize({ + width: e.nativeEvent.layout.width, + height: e.nativeEvent.layout.height, + }) + } + /> + + {activeTask === 'classification' && } + {activeTask === 'objectDetection' && ( + + )} + {activeTask === 'segmentation' && ( + + )} + + {!isReady && ( + + + + )} + + + + {activeVariantLabel} + + {fps} FPS – {frameMs.toFixed(0)} ms + + + + + {TASKS.map((t) => ( + { + setActiveTask(t.id); + setActiveModel(t.variants[0]!.id); + }} + > + + {t.label} + + + ))} + + + + {activeTaskInfo.variants.map((v) => ( + setActiveModel(v.id)} + > + + {v.label} + + + ))} + + + + + + setCameraPosition((p) => (p === 'back' ? 'front' : 'back')) + } + > + + + + + + + + + ); +} + +const styles = StyleSheet.create({ + container: { flex: 1, backgroundColor: 'black' }, + centered: { + flex: 1, + backgroundColor: 'black', + justifyContent: 'center', + alignItems: 'center', + gap: 16, + }, + message: { color: 'white', fontSize: 18 }, + button: { + paddingHorizontal: 24, + paddingVertical: 12, + backgroundColor: ColorPalette.primary, + borderRadius: 24, + }, + buttonText: { color: 'white', fontSize: 15, fontWeight: '600' }, + loadingOverlay: { + ...StyleSheet.absoluteFillObject, + backgroundColor: 'rgba(0,0,0,0.6)', + justifyContent: 'center', + alignItems: 'center', + }, + topOverlay: { + position: 'absolute', + top: 0, + left: 0, + right: 0, + alignItems: 'center', + gap: 8, + }, + titleRow: { + alignItems: 'center', + paddingHorizontal: 16, + }, + modelTitle: { + color: 'white', + fontSize: 22, + fontWeight: '700', + textShadowColor: 'rgba(0,0,0,0.7)', + textShadowOffset: { width: 0, height: 1 }, + textShadowRadius: 4, + }, + fpsText: { + color: 'rgba(255,255,255,0.85)', + fontSize: 14, + fontWeight: '500', + marginTop: 2, + textShadowColor: 'rgba(0,0,0,0.7)', + textShadowOffset: { width: 0, height: 1 }, + textShadowRadius: 4, + }, + tabsContent: { + paddingHorizontal: 12, + gap: 6, + }, + tab: { + paddingHorizontal: 18, + paddingVertical: 7, + borderRadius: 20, + backgroundColor: 'rgba(0,0,0,0.45)', + borderWidth: 1, + borderColor: 'rgba(255,255,255,0.25)', + }, + tabActive: { + backgroundColor: 'rgba(255,255,255,0.2)', + borderColor: 'white', + }, + tabText: { + color: 'rgba(255,255,255,0.7)', + fontSize: 14, + fontWeight: '600', + }, + tabTextActive: { color: 'white' }, + chipsContent: { + paddingHorizontal: 12, + gap: 6, + }, + variantChip: { + paddingHorizontal: 14, + paddingVertical: 5, + borderRadius: 16, + backgroundColor: 'rgba(0,0,0,0.35)', + borderWidth: 1, + borderColor: 'rgba(255,255,255,0.15)', + }, + variantChipActive: { + backgroundColor: ColorPalette.primary, + borderColor: ColorPalette.primary, + }, + variantChipText: { + color: 'rgba(255,255,255,0.6)', + fontSize: 12, + fontWeight: '500', + }, + variantChipTextActive: { color: 'white' }, + bottomOverlay: { + position: 'absolute', + bottom: 0, + left: 0, + right: 0, + alignItems: 'center', + }, + flipButton: { + width: 56, + height: 56, + borderRadius: 28, + backgroundColor: 'rgba(255,255,255,0.2)', + justifyContent: 'center', + alignItems: 'center', + borderWidth: 1.5, + borderColor: 'rgba(255,255,255,0.4)', + }, +}); diff --git a/apps/computer-vision/components/vision_camera/tasks/ClassificationTask.tsx b/apps/computer-vision/components/vision_camera/tasks/ClassificationTask.tsx new file mode 100644 index 0000000000..c9b4a2bf21 --- /dev/null +++ b/apps/computer-vision/components/vision_camera/tasks/ClassificationTask.tsx @@ -0,0 +1,120 @@ +import React, { useCallback, useEffect, useRef, useState } from 'react'; +import { StyleSheet, Text, View } from 'react-native'; +import { Frame, useFrameOutput } from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { EFFICIENTNET_V2_S, useClassification } from 'react-native-executorch'; +import { TaskProps } from './types'; + +type Props = Omit; + +export default function ClassificationTask({ + frameKillSwitch, + onFrameOutputChange, + onReadyChange, + onProgressChange, + onGeneratingChange, + onFpsChange, +}: Props) { + const model = useClassification({ model: EFFICIENTNET_V2_S }); + const [classResult, setClassResult] = useState({ label: '', score: 0 }); + const lastFrameTimeRef = useRef(Date.now()); + + useEffect(() => { + onReadyChange(model.isReady); + }, [model.isReady, onReadyChange]); + + useEffect(() => { + onProgressChange(model.downloadProgress); + }, [model.downloadProgress, onProgressChange]); + + useEffect(() => { + onGeneratingChange(model.isGenerating); + }, [model.isGenerating, onGeneratingChange]); + + const classRof = model.runOnFrame; + + const updateClass = useCallback( + (r: { label: string; score: number }) => { + setClassResult(r); + const now = Date.now(); + const diff = now - lastFrameTimeRef.current; + if (diff > 0) onFpsChange(Math.round(1000 / diff), diff); + lastFrameTimeRef.current = now; + }, + [onFpsChange] + ); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + if (frameKillSwitch.getDirty()) { + frame.dispose(); + return; + } + try { + if (!classRof) return; + const result = classRof(frame); + if (result) { + let bestLabel = ''; + let bestScore = -1; + const entries = Object.entries(result); + for (let i = 0; i < entries.length; i++) { + const [label, score] = entries[i]!; + if ((score as number) > bestScore) { + bestScore = score as number; + bestLabel = label; + } + } + scheduleOnRN(updateClass, { label: bestLabel, score: bestScore }); + } + } catch { + // ignore + } finally { + frame.dispose(); + } + }, + [classRof, frameKillSwitch, updateClass] + ), + }); + + useEffect(() => { + onFrameOutputChange(frameOutput); + }, [frameOutput, onFrameOutputChange]); + + return classResult.label ? ( + + {classResult.label} + {(classResult.score * 100).toFixed(1)}% + + ) : null; +} + +const styles = StyleSheet.create({ + overlay: { + ...StyleSheet.absoluteFillObject, + justifyContent: 'center', + alignItems: 'center', + }, + label: { + color: 'white', + fontSize: 28, + fontWeight: '700', + textAlign: 'center', + textShadowColor: 'rgba(0,0,0,0.8)', + textShadowOffset: { width: 0, height: 1 }, + textShadowRadius: 6, + paddingHorizontal: 24, + }, + score: { + color: 'rgba(255,255,255,0.75)', + fontSize: 18, + fontWeight: '500', + marginTop: 4, + textShadowColor: 'rgba(0,0,0,0.8)', + textShadowOffset: { width: 0, height: 1 }, + textShadowRadius: 6, + }, +}); diff --git a/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx new file mode 100644 index 0000000000..a54d20c87e --- /dev/null +++ b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx @@ -0,0 +1,174 @@ +import React, { useCallback, useEffect, useRef, useState } from 'react'; +import { StyleSheet, Text, View } from 'react-native'; +import { Frame, useFrameOutput } from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { + Detection, + RF_DETR_NANO, + SSDLITE_320_MOBILENET_V3_LARGE, + useObjectDetection, +} from 'react-native-executorch'; +import { labelColor, labelColorBg } from '../utils/colors'; +import { TaskProps } from './types'; + +type ObjModelId = 'objectDetectionSsdlite' | 'objectDetectionRfdetr'; + +type Props = TaskProps & { activeModel: ObjModelId }; + +export default function ObjectDetectionTask({ + activeModel, + canvasSize, + cameraPosition, + frameKillSwitch, + onFrameOutputChange, + onReadyChange, + onProgressChange, + onGeneratingChange, + onFpsChange, +}: Props) { + const ssdlite = useObjectDetection({ + model: SSDLITE_320_MOBILENET_V3_LARGE, + preventLoad: activeModel !== 'objectDetectionSsdlite', + }); + const rfdetr = useObjectDetection({ + model: RF_DETR_NANO, + preventLoad: activeModel !== 'objectDetectionRfdetr', + }); + + const active = activeModel === 'objectDetectionSsdlite' ? ssdlite : rfdetr; + + const [detections, setDetections] = useState([]); + const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); + const lastFrameTimeRef = useRef(Date.now()); + + useEffect(() => { + onReadyChange(active.isReady); + }, [active.isReady, onReadyChange]); + + useEffect(() => { + onProgressChange(active.downloadProgress); + }, [active.downloadProgress, onProgressChange]); + + useEffect(() => { + onGeneratingChange(active.isGenerating); + }, [active.isGenerating, onGeneratingChange]); + + const detRof = active.runOnFrame; + + const updateDetections = useCallback( + (p: { results: Detection[]; imageWidth: number; imageHeight: number }) => { + setDetections(p.results); + setImageSize({ width: p.imageWidth, height: p.imageHeight }); + const now = Date.now(); + const diff = now - lastFrameTimeRef.current; + if (diff > 0) onFpsChange(Math.round(1000 / diff), diff); + lastFrameTimeRef.current = now; + }, + [onFpsChange] + ); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + if (frameKillSwitch.getDirty()) { + frame.dispose(); + return; + } + try { + if (!detRof) return; + const iw = frame.width > frame.height ? frame.height : frame.width; + const ih = frame.width > frame.height ? frame.width : frame.height; + const result = detRof(frame, 0.5); + if (result) { + scheduleOnRN(updateDetections, { + results: result, + imageWidth: iw, + imageHeight: ih, + }); + } + } catch { + // ignore + } finally { + frame.dispose(); + } + }, + [detRof, frameKillSwitch, updateDetections] + ), + }); + + useEffect(() => { + onFrameOutputChange(frameOutput); + }, [frameOutput, onFrameOutputChange]); + + const scale = Math.max( + canvasSize.width / imageSize.width, + canvasSize.height / imageSize.height + ); + const offsetX = (canvasSize.width - imageSize.width * scale) / 2; + const offsetY = (canvasSize.height - imageSize.height * scale) / 2; + + return ( + + {detections.map((det, i) => { + const left = det.bbox.x1 * scale + offsetX; + const top = det.bbox.y1 * scale + offsetY; + const w = (det.bbox.x2 - det.bbox.x1) * scale; + const h = (det.bbox.y2 - det.bbox.y1) * scale; + return ( + + + + {det.label} {(det.score * 100).toFixed(1)} + + + + ); + })} + + ); +} + +const styles = StyleSheet.create({ + bbox: { + position: 'absolute', + borderWidth: 2, + borderColor: 'cyan', + borderRadius: 4, + }, + bboxLabel: { + position: 'absolute', + top: -22, + left: -2, + paddingHorizontal: 6, + paddingVertical: 2, + borderRadius: 4, + }, + bboxLabelText: { color: 'white', fontSize: 11, fontWeight: '600' }, +}); diff --git a/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx b/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx new file mode 100644 index 0000000000..8226b0aae9 --- /dev/null +++ b/apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx @@ -0,0 +1,212 @@ +import React, { useCallback, useEffect, useRef, useState } from 'react'; +import { StyleSheet, View } from 'react-native'; +import { Frame, useFrameOutput } from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { + DEEPLAB_V3_RESNET50_QUANTIZED, + DEEPLAB_V3_RESNET101_QUANTIZED, + DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED, + FCN_RESNET50_QUANTIZED, + FCN_RESNET101_QUANTIZED, + LRASPP_MOBILENET_V3_LARGE_QUANTIZED, + SELFIE_SEGMENTATION, + useSemanticSegmentation, +} from 'react-native-executorch'; +import { + AlphaType, + Canvas, + ColorType, + Image as SkiaImage, + Skia, + SkImage, +} from '@shopify/react-native-skia'; +import { CLASS_COLORS } from '../utils/colors'; +import { TaskProps } from './types'; + +type SegModelId = + | 'segmentationDeeplabResnet50' + | 'segmentationDeeplabResnet101' + | 'segmentationDeeplabMobilenet' + | 'segmentationLraspp' + | 'segmentationFcnResnet50' + | 'segmentationFcnResnet101' + | 'segmentationSelfie'; + +type Props = TaskProps & { activeModel: SegModelId }; + +export default function SegmentationTask({ + activeModel, + canvasSize, + cameraPosition, + frameKillSwitch, + onFrameOutputChange, + onReadyChange, + onProgressChange, + onGeneratingChange, + onFpsChange, +}: Props) { + const segDeeplabResnet50 = useSemanticSegmentation({ + model: DEEPLAB_V3_RESNET50_QUANTIZED, + preventLoad: activeModel !== 'segmentationDeeplabResnet50', + }); + const segDeeplabResnet101 = useSemanticSegmentation({ + model: DEEPLAB_V3_RESNET101_QUANTIZED, + preventLoad: activeModel !== 'segmentationDeeplabResnet101', + }); + const segDeeplabMobilenet = useSemanticSegmentation({ + model: DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED, + preventLoad: activeModel !== 'segmentationDeeplabMobilenet', + }); + const segLraspp = useSemanticSegmentation({ + model: LRASPP_MOBILENET_V3_LARGE_QUANTIZED, + preventLoad: activeModel !== 'segmentationLraspp', + }); + const segFcnResnet50 = useSemanticSegmentation({ + model: FCN_RESNET50_QUANTIZED, + preventLoad: activeModel !== 'segmentationFcnResnet50', + }); + const segFcnResnet101 = useSemanticSegmentation({ + model: FCN_RESNET101_QUANTIZED, + preventLoad: activeModel !== 'segmentationFcnResnet101', + }); + const segSelfie = useSemanticSegmentation({ + model: SELFIE_SEGMENTATION, + preventLoad: activeModel !== 'segmentationSelfie', + }); + + const active = { + segmentationDeeplabResnet50: segDeeplabResnet50, + segmentationDeeplabResnet101: segDeeplabResnet101, + segmentationDeeplabMobilenet: segDeeplabMobilenet, + segmentationLraspp: segLraspp, + segmentationFcnResnet50: segFcnResnet50, + segmentationFcnResnet101: segFcnResnet101, + segmentationSelfie: segSelfie, + }[activeModel]; + + const [maskImage, setMaskImage] = useState(null); + const lastFrameTimeRef = useRef(Date.now()); + + useEffect(() => { + onReadyChange(active.isReady); + }, [active.isReady, onReadyChange]); + + useEffect(() => { + onProgressChange(active.downloadProgress); + }, [active.downloadProgress, onProgressChange]); + + useEffect(() => { + onGeneratingChange(active.isGenerating); + }, [active.isGenerating, onGeneratingChange]); + + // Clear stale mask when the segmentation model variant changes + useEffect(() => { + setMaskImage((prev) => { + prev?.dispose(); + return null; + }); + }, [activeModel]); + + // Dispose native Skia image on unmount to prevent memory leaks + useEffect(() => { + return () => { + setMaskImage((prev) => { + prev?.dispose(); + return null; + }); + }; + }, []); + + const segRof = active.runOnFrame; + + const updateMask = useCallback( + (img: SkImage) => { + setMaskImage((prev) => { + prev?.dispose(); + return img; + }); + const now = Date.now(); + const diff = now - lastFrameTimeRef.current; + if (diff > 0) onFpsChange(Math.round(1000 / diff), diff); + lastFrameTimeRef.current = now; + }, + [onFpsChange] + ); + + // CLASS_COLORS captured directly in closure — worklets cannot import modules + const colors = CLASS_COLORS; + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + if (frameKillSwitch.getDirty()) { + frame.dispose(); + return; + } + try { + if (!segRof) return; + const result = segRof(frame, [], false); + if (result?.ARGMAX) { + const argmax: Int32Array = result.ARGMAX; + const side = Math.round(Math.sqrt(argmax.length)); + const pixels = new Uint8Array(side * side * 4); + for (let i = 0; i < argmax.length; i++) { + const color = colors[argmax[i]!] ?? [0, 0, 0, 0]; + pixels[i * 4] = color[0]!; + pixels[i * 4 + 1] = color[1]!; + pixels[i * 4 + 2] = color[2]!; + pixels[i * 4 + 3] = color[3]!; + } + const skData = Skia.Data.fromBytes(pixels); + const img = Skia.Image.MakeImage( + { + width: side, + height: side, + alphaType: AlphaType.Unpremul, + colorType: ColorType.RGBA_8888, + }, + skData, + side * 4 + ); + if (img) scheduleOnRN(updateMask, img); + } + } catch { + // ignore + } finally { + frame.dispose(); + } + }, + [colors, frameKillSwitch, segRof, updateMask] + ), + }); + + useEffect(() => { + onFrameOutputChange(frameOutput); + }, [frameOutput, onFrameOutputChange]); + + if (!maskImage) return null; + + return ( + + + + + + ); +} diff --git a/apps/computer-vision/components/vision_camera/tasks/types.ts b/apps/computer-vision/components/vision_camera/tasks/types.ts new file mode 100644 index 0000000000..9727227f2f --- /dev/null +++ b/apps/computer-vision/components/vision_camera/tasks/types.ts @@ -0,0 +1,14 @@ +import { useFrameOutput } from 'react-native-vision-camera'; +import { createSynchronizable } from 'react-native-worklets'; + +export type TaskProps = { + activeModel: string; + canvasSize: { width: number; height: number }; + cameraPosition: 'front' | 'back'; + frameKillSwitch: ReturnType>; + onFrameOutputChange: (frameOutput: ReturnType) => void; + onReadyChange: (isReady: boolean) => void; + onProgressChange: (progress: number) => void; + onGeneratingChange: (isGenerating: boolean) => void; + onFpsChange: (fps: number, frameMs: number) => void; +}; diff --git a/apps/computer-vision/components/vision_camera/utils/colors.ts b/apps/computer-vision/components/vision_camera/utils/colors.ts new file mode 100644 index 0000000000..c38493a3b0 --- /dev/null +++ b/apps/computer-vision/components/vision_camera/utils/colors.ts @@ -0,0 +1,41 @@ +export const CLASS_COLORS: number[][] = [ + [0, 0, 0, 0], + [51, 255, 87, 180], + [51, 87, 255, 180], + [255, 51, 246, 180], + [51, 255, 246, 180], + [243, 255, 51, 180], + [141, 51, 255, 180], + [255, 131, 51, 180], + [51, 255, 131, 180], + [131, 51, 255, 180], + [255, 255, 51, 180], + [51, 255, 255, 180], + [255, 51, 143, 180], + [127, 51, 255, 180], + [51, 255, 175, 180], + [255, 175, 51, 180], + [179, 255, 51, 180], + [255, 87, 51, 180], + [255, 51, 162, 180], + [51, 162, 255, 180], + [162, 51, 255, 180], +]; + +export function hashLabel(label: string): number { + let hash = 5381; + for (let i = 0; i < label.length; i++) { + hash = (hash + hash * 32 + label.charCodeAt(i)) % 1000003; + } + return 1 + (Math.abs(hash) % (CLASS_COLORS.length - 1)); +} + +export function labelColor(label: string): string { + const color = CLASS_COLORS[hashLabel(label)]!; + return `rgba(${color[0]},${color[1]},${color[2]},1)`; +} + +export function labelColorBg(label: string): string { + const color = CLASS_COLORS[hashLabel(label)]!; + return `rgba(${color[0]},${color[1]},${color[2]},0.75)`; +} diff --git a/apps/computer-vision/package.json b/apps/computer-vision/package.json index d7128125dd..578acf19b3 100644 --- a/apps/computer-vision/package.json +++ b/apps/computer-vision/package.json @@ -31,14 +31,14 @@ "react-native-gesture-handler": "~2.28.0", "react-native-image-picker": "^7.2.2", "react-native-loading-spinner-overlay": "^3.0.1", - "react-native-nitro-image": "^0.12.0", - "react-native-nitro-modules": "^0.33.9", + "react-native-nitro-image": "0.13.0", + "react-native-nitro-modules": "0.35.0", "react-native-reanimated": "~4.2.2", "react-native-safe-area-context": "~5.6.0", "react-native-screens": "~4.16.0", "react-native-svg": "15.15.3", "react-native-svg-transformer": "^1.5.3", - "react-native-vision-camera": "5.0.0-beta.2", + "react-native-vision-camera": "5.0.0-beta.6", "react-native-worklets": "0.7.4" }, "devDependencies": { diff --git a/docs/docs/03-hooks/02-computer-vision/useClassification.md b/docs/docs/03-hooks/02-computer-vision/useClassification.md index e9c2eebfab..e88cce1aff 100644 --- a/docs/docs/03-hooks/02-computer-vision/useClassification.md +++ b/docs/docs/03-hooks/02-computer-vision/useClassification.md @@ -52,12 +52,16 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns a promise, which can resolve either to an error or an object containing categories with their probabilities. +To run the model, use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to an object containing categories with their probabilities. :::info Images from external sources are stored in your application's temporary directory. ::: +## VisionCamera integration + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). + ## Example ```typescript diff --git a/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md b/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md index caef87cdf2..a6ea5fa982 100644 --- a/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md +++ b/docs/docs/03-hooks/02-computer-vision/useImageEmbeddings.md @@ -63,7 +63,11 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/interfaces/ImageEmbeddingsType.md#forward) method. It accepts one argument which is a URI/URL to an image you want to encode or base64 (whole URI or only raw base64). The function returns a promise, which can resolve either to an error or an array of numbers representing the embedding. +To run the model, use the [`forward`](../../06-api-reference/interfaces/ImageEmbeddingsType.md#forward) method. It accepts one argument — the image to embed. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to a `Float32Array` representing the embedding. + +## VisionCamera integration + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Example diff --git a/docs/docs/03-hooks/02-computer-vision/useOCR.md b/docs/docs/03-hooks/02-computer-vision/useOCR.md index 76e7ad6956..41491c7143 100644 --- a/docs/docs/03-hooks/02-computer-vision/useOCR.md +++ b/docs/docs/03-hooks/02-computer-vision/useOCR.md @@ -50,7 +50,11 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the text recognized within the box, and the confidence score. For more information, please refer to the reference or type definitions. +To run the model, use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score. + +## VisionCamera integration + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Detection object diff --git a/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md b/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md index 69ac3c79ac..5fb2b2bb3a 100644 --- a/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md +++ b/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md @@ -66,7 +66,7 @@ You need more details? Check the following resources: To run the model, use the [`forward`](../../06-api-reference/interfaces/ObjectDetectionType.md#forward) method. It accepts two arguments: -- `imageSource` (required) - The image to process. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- `input` (required) - The image to process. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). - `detectionThreshold` (optional) - A number between 0 and 1 representing the minimum confidence score for a detection to be included in the results. Defaults to `0.7`. `forward` returns a promise resolving to an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects, each containing: @@ -107,6 +107,10 @@ function App() { } ``` +## VisionCamera integration + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). + ## Supported models | Model | Number of classes | Class list | diff --git a/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md b/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md index 3a6fa46553..dc654369c7 100644 --- a/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md +++ b/docs/docs/03-hooks/02-computer-vision/useSemanticSegmentation.md @@ -66,7 +66,7 @@ You need more details? Check the following resources: To run the model, use the [`forward`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) method. It accepts three arguments: -- [`imageSource`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- [`input`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). - [`classesOfInterest`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]` (no class masks). The `ARGMAX` map is always returned regardless of this parameter. - [`resizeToInput`](../../06-api-reference/interfaces/SemanticSegmentationType.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions (e.g. 224x224 for `DEEPLAB_V3_RESNET50`). @@ -115,6 +115,10 @@ function App() { } ``` +## VisionCamera integration + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). + ## Supported models | Model | Number of classes | Class list | Quantized | diff --git a/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md b/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md index 471bde35e9..d08d7e8688 100644 --- a/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md +++ b/docs/docs/03-hooks/02-computer-vision/useStyleTransfer.md @@ -23,10 +23,13 @@ import { const model = useStyleTransfer({ model: STYLE_TRANSFER_CANDY }); -const imageUri = 'file::///Users/.../cute_cat.png'; +const imageUri = 'file:///Users/.../cute_cat.png'; try { - const generatedImageUrl = await model.forward(imageUri); + // Returns a file URI string + const uri = await model.forward(imageUri, 'url'); + // Or returns raw PixelData (default) + const pixels = await model.forward(imageUri); } catch (error) { console.error(error); } @@ -51,30 +54,54 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use [`forward`](../../06-api-reference/interfaces/StyleTransferType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns a promise which can resolve either to an error or a URL to generated image. +To run the model, use the [`forward`](../../06-api-reference/interfaces/StyleTransferType.md#forward) method. It accepts two arguments: + +- `input` (required) — The image to stylize. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). +- `outputType` (optional) — Controls the return format: + - `'pixelData'` (default) — Returns a `PixelData` object with raw RGB pixels. No file is written. + - `'url'` — Saves the result to a temp file and returns its URI as a `string`. :::info -Images from external sources and the generated image are stored in your application's temporary directory. +When `outputType` is `'url'`, the generated image is stored in your application's temporary directory. ::: ## Example ```typescript +import { + useStyleTransfer, + STYLE_TRANSFER_CANDY, +} from 'react-native-executorch'; + function App() { const model = useStyleTransfer({ model: STYLE_TRANSFER_CANDY }); - // ... - const imageUri = 'file::///Users/.../cute_cat.png'; - - try { - const generatedImageUrl = await model.forward(imageUri); - } catch (error) { - console.error(error); - } - // ... + // Returns a file URI — easy to pass to + const runWithUrl = async (imageUri: string) => { + try { + const uri = await model.forward(imageUri, 'url'); + console.log('Styled image saved at:', uri); + } catch (error) { + console.error(error); + } + }; + + // Returns raw PixelData — useful for further processing or frame pipelines + const runWithPixelData = async (imageUri: string) => { + try { + const pixels = await model.forward(imageUri); + // pixels.dataPtr is a Uint8Array of RGB bytes + } catch (error) { + console.error(error); + } + }; } ``` +## VisionCamera integration + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). + ## Supported models - [Candy](https://github.com/pytorch/examples/tree/main/fast_neural_style) diff --git a/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md b/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md index b9d29fc423..80b142ac62 100644 --- a/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md +++ b/docs/docs/03-hooks/02-computer-vision/useVerticalOCR.md @@ -58,7 +58,11 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the text recognized within the box, and the confidence score. For more information, please refer to the reference or type definitions. +To run the model, use the [`forward`](../../06-api-reference/interfaces/OCRType.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score. + +## VisionCamera integration + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Detection object diff --git a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md new file mode 100644 index 0000000000..79f6c5aad8 --- /dev/null +++ b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md @@ -0,0 +1,207 @@ +--- +title: VisionCamera Integration +--- + +React Native ExecuTorch vision models support real-time frame processing via [VisionCamera v5](https://react-native-vision-camera-v5-docs.vercel.app) using the `runOnFrame` worklet. This page explains how to set it up and what to watch out for. + +## Prerequisites + +Make sure you have the following packages installed: + +- [`react-native-vision-camera`](https://react-native-vision-camera-v5-docs.vercel.app) v5 +- [`react-native-worklets`](https://docs.swmansion.com/react-native-worklets/) + +## Which models support runOnFrame? + +The following hooks expose `runOnFrame`: + +- [`useClassification`](./useClassification.md) +- [`useImageEmbeddings`](./useImageEmbeddings.md) +- [`useOCR`](./useOCR.md) +- [`useVerticalOCR`](./useVerticalOCR.md) +- [`useObjectDetection`](./useObjectDetection.md) +- [`useSemanticSegmentation`](./useSemanticSegmentation.md) +- [`useStyleTransfer`](./useStyleTransfer.md) + +## runOnFrame vs forward + +| | `runOnFrame` | `forward` | +| -------- | -------------------- | -------------------------- | +| Thread | JS (worklet) | Background thread | +| Input | VisionCamera `Frame` | `string` URI / `PixelData` | +| Output | Model result (sync) | `Promise` | +| Use case | Real-time camera | Single image | + +Use `runOnFrame` when you need to process every camera frame. Use `forward` for one-off image inference. + +## How it works + +VisionCamera v5 delivers frames via [`useFrameOutput`](https://react-native-vision-camera-v5-docs.vercel.app/docs/frame-output). Inside the `onFrame` worklet you call `runOnFrame(frame)` synchronously, then use `scheduleOnRN` from `react-native-worklets` to post the result back to React state on the main thread. + +:::warning +You **must** set `pixelFormat: 'rgb'` in `useFrameOutput`. Our extraction pipeline expect RGB pixel data — any other format (e.g. the default `yuv`) will produce incorrect results. +::: + +:::warning +`runOnFrame` is synchronous and runs on the JS worklet thread. For models with longer inference times, use `dropFramesWhileBusy: true` to skip frames and avoid blocking the camera pipeline. For more control, see VisionCamera's [async frame processing guide](https://react-native-vision-camera-v5-docs.vercel.app/docs/async-frame-processing). +::: + +:::note +Always call `frame.dispose()` after processing to release the frame buffer. Wrap your inference in a `try/finally` to ensure it's always called even if `runOnFrame` throws. +::: + +## Full example (Classification) + +```tsx +import { useState, useCallback } from 'react'; +import { Text, StyleSheet } from 'react-native'; +import { + Camera, + Frame, + useCameraDevices, + useCameraPermission, + useFrameOutput, +} from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch'; + +export default function App() { + const { hasPermission, requestPermission } = useCameraPermission(); + const devices = useCameraDevices(); + const device = devices.find((d) => d.position === 'back'); + const model = useClassification({ model: EFFICIENTNET_V2_S }); + const [topLabel, setTopLabel] = useState(''); + + // Extract runOnFrame so it can be captured by the useCallback dependency array + const runOnFrame = model.runOnFrame; + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + if (!runOnFrame) return; + try { + const scores = runOnFrame(frame); + if (scores) { + let best = ''; + let bestScore = -1; + for (const [label, score] of Object.entries(scores)) { + if ((score as number) > bestScore) { + bestScore = score as number; + best = label; + } + } + scheduleOnRN(setTopLabel, best); + } + } finally { + frame.dispose(); + } + }, + [runOnFrame] + ), + }); + + if (!hasPermission) { + requestPermission(); + return null; + } + + if (!device) return null; + + return ( + <> + + {topLabel} + + ); +} + +const styles = StyleSheet.create({ + camera: { flex: 1 }, + label: { + position: 'absolute', + bottom: 40, + alignSelf: 'center', + color: 'white', + fontSize: 20, + }, +}); +``` + +## Using the Module API + +If you use the TypeScript Module API (e.g. `ClassificationModule`) directly instead of a hook, `runOnFrame` is a worklet function and **cannot** be passed directly to `useState` — React would invoke it as a state initializer. Use the functional updater form `() => module.runOnFrame`: + +```tsx +import { useState, useEffect, useCallback } from 'react'; +import { Camera, useFrameOutput } from 'react-native-vision-camera'; +import { scheduleOnRN } from 'react-native-worklets'; +import { + ClassificationModule, + EFFICIENTNET_V2_S, +} from 'react-native-executorch'; + +export default function App() { + const [module] = useState(() => new ClassificationModule()); + const [runOnFrame, setRunOnFrame] = useState( + null + ); + + useEffect(() => { + module.load(EFFICIENTNET_V2_S).then(() => { + // () => module.runOnFrame is required — passing module.runOnFrame directly + // would cause React to call it as a state initializer function + setRunOnFrame(() => module.runOnFrame); + }); + }, [module]); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + onFrame: useCallback( + (frame) => { + 'worklet'; + if (!runOnFrame) return; + try { + const result = runOnFrame(frame); + if (result) scheduleOnRN(setResult, result); + } finally { + frame.dispose(); + } + }, + [runOnFrame] + ), + }); + + return ; +} +``` + +## Common issues + +### Results look wrong or scrambled + +You forgot to set `pixelFormat: 'rgb'`. The default VisionCamera pixel format is `yuv` — our frame extraction works only with RGB data. + +### App freezes or camera drops frames + +Your model's inference time exceeds the frame interval. Enable `dropFramesWhileBusy: true` in `useFrameOutput`, or move inference off the worklet thread using VisionCamera's [async frame processing](https://react-native-vision-camera-v5-docs.vercel.app/docs/async-frame-processing). + +### Memory leak / crash after many frames + +You are not calling `frame.dispose()`. Always dispose the frame in a `finally` block. + +### `runOnFrame` is always null + +The model hasn't finished loading yet. Guard with `if (!runOnFrame) return` inside `onFrame`, or check `model.isReady` before enabling the camera. + +### TypeError: `module.runOnFrame` is not a function (Module API) + +You passed `module.runOnFrame` directly to `setState` instead of `() => module.runOnFrame`. React invoked it as a state initializer — see the [Module API section](#using-the-module-api) above. diff --git a/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md b/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md index 4234fa865d..a9bf6ed77d 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md @@ -47,7 +47,9 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method on the module object. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an object containing categories with their probabilities. +To run the model, use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an object containing categories with their probabilities. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. ## Managing memory diff --git a/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md b/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md index 7388416334..47eceef00a 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md @@ -48,4 +48,6 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -[`forward`](../../06-api-reference/classes/ImageEmbeddingsModule.md#forward) accepts one argument: image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The function returns a promise, which can resolve either to an error or an array of numbers representing the embedding. +[`forward`](../../06-api-reference/classes/ImageEmbeddingsModule.md#forward) accepts one argument — the image to embed. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to a `Float32Array` representing the embedding. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. diff --git a/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md b/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md index 1524859ed2..1391982173 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md @@ -41,4 +41,6 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/OCRModule.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score. +To run the model, use the [`forward`](../../06-api-reference/classes/OCRModule.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. diff --git a/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md b/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md index d942eded65..b56cb47713 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md @@ -40,7 +40,14 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/ObjectDetectionModule.md#forward) method on the module object. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score. +To run the model, use the [`forward`](../../06-api-reference/classes/ObjectDetectionModule.md#forward) method. It accepts two arguments: + +- `input` (required) - The image to process. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). +- `detectionThreshold` (optional) - A number between 0 and 1. Defaults to `0.7`. + +The method returns a promise resolving to an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects, each containing the bounding box, label, and confidence score. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. ## Using a custom model diff --git a/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md b/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md index bf88690bdb..4bf2129ac1 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md @@ -86,7 +86,7 @@ For more information on loading resources, take a look at [loading models](../.. To run the model, use the [`forward`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) method. It accepts three arguments: -- [`imageSource`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- [`input`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). - [`classesOfInterest`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]`. The `ARGMAX` map is always returned regardless. - [`resizeToInput`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions. @@ -113,6 +113,8 @@ result.CAT; // Float32Array result.DOG; // Float32Array ``` +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. + ## Managing memory The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/SemanticSegmentationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/SemanticSegmentationModule.md#forward) after [`delete`](../../06-api-reference/classes/SemanticSegmentationModule.md#delete) unless you create a new instance. diff --git a/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md b/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md index 0e70a24796..4c57716001 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md @@ -47,7 +47,14 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/StyleTransferModule.md#forward) method on the module object. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or a URL to generated image. +To run the model, use the [`forward`](../../06-api-reference/classes/StyleTransferModule.md#forward) method. It accepts two arguments: + +- `input` (required) — The image to stylize. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). +- `outputType` (optional) — Controls the return format: + - `'pixelData'` (default) — Returns a `PixelData` object with raw RGB pixels. No file is written. + - `'url'` — Saves the result to a temp file and returns its URI as a `string`. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. ## Managing memory diff --git a/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md b/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md index eb47efa857..cc1cdba51d 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md @@ -43,4 +43,6 @@ For more information on loading resources, take a look at [loading models](../.. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/VerticalOCRModule.md#forward) method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score. +To run the model, use the [`forward`](../../06-api-reference/classes/VerticalOCRModule.md#forward) method. It accepts one argument — the image to recognize. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an array of [`OCRDetection`](../../06-api-reference/interfaces/OCRDetection.md) objects, each containing the bounding box, recognized text, and confidence score. + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp index bd29500b00..d7a763819a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp @@ -225,7 +225,7 @@ readImageToTensor(const std::string &path, if (tensorDims.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected tensor size, expected at least 2 dimentions " + "Unexpected tensor size, expected at least 2 dimensions " "but got: %zu.", tensorDims.size()); throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h index 4057950b23..1eb288ee30 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h @@ -1,5 +1,8 @@ #pragma once +#include +#include + namespace rnexecutorch { using executorch::aten::ScalarType; diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 96e3168ee7..a4e373c2b8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -17,8 +18,10 @@ #include #include #include +#include #include #include +#include #include using namespace rnexecutorch::models::speech_to_text; @@ -557,4 +560,63 @@ inline jsi::Value getJsiValue(const TranscriptionResult &result, return obj; } +inline jsi::Value +getJsiValue(const models::style_transfer::PixelDataResult &result, + jsi::Runtime &runtime) { + jsi::Object obj(runtime); + + auto arrayBuffer = jsi::ArrayBuffer(runtime, result.dataPtr); + auto uint8ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Uint8Array"); + auto uint8Array = + uint8ArrayCtor.callAsConstructor(runtime, arrayBuffer).getObject(runtime); + obj.setProperty(runtime, "dataPtr", uint8Array); + + auto sizesArray = jsi::Array(runtime, 3); + sizesArray.setValueAtIndex(runtime, 0, jsi::Value(result.height)); + sizesArray.setValueAtIndex(runtime, 1, jsi::Value(result.width)); + sizesArray.setValueAtIndex(runtime, 2, jsi::Value(result.channels)); + obj.setProperty(runtime, "sizes", sizesArray); + + obj.setProperty(runtime, "scalarType", + jsi::Value(static_cast(ScalarType::Byte))); + + return obj; +} + +inline jsi::Value getJsiValue( + const rnexecutorch::models::semantic_segmentation::SegmentationResult + &result, + jsi::Runtime &runtime) { + jsi::Object dict(runtime); + + auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, result.argmax); + auto int32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Int32Array"); + auto int32Array = int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) + .getObject(runtime); + dict.setProperty(runtime, "ARGMAX", int32Array); + + for (auto &[classLabel, owningBuffer] : *result.classBuffers) { + auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); + auto float32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Float32Array"); + auto float32Array = + float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) + .getObject(runtime); + dict.setProperty(runtime, jsi::String::createFromAscii(runtime, classLabel), + float32Array); + } + + return dict; +} + +inline jsi::Value +getJsiValue(const models::style_transfer::StyleTransferResult &result, + jsi::Runtime &runtime) { + return std::visit( + [&runtime](const auto &value) { return getJsiValue(value, runtime); }, + result); +} + } // namespace rnexecutorch::jsi_conversion diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index c13b8991dc..e4361273d5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -190,13 +190,11 @@ template class ModelHostObject : public JsiHostObject { addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, synchronousHostFunction<&Model::streamStop>, "streamStop")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::streamInsert>, - "streamInsert")); addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::generateFromPhonemes>, "generateFromPhonemes")); + addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::streamFromPhonemes>, @@ -328,8 +326,24 @@ template class ModelHostObject : public JsiHostObject { return jsi_conversion::getJsiValue(std::move(result), runtime); } + } catch (const RnExecutorchError &e) { + jsi::Object errorData(runtime); + errorData.setProperty(runtime, "code", e.getNumericCode()); + errorData.setProperty(runtime, "message", + jsi::String::createFromUtf8(runtime, e.what())); + throw jsi::JSError(runtime, jsi::Value(runtime, std::move(errorData))); + } catch (const std::runtime_error &e) { + // This catch should be merged with the next one + // (std::runtime_error inherits from std::exception) HOWEVER react + // native has broken RTTI which breaks proper exception type + // checking. Remove when the following change is present in our + // version: + // https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e + throw jsi::JSError(runtime, e.what()); } catch (const std::exception &e) { throw jsi::JSError(runtime, e.what()); + } catch (...) { + throw jsi::JSError(runtime, "Unknown error in vision function"); } } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index b88310e124..0b6acbc383 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -1,53 +1,53 @@ #include "VisionModel.h" #include #include -#include #include namespace rnexecutorch::models { using namespace facebook; -cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime, - const jsi::Value &frameData) const { - auto frameObj = frameData.asObject(runtime); - return ::rnexecutorch::utils::extractFrame(runtime, frameObj); +VisionModel::VisionModel(const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) {} + +void VisionModel::unload() noexcept { + std::scoped_lock lock(inference_mutex_); + BaseModel::unload(); } -cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { - if (tensorView.sizes.size() != 3) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Invalid pixel data: sizes must have 3 elements " - "[height, width, channels], got %zu", - tensorView.sizes.size()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); +cv::Size VisionModel::modelInputSize() const { + if (modelInputShape_.size() < 2) { + return {0, 0}; } + return cv::Size(modelInputShape_[modelInputShape_.size() - 1], + modelInputShape_[modelInputShape_.size() - 2]); +} - int32_t height = tensorView.sizes[0]; - int32_t width = tensorView.sizes[1]; - int32_t channels = tensorView.sizes[2]; - - if (channels != 3) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Invalid pixel data: expected 3 channels (RGB), got %d", - channels); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } +cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) const { + cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData); + cv::Mat rgb; +#ifdef __APPLE__ + cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); +#else + cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); +#endif + return rgb; +} - if (tensorView.scalarType != ScalarType::Byte) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); +cv::Mat VisionModel::preprocess(const cv::Mat &image) const { + const cv::Size targetSize = modelInputSize(); + if (image.size() == targetSize) { + return image; } + cv::Mat resized; + cv::resize(image, resized, targetSize); + return resized; +} - uint8_t *dataPtr = static_cast(tensorView.dataPtr); - cv::Mat image(height, width, CV_8UC3, dataPtr); - - return image; +cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { + return ::rnexecutorch::utils::pixelsToMat(tensorView); } } // namespace rnexecutorch::models diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h index 4828f26578..6f9a9532f4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -23,8 +23,8 @@ namespace models { * Usage: * Subclasses should: * 1. Inherit from VisionModel instead of BaseModel - * 2. Implement preprocessFrame() with model-specific preprocessing - * 3. Delegate to runInference() which handles locking internally + * 2. Optionally override preprocess() for model-specific preprocessing + * 3. Implement runInference() which acquires the lock internally * * Example: * @code @@ -33,7 +33,7 @@ namespace models { * std::unordered_map * generateFromFrame(jsi::Runtime& runtime, const jsi::Value& frameValue) { * auto frameObject = frameValue.asObject(runtime); - * cv::Mat frame = utils::extractFrame(runtime, frameObject); + * cv::Mat frame = extractFromFrame(runtime, frameObject); * return runInference(frame); * } * }; @@ -41,19 +41,27 @@ namespace models { */ class VisionModel : public BaseModel { public: - /** - * @brief Construct a VisionModel with the same parameters as BaseModel - * - * VisionModel uses the same construction pattern as BaseModel, just adding - * thread-safety on top. - */ VisionModel(const std::string &modelSource, - std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) {} + std::shared_ptr callInvoker); virtual ~VisionModel() = default; + /** + * @brief Thread-safe unload that waits for any in-flight inference to + * complete + * + * Overrides BaseModel::unload() to acquire inference_mutex_ before + * resetting the module. This prevents a crash where BaseModel::unload() + * destroys module_ while generateFromFrame() is still executing on the + * VisionCamera worklet thread. + */ + void unload() noexcept; + protected: + /// Cached input tensor shape (getAllInputShapes()[0]). + /// Set once by each subclass constructor to avoid per-frame metadata lookups. + std::vector modelInputShape_; + /** * @brief Mutex to ensure thread-safe inference * @@ -70,44 +78,32 @@ class VisionModel : public BaseModel { mutable std::mutex inference_mutex_; /** - * @brief Preprocess a camera frame for model input + * @brief Resize an RGB image to the model's expected input size * - * This method should implement model-specific preprocessing such as: - * - Resizing to the model's expected input size - * - Color space conversion (e.g., BGR to RGB) - * - Normalization - * - Any other model-specific transformations + * Resizes to modelInputSize() if needed. Subclasses may override for + * model-specific preprocessing (e.g., normalisation). * - * @param frame Input frame from camera (already extracted and rotated by - * FrameExtractor) - * @return Preprocessed cv::Mat ready for tensor conversion + * @param image Input image in RGB format + * @return cv::Mat resized to modelInputSize(), in RGB format * - * @note The input frame is already in RGB format and rotated 90° clockwise - * @note This method is called under mutex protection in generateFromFrame() + * @note Called from runInference() under the inference mutex */ - virtual cv::Mat preprocessFrame(const cv::Mat &frame) const = 0; + virtual cv::Mat preprocess(const cv::Mat &image) const; + + /// Convenience accessor: spatial dimensions of the model input. + cv::Size modelInputSize() const; /** - * @brief Extract and preprocess frame from VisionCamera in one call + * @brief Extract an RGB cv::Mat from a VisionCamera frame * - * This is a convenience method that combines frame extraction and - * preprocessing. It handles both nativeBuffer (zero-copy) and ArrayBuffer - * paths automatically. + * Calls frameToMat() then converts the raw 4-channel frame + * (BGRA on iOS, RGBA on Android) to RGB. * * @param runtime JSI runtime * @param frameData JSI value containing frame data from VisionCamera + * @return cv::Mat in RGB format (3 channels) * - * @return Preprocessed cv::Mat ready for tensor conversion - * - * @throws std::runtime_error if frame extraction fails - * - * @note This method does NOT acquire the inference mutex - caller is - * responsible - * @note Typical usage: - * @code - * cv::Mat preprocessed = extractFromFrame(runtime, frameData); - * auto tensor = image_processing::getTensorFromMatrix(dims, preprocessed); - * @endcode + * @note Does NOT acquire the inference mutex — caller is responsible */ cv::Mat extractFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) const; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index 0fba071087..f713b59605 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -12,31 +12,33 @@ namespace rnexecutorch::models::classification { Classification::Classification(const std::string &modelSource, std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) { + : VisionModel(modelSource, callInvoker) { auto inputShapes = getAllInputShapes(); if (inputShapes.size() == 0) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - std::vector modelInputShape = inputShapes[0]; - if (modelInputShape.size() < 2) { + modelInputShape_ = inputShapes[0]; + if (modelInputShape_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " + "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - modelInputShape.size()); + modelInputShape_.size()); throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, errorMessage); } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); } std::unordered_map -Classification::generate(std::string imageSource) { +Classification::runInference(cv::Mat image) { + std::scoped_lock lock(inference_mutex_); + + cv::Mat preprocessed = preprocess(image); + auto inputTensor = - image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]) - .first; + image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); + auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { throw RnExecutorchError(forwardResult.error(), @@ -46,6 +48,30 @@ Classification::generate(std::string imageSource) { return postprocess(forwardResult->at(0).toTensor()); } +std::unordered_map +Classification::generateFromString(std::string imageSource) { + cv::Mat imageBGR = image_processing::readImage(imageSource); + + cv::Mat imageRGB; + cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); + + return runInference(imageRGB); +} + +std::unordered_map +Classification::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { + cv::Mat frame = extractFromFrame(runtime, frameData); + return runInference(frame); +} + +std::unordered_map +Classification::generateFromPixels(JSTensorViewIn pixelData) { + cv::Mat image = extractFromPixels(pixelData); + + return runInference(image); +} + std::unordered_map Classification::postprocess(const Tensor &tensor) { std::span resultData( diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h index 1465fc5f9b..9f62864b9e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h @@ -3,28 +3,38 @@ #include #include +#include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include +#include namespace rnexecutorch { namespace models::classification { using executorch::aten::Tensor; using executorch::extension::TensorPtr; -class Classification : public BaseModel { +class Classification : public VisionModel { public: Classification(const std::string &modelSource, std::shared_ptr callInvoker); + [[nodiscard("Registered non-void function")]] std::unordered_map< std::string_view, float> - generate(std::string imageSource); + generateFromString(std::string imageSource); + + [[nodiscard("Registered non-void function")]] std::unordered_map< + std::string_view, float> + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); + + [[nodiscard("Registered non-void function")]] std::unordered_map< + std::string_view, float> + generateFromPixels(JSTensorViewIn pixelData); private: - std::unordered_map postprocess(const Tensor &tensor); + std::unordered_map runInference(cv::Mat image); - cv::Size modelImageSize{0, 0}; + std::unordered_map postprocess(const Tensor &tensor); }; } // namespace models::classification diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp index ec3129e760..d2914469af 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp @@ -1,40 +1,40 @@ #include "ImageEmbeddings.h" - #include #include #include #include -#include namespace rnexecutorch::models::embeddings { ImageEmbeddings::ImageEmbeddings( const std::string &modelSource, std::shared_ptr callInvoker) - : BaseEmbeddings(modelSource, callInvoker) { + : VisionModel(modelSource, callInvoker) { auto inputTensors = getAllInputShapes(); if (inputTensors.size() == 0) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - std::vector modelInputShape = inputTensors[0]; - if (modelInputShape.size() < 2) { + modelInputShape_ = inputTensors[0]; + if (modelInputShape_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " + "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - modelInputShape.size()); + modelInputShape_.size()); throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, errorMessage); } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); } std::shared_ptr -ImageEmbeddings::generate(std::string imageSource) { - auto [inputTensor, originalSize] = - image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); +ImageEmbeddings::runInference(cv::Mat image) { + std::scoped_lock lock(inference_mutex_); + + cv::Mat preprocessed = preprocess(image); + + auto inputTensor = + image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); @@ -45,7 +45,33 @@ ImageEmbeddings::generate(std::string imageSource) { "is correct."); } - return BaseEmbeddings::postprocess(forwardResult); + auto forwardResultTensor = forwardResult->at(0).toTensor(); + return std::make_shared( + forwardResultTensor.const_data_ptr(), forwardResultTensor.nbytes()); +} + +std::shared_ptr +ImageEmbeddings::generateFromString(std::string imageSource) { + cv::Mat imageBGR = image_processing::readImage(imageSource); + + cv::Mat imageRGB; + cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); + + return runInference(imageRGB); +} + +std::shared_ptr +ImageEmbeddings::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { + cv::Mat frame = extractFromFrame(runtime, frameData); + return runInference(frame); +} + +std::shared_ptr +ImageEmbeddings::generateFromPixels(JSTensorViewIn pixelData) { + cv::Mat image = extractFromPixels(pixelData); + + return runInference(image); } } // namespace rnexecutorch::models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h index 7e114e939d..3a20301724 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h @@ -2,26 +2,37 @@ #include #include +#include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include +#include +#include namespace rnexecutorch { namespace models::embeddings { using executorch::extension::TensorPtr; using executorch::runtime::EValue; -class ImageEmbeddings final : public BaseEmbeddings { +class ImageEmbeddings final : public VisionModel { public: ImageEmbeddings(const std::string &modelSource, std::shared_ptr callInvoker); + + [[nodiscard( + "Registered non-void function")]] std::shared_ptr + generateFromString(std::string imageSource); + + [[nodiscard( + "Registered non-void function")]] std::shared_ptr + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); + [[nodiscard( "Registered non-void function")]] std::shared_ptr - generate(std::string imageSource); + generateFromPixels(JSTensorViewIn pixelData); private: - cv::Size modelImageSize{0, 0}; + std::shared_ptr runInference(cv::Mat image); }; } // namespace models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 1dad1a61a5..d30fc8aeed 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -5,7 +5,6 @@ #include #include #include -#include namespace rnexecutorch::models::object_detection { @@ -20,18 +19,16 @@ ObjectDetection::ObjectDetection( throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - std::vector modelInputShape = inputTensors[0]; - if (modelInputShape.size() < 2) { + modelInputShape_ = inputTensors[0]; + if (modelInputShape_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " + "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - modelInputShape.size()); + modelInputShape_.size()); throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, errorMessage); } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); } else if (!normMean.empty()) { @@ -46,46 +43,13 @@ ObjectDetection::ObjectDetection( } } -cv::Mat ObjectDetection::preprocessFrame(const cv::Mat &frame) const { - const std::vector tensorDims = getAllInputShapes()[0]; - cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1], - tensorDims[tensorDims.size() - 2]); - - cv::Mat rgb; - - if (frame.channels() == 4) { -#ifdef __APPLE__ - cv::cvtColor(frame, rgb, cv::COLOR_BGRA2RGB); -#else - cv::cvtColor(frame, rgb, cv::COLOR_RGBA2RGB); -#endif - } else if (frame.channels() == 3) { - rgb = frame; - } else { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unsupported frame format: %d channels", frame.channels()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); - } - - // Only resize if dimensions don't match - if (rgb.size() != tensorSize) { - cv::Mat resized; - cv::resize(rgb, resized, tensorSize); - return resized; - } - - return rgb; -} - std::vector ObjectDetection::postprocess(const std::vector &tensors, cv::Size originalSize, double detectionThreshold) { - float widthRatio = - static_cast(originalSize.width) / modelImageSize.width; + const cv::Size inputSize = modelInputSize(); + float widthRatio = static_cast(originalSize.width) / inputSize.width; float heightRatio = - static_cast(originalSize.height) / modelImageSize.height; + static_cast(originalSize.height) / inputSize.height; std::vector detections; auto bboxTensor = tensors.at(0).toTensor(); @@ -134,14 +98,14 @@ ObjectDetection::runInference(cv::Mat image, double detectionThreshold) { std::scoped_lock lock(inference_mutex_); cv::Size originalSize = image.size(); - cv::Mat preprocessed = preprocessFrame(image); + cv::Mat preprocessed = preprocess(image); - const std::vector tensorDims = getAllInputShapes()[0]; auto inputTensor = (normMean_ && normStd_) - ? image_processing::getTensorFromMatrix(tensorDims, preprocessed, - *normMean_, *normStd_) - : image_processing::getTensorFromMatrix(tensorDims, preprocessed); + ? image_processing::getTensorFromMatrix( + modelInputShape_, preprocessed, *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(modelInputShape_, + preprocessed); auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { @@ -168,9 +132,7 @@ std::vector ObjectDetection::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, double detectionThreshold) { - auto frameObj = frameData.asObject(runtime); - cv::Mat frame = rnexecutorch::utils::extractFrame(runtime, frameObj); - + cv::Mat frame = extractFromFrame(runtime, frameData); return runInference(frame, detectionThreshold); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h index f1159d88e0..d94087688a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h @@ -77,7 +77,6 @@ class ObjectDetection : public VisionModel { protected: std::vector runInference(cv::Mat image, double detectionThreshold); - cv::Mat preprocessFrame(const cv::Mat &frame) const override; private: /** @@ -100,9 +99,6 @@ class ObjectDetection : public VisionModel { cv::Size originalSize, double detectionThreshold); - /// Expected input image dimensions derived from the model's input shape. - cv::Size modelImageSize{0, 0}; - /// Optional per-channel mean for input normalisation (set in constructor). std::optional normMean_; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp index a521b4e8b0..3c64ba115f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace rnexecutorch::models::ocr { OCR::OCR(const std::string &detectorSource, const std::string &recognizerSource, @@ -12,12 +13,8 @@ OCR::OCR(const std::string &detectorSource, const std::string &recognizerSource, : detector(detectorSource, callInvoker), recognitionHandler(recognizerSource, symbols, callInvoker) {} -std::vector OCR::generate(std::string input) { - cv::Mat image = image_processing::readImage(input); - if (image.empty()) { - throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, - "Failed to load image from path: " + input); - } +std::vector OCR::runInference(cv::Mat image) { + std::scoped_lock lock(inference_mutex_); /* 1. Detection process returns the list of bounding boxes containing areas @@ -43,12 +40,46 @@ std::vector OCR::generate(std::string input) { return result; } +std::vector OCR::generateFromString(std::string input) { + cv::Mat image = image_processing::readImage(input); + if (image.empty()) { + throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, + "Failed to load image from path: " + input); + } + return runInference(image); +} + +std::vector +OCR::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) { + cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData); + cv::Mat bgr; +#ifdef __APPLE__ + cv::cvtColor(frame, bgr, cv::COLOR_BGRA2BGR); +#elif defined(__ANDROID__) + cv::cvtColor(frame, bgr, cv::COLOR_RGBA2BGR); +#else + throw RnExecutorchError( + RnExecutorchErrorCode::PlatformNotSupported, + "generateFromFrame is not supported on this platform"); +#endif + return runInference(bgr); +} + +std::vector +OCR::generateFromPixels(JSTensorViewIn pixelData) { + cv::Mat image; + cv::cvtColor(::rnexecutorch::utils::pixelsToMat(pixelData), image, + cv::COLOR_RGB2BGR); + return runInference(image); +} + std::size_t OCR::getMemoryLowerBound() const noexcept { return detector.getMemoryLowerBound() + recognitionHandler.getMemoryLowerBound(); } void OCR::unload() noexcept { + std::scoped_lock lock(inference_mutex_); detector.unload(); recognitionHandler.unload(); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h index d84ba903f0..719cb957c4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h @@ -1,9 +1,11 @@ #pragma once +#include #include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" +#include #include #include #include @@ -28,13 +30,20 @@ class OCR final { const std::string &recognizerSource, const std::string &symbols, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::vector - generate(std::string input); + generateFromString(std::string input); + [[nodiscard("Registered non-void function")]] std::vector + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); + [[nodiscard("Registered non-void function")]] std::vector + generateFromPixels(JSTensorViewIn pixelData); std::size_t getMemoryLowerBound() const noexcept; void unload() noexcept; private: + std::vector runInference(cv::Mat image); + Detector detector; RecognitionHandler recognitionHandler; + mutable std::mutex inference_mutex_; }; } // namespace models::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp index fc6a04ebcd..5ecf7493c4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp @@ -1,8 +1,6 @@ #include "BaseSemanticSegmentation.h" #include "jsi/jsi.h" -#include - #include #include #include @@ -15,7 +13,8 @@ BaseSemanticSegmentation::BaseSemanticSegmentation( const std::string &modelSource, std::vector normMean, std::vector normStd, std::vector allClasses, std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker), allClasses_(std::move(allClasses)) { + : VisionModel(modelSource, callInvoker), + allClasses_(std::move(allClasses)) { initModelImageSize(); if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); @@ -37,46 +36,71 @@ void BaseSemanticSegmentation::initModelImageSize() { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } - std::vector modelInputShape = inputShapes[0]; - if (modelInputShape.size() < 2) { + modelInputShape_ = inputShapes[0]; + if (modelInputShape_.size() < 2) { throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, "Unexpected model input size, expected at least 2 " "dimensions but got: " + - std::to_string(modelInputShape.size()) + "."); + std::to_string(modelInputShape_.size()) + "."); } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); - numModelPixels = modelImageSize.area(); -} - -TensorPtr BaseSemanticSegmentation::preprocess(const std::string &imageSource, - cv::Size &originalSize) { - auto [inputTensor, origSize] = image_processing::readImageToTensor( - imageSource, getAllInputShapes()[0], false, normMean_, normStd_); - originalSize = origSize; - return inputTensor; + numModelPixels = modelInputSize().area(); } -std::shared_ptr BaseSemanticSegmentation::generate( - std::string imageSource, - std::set> classesOfInterest, bool resize) { +semantic_segmentation::SegmentationResult +BaseSemanticSegmentation::runInference( + cv::Mat image, cv::Size originalSize, + std::set> &classesOfInterest, bool resize) { + std::scoped_lock lock(inference_mutex_); - cv::Size originalSize; - auto inputTensor = preprocess(imageSource, originalSize); + cv::Mat preprocessed = VisionModel::preprocess(image); + auto inputTensor = + (normMean_ && normStd_) + ? image_processing::getTensorFromMatrix( + modelInputShape_, preprocessed, *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(modelInputShape_, + preprocessed); auto forwardResult = BaseModel::forward(inputTensor); - if (!forwardResult.ok()) { throw RnExecutorchError(forwardResult.error(), "The model's forward function did not succeed. " "Ensure the model input is correct."); } - return postprocess(forwardResult->at(0).toTensor(), originalSize, allClasses_, - classesOfInterest, resize); + return computeResult(forwardResult->at(0).toTensor(), originalSize, + allClasses_, classesOfInterest, resize); +} + +semantic_segmentation::SegmentationResult +BaseSemanticSegmentation::generateFromString( + std::string imageSource, + std::set> classesOfInterest, bool resize) { + cv::Mat imageBGR = image_processing::readImage(imageSource); + cv::Size originalSize = imageBGR.size(); + cv::Mat imageRGB; + cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); + + return runInference(imageRGB, originalSize, classesOfInterest, resize); +} + +semantic_segmentation::SegmentationResult +BaseSemanticSegmentation::generateFromPixels( + JSTensorViewIn pixelData, + std::set> classesOfInterest, bool resize) { + cv::Mat image = extractFromPixels(pixelData); + return runInference(image, image.size(), classesOfInterest, resize); } -std::shared_ptr BaseSemanticSegmentation::postprocess( +semantic_segmentation::SegmentationResult +BaseSemanticSegmentation::generateFromFrame( + jsi::Runtime &runtime, const jsi::Value &frameData, + std::set> classesOfInterest, bool resize) { + cv::Mat frame = extractFromFrame(runtime, frameData); + return runInference(frame, frame.size(), classesOfInterest, resize); +} + +semantic_segmentation::SegmentationResult +BaseSemanticSegmentation::computeResult( const Tensor &tensor, cv::Size originalSize, std::vector &allClasses, std::set> &classesOfInterest, bool resize) { @@ -161,8 +185,8 @@ std::shared_ptr BaseSemanticSegmentation::postprocess( } // Filter classes of interest - auto buffersToReturn = std::make_shared>>(); + auto buffersToReturn = std::make_shared< + std::unordered_map>>(); for (std::size_t cl = 0; cl < resultClasses.size(); ++cl) { if (cl < allClasses.size() && classesOfInterest.contains(allClasses[cl])) { (*buffersToReturn)[allClasses[cl]] = resultClasses[cl]; @@ -185,48 +209,7 @@ std::shared_ptr BaseSemanticSegmentation::postprocess( } } - return populateDictionary(argmax, buffersToReturn); -} - -std::shared_ptr BaseSemanticSegmentation::populateDictionary( - std::shared_ptr argmax, - std::shared_ptr>> - classesToOutput) { - auto promisePtr = std::make_shared>(); - std::future doneFuture = promisePtr->get_future(); - - std::shared_ptr dictPtr = nullptr; - callInvoker->invokeAsync( - [argmax, classesToOutput, &dictPtr, promisePtr](jsi::Runtime &runtime) { - dictPtr = std::make_shared(runtime); - auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax); - - auto int32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Int32Array"); - auto int32Array = - int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) - .getObject(runtime); - dictPtr->setProperty(runtime, "ARGMAX", int32Array); - - for (auto &[classLabel, owningBuffer] : *classesToOutput) { - auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); - - auto float32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Float32Array"); - auto float32Array = - float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) - .getObject(runtime); - - dictPtr->setProperty( - runtime, jsi::String::createFromAscii(runtime, classLabel.data()), - float32Array); - } - promisePtr->set_value(); - }); - - doneFuture.wait(); - return dictPtr; + return semantic_segmentation::SegmentationResult{argmax, buffersToReturn}; } } // namespace rnexecutorch::models::semantic_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index d39a7e5d4a..a30ae375bf 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -1,23 +1,20 @@ #pragma once -#include -#include #include #include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include -#include +#include +#include namespace rnexecutorch { namespace models::semantic_segmentation { using namespace facebook; using executorch::aten::Tensor; -using executorch::extension::TensorPtr; -class BaseSemanticSegmentation : public BaseModel { +class BaseSemanticSegmentation : public VisionModel { public: BaseSemanticSegmentation(const std::string &modelSource, std::vector normMean, @@ -25,33 +22,42 @@ class BaseSemanticSegmentation : public BaseModel { std::vector allClasses, std::shared_ptr callInvoker); - [[nodiscard("Registered non-void function")]] std::shared_ptr - generate(std::string imageSource, - std::set> classesOfInterest, bool resize); + [[nodiscard("Registered non-void function")]] + semantic_segmentation::SegmentationResult + generateFromString(std::string imageSource, + std::set> classesOfInterest, + bool resize); + + [[nodiscard("Registered non-void function")]] + semantic_segmentation::SegmentationResult + generateFromPixels(JSTensorViewIn pixelData, + std::set> classesOfInterest, + bool resize); + + [[nodiscard("Registered non-void function")]] + semantic_segmentation::SegmentationResult + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, + std::set> classesOfInterest, + bool resize); protected: - virtual TensorPtr preprocess(const std::string &imageSource, - cv::Size &originalSize); - virtual std::shared_ptr - postprocess(const Tensor &tensor, cv::Size originalSize, - std::vector &allClasses, - std::set> &classesOfInterest, - bool resize); - - cv::Size modelImageSize; + virtual semantic_segmentation::SegmentationResult + computeResult(const Tensor &tensor, cv::Size originalSize, + std::vector &allClasses, + std::set> &classesOfInterest, + bool resize); std::size_t numModelPixels; std::optional normMean_; std::optional normStd_; std::vector allClasses_; - std::shared_ptr populateDictionary( - std::shared_ptr argmax, - std::shared_ptr>> - classesToOutput); - private: void initModelImageSize(); + + semantic_segmentation::SegmentationResult + runInference(cv::Mat image, cv::Size originalSize, + std::set> &classesOfInterest, + bool resize); }; } // namespace models::semantic_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h new file mode 100644 index 0000000000..b305b96a70 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include + +namespace rnexecutorch::models::semantic_segmentation { + +struct SegmentationResult { + std::shared_ptr argmax; + std::shared_ptr< + std::unordered_map>> + classBuffers; +}; + +} // namespace rnexecutorch::models::semantic_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index 3b9c0187b9..ad94e76ba4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -13,37 +13,31 @@ using executorch::extension::TensorPtr; StyleTransfer::StyleTransfer(const std::string &modelSource, std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) { + : VisionModel(modelSource, callInvoker) { auto inputShapes = getAllInputShapes(); if (inputShapes.size() == 0) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors"); } - std::vector modelInputShape = inputShapes[0]; - if (modelInputShape.size() < 2) { + modelInputShape_ = inputShapes[0]; + if (modelInputShape_.size() < 2) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " + "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", - modelInputShape.size()); + modelInputShape_.size()); throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, errorMessage); } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); } -std::string StyleTransfer::postprocess(const Tensor &tensor, - cv::Size originalSize) { - cv::Mat mat = image_processing::getMatrixFromTensor(modelImageSize, tensor); - cv::resize(mat, mat, originalSize); +cv::Mat StyleTransfer::runInference(cv::Mat image, cv::Size outputSize) { + std::scoped_lock lock(inference_mutex_); - return image_processing::saveToTempFile(mat); -} + cv::Mat preprocessed = preprocess(image); -std::string StyleTransfer::generate(std::string imageSource) { - auto [inputTensor, originalSize] = - image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); + auto inputTensor = + image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { @@ -52,7 +46,55 @@ std::string StyleTransfer::generate(std::string imageSource) { "Ensure the model input is correct."); } - return postprocess(forwardResult->at(0).toTensor(), originalSize); + cv::Mat mat = image_processing::getMatrixFromTensor( + modelInputSize(), forwardResult->at(0).toTensor()); + if (mat.size() != outputSize) { + cv::resize(mat, mat, outputSize); + } + return mat; +} + +PixelDataResult toPixelDataResult(const cv::Mat &bgrMat) { + cv::Size size = bgrMat.size(); + // Convert BGR -> RGBA so JS can pass the buffer directly to Skia + cv::Mat rgba; + cv::cvtColor(bgrMat, rgba, cv::COLOR_BGR2RGBA); + std::size_t dataSize = static_cast(size.width) * size.height * 4; + auto pixelBuffer = std::make_shared(rgba.data, dataSize); + return PixelDataResult{pixelBuffer, size.width, size.height, rgba.channels()}; +} + +StyleTransferResult StyleTransfer::generateFromString(std::string imageSource, + bool saveToFile) { + cv::Mat imageBGR = image_processing::readImage(imageSource); + cv::Size originalSize = imageBGR.size(); + + cv::Mat imageRGB; + cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); + + cv::Mat result = runInference(imageRGB, originalSize); + if (saveToFile) { + return image_processing::saveToTempFile(result); + } + return toPixelDataResult(result); +} + +PixelDataResult StyleTransfer::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { + cv::Mat frame = extractFromFrame(runtime, frameData); + + return toPixelDataResult(runInference(frame, modelInputSize())); +} + +StyleTransferResult StyleTransfer::generateFromPixels(JSTensorViewIn pixelData, + bool saveToFile) { + cv::Mat image = extractFromPixels(pixelData); + + cv::Mat result = runInference(image, image.size()); + if (saveToFile) { + return image_processing::saveToTempFile(result); + } + return toPixelDataResult(result); } } // namespace rnexecutorch::models::style_transfer diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h index 73744c4d82..c15095bf5b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -9,25 +9,30 @@ #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include +#include +#include +#include namespace rnexecutorch { namespace models::style_transfer { using namespace facebook; -using executorch::aten::Tensor; -using executorch::extension::TensorPtr; -class StyleTransfer : public BaseModel { +class StyleTransfer : public VisionModel { public: StyleTransfer(const std::string &modelSource, std::shared_ptr callInvoker); - [[nodiscard("Registered non-void function")]] std::string - generate(std::string imageSource); -private: - std::string postprocess(const Tensor &tensor, cv::Size originalSize); + [[nodiscard("Registered non-void function")]] StyleTransferResult + generateFromString(std::string imageSource, bool saveToFile); + + [[nodiscard("Registered non-void function")]] PixelDataResult + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); - cv::Size modelImageSize{0, 0}; + [[nodiscard("Registered non-void function")]] StyleTransferResult + generateFromPixels(JSTensorViewIn pixelData, bool saveToFile); + +private: + cv::Mat runInference(cv::Mat image, cv::Size outputSize); }; } // namespace models::style_transfer diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h new file mode 100644 index 0000000000..27df4ec6c6 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/Types.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include +#include + +namespace rnexecutorch::models::style_transfer { + +struct PixelDataResult { + std::shared_ptr dataPtr; + int width; + int height; + int channels; +}; + +using StyleTransferResult = std::variant; + +} // namespace rnexecutorch::models::style_transfer diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp index 0f75d20152..fef78d7953 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp @@ -1,10 +1,11 @@ #include "VerticalOCR.h" -#include #include +#include #include #include #include #include +#include #include namespace rnexecutorch::models::ocr { @@ -16,12 +17,9 @@ VerticalOCR::VerticalOCR(const std::string &detectorSource, converter(symbols), independentCharacters(independentChars), callInvoker(invoker) {} -std::vector VerticalOCR::generate(std::string input) { - cv::Mat image = image_processing::readImage(input); - if (image.empty()) { - throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, - "Failed to load image from path: " + input); - } +std::vector VerticalOCR::runInference(cv::Mat image) { + std::scoped_lock lock(inference_mutex_); + // 1. Large Detector std::vector largeBoxes = detector.generate(image, constants::kLargeDetectorWidth); @@ -44,6 +42,41 @@ std::vector VerticalOCR::generate(std::string input) { return predictions; } +std::vector +VerticalOCR::generateFromString(std::string input) { + cv::Mat image = image_processing::readImage(input); + if (image.empty()) { + throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, + "Failed to load image from path: " + input); + } + return runInference(image); +} + +std::vector +VerticalOCR::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { + cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData); + cv::Mat bgr; +#ifdef __APPLE__ + cv::cvtColor(frame, bgr, cv::COLOR_BGRA2BGR); +#elif defined(__ANDROID__) + cv::cvtColor(frame, bgr, cv::COLOR_RGBA2BGR); +#else + throw RnExecutorchError( + RnExecutorchErrorCode::PlatformNotSupported, + "generateFromFrame is not supported on this platform"); +#endif + return runInference(bgr); +} + +std::vector +VerticalOCR::generateFromPixels(JSTensorViewIn pixelData) { + cv::Mat image; + cv::cvtColor(::rnexecutorch::utils::pixelsToMat(pixelData), image, + cv::COLOR_RGB2BGR); + return runInference(image); +} + std::size_t VerticalOCR::getMemoryLowerBound() const noexcept { return detector.getMemoryLowerBound() + recognizer.getMemoryLowerBound(); } @@ -176,6 +209,7 @@ types::OCRDetection VerticalOCR::_processSingleTextBox( } void VerticalOCR::unload() noexcept { + std::scoped_lock lock(inference_mutex_); detector.unload(); recognizer.unload(); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h index e97fb90348..4016e28138 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h @@ -1,12 +1,14 @@ #pragma once #include +#include #include #include #include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" +#include #include #include #include @@ -48,11 +50,17 @@ class VerticalOCR final { bool indpendentCharacters, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::vector - generate(std::string input); + generateFromString(std::string input); + [[nodiscard("Registered non-void function")]] std::vector + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); + [[nodiscard("Registered non-void function")]] std::vector + generateFromPixels(JSTensorViewIn pixelData); std::size_t getMemoryLowerBound() const noexcept; void unload() noexcept; private: + std::vector runInference(cv::Mat image); + std::pair _handleIndependentCharacters( const types::DetectorBBox &box, const cv::Mat &originalImage, const std::vector &characterBoxes, @@ -75,6 +83,7 @@ class VerticalOCR final { CTCLabelConverter converter; bool independentCharacters; std::shared_ptr callInvoker; + mutable std::mutex inference_mutex_; }; } // namespace models::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index 426aafc1f3..bd359dcb8f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -157,13 +157,33 @@ add_rn_test(ImageProcessingTest unit/ImageProcessingTest.cpp LIBS opencv_deps ) +add_rn_test(FrameProcessorTests unit/FrameProcessorTest.cpp + SOURCES + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp + ${IMAGE_UTILS_SOURCES} + LIBS opencv_deps android +) + add_rn_test(BaseModelTests integration/BaseModelTest.cpp) +add_rn_test(VisionModelTests integration/VisionModelTest.cpp + SOURCES + ${RNEXECUTORCH_DIR}/models/VisionModel.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp + ${IMAGE_UTILS_SOURCES} + LIBS opencv_deps android +) + add_rn_test(ClassificationTests integration/ClassificationTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/classification/Classification.cpp + ${RNEXECUTORCH_DIR}/models/VisionModel.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${IMAGE_UTILS_SOURCES} - LIBS opencv_deps + LIBS opencv_deps android ) add_rn_test(ObjectDetectionTests integration/ObjectDetectionTest.cpp @@ -181,8 +201,11 @@ add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/embeddings/image/ImageEmbeddings.cpp ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp + ${RNEXECUTORCH_DIR}/models/VisionModel.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${IMAGE_UTILS_SOURCES} - LIBS opencv_deps + LIBS opencv_deps android ) add_rn_test(TextEmbeddingsTests integration/TextEmbeddingsTest.cpp @@ -196,8 +219,11 @@ add_rn_test(TextEmbeddingsTests integration/TextEmbeddingsTest.cpp add_rn_test(StyleTransferTests integration/StyleTransferTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/style_transfer/StyleTransfer.cpp + ${RNEXECUTORCH_DIR}/models/VisionModel.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${IMAGE_UTILS_SOURCES} - LIBS opencv_deps + LIBS opencv_deps android ) add_rn_test(VADTests integration/VoiceActivityDetectionTest.cpp @@ -273,8 +299,10 @@ add_rn_test(OCRTests integration/OCRTest.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/DetectorUtils.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognitionHandlerUtils.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognizerUtils.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${IMAGE_UTILS_SOURCES} - LIBS opencv_deps + LIBS opencv_deps android ) add_rn_test(VerticalOCRTests integration/VerticalOCRTest.cpp @@ -287,6 +315,8 @@ add_rn_test(VerticalOCRTests integration/VerticalOCRTest.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/DetectorUtils.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognitionHandlerUtils.cpp ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognizerUtils.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${IMAGE_UTILS_SOURCES} - LIBS opencv_deps + LIBS opencv_deps android ) diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp index 10aa663a4a..d164fcacb0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp @@ -1,6 +1,9 @@ #include "BaseModelTests.h" +#include "VisionModelTests.h" +#include #include #include +#include #include #include @@ -28,7 +31,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generate(kValidTestImagePath); + (void)model.generateFromString(kValidTestImagePath); } }; } // namespace model_tests @@ -36,43 +39,45 @@ template <> struct ModelTraits { using ClassificationTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(Classification, CommonModelTest, ClassificationTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(Classification, VisionModelTest, + ClassificationTypes); // ============================================================================ // Model-specific tests // ============================================================================ TEST(ClassificationGenerateTests, InvalidImagePathThrows) { Classification model(kValidClassificationModelPath, nullptr); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(ClassificationGenerateTests, EmptyImagePathThrows) { Classification model(kValidClassificationModelPath, nullptr); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(ClassificationGenerateTests, MalformedURIThrows) { Classification model(kValidClassificationModelPath, nullptr); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(ClassificationGenerateTests, ValidImageReturnsResults) { Classification model(kValidClassificationModelPath, nullptr); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); EXPECT_FALSE(results.empty()); } TEST(ClassificationGenerateTests, ResultsHaveCorrectSize) { Classification model(kValidClassificationModelPath, nullptr); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); auto expectedNumClasses = constants::kImagenet1kV1Labels.size(); EXPECT_EQ(results.size(), expectedNumClasses); } TEST(ClassificationGenerateTests, ResultsContainValidProbabilities) { Classification model(kValidClassificationModelPath, nullptr); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); float sum = 0.0f; for (const auto &[label, prob] : results) { @@ -85,7 +90,7 @@ TEST(ClassificationGenerateTests, ResultsContainValidProbabilities) { TEST(ClassificationGenerateTests, TopPredictionHasReasonableConfidence) { Classification model(kValidClassificationModelPath, nullptr); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); float maxProb = 0.0f; for (const auto &[label, prob] : results) { @@ -115,3 +120,15 @@ TEST(ClassificationInheritedTests, GetMethodMetaWorks) { auto result = model.getMethodMeta("forward"); EXPECT_TRUE(result.ok()); } + +// ============================================================================ +// generateFromPixels smoke test +// ============================================================================ +TEST(ClassificationPixelTests, ValidPixelsReturnsResults) { + Classification model(kValidClassificationModelPath, nullptr); + std::vector buf(64 * 64 * 3, 128); + JSTensorViewIn view{ + buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + auto results = model.generateFromPixels(view); + EXPECT_FALSE(results.empty()); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp index 3a23746957..4982206614 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp @@ -1,7 +1,10 @@ #include "BaseModelTests.h" +#include "VisionModelTests.h" #include +#include #include #include +#include #include using namespace rnexecutorch; @@ -29,7 +32,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generate(kValidTestImagePath); + (void)model.generateFromString(kValidTestImagePath); } }; } // namespace model_tests @@ -37,37 +40,39 @@ template <> struct ModelTraits { using ImageEmbeddingsTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, CommonModelTest, ImageEmbeddingsTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, VisionModelTest, + ImageEmbeddingsTypes); // ============================================================================ // Model-specific tests // ============================================================================ TEST(ImageEmbeddingsGenerateTests, InvalidImagePathThrows) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(ImageEmbeddingsGenerateTests, EmptyImagePathThrows) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(ImageEmbeddingsGenerateTests, MalformedURIThrows) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(ImageEmbeddingsGenerateTests, ValidImageReturnsResults) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); + auto result = model.generateFromString(kValidTestImagePath); EXPECT_NE(result, nullptr); EXPECT_GT(result->size(), 0u); } TEST(ImageEmbeddingsGenerateTests, ResultsHaveCorrectSize) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); + auto result = model.generateFromString(kValidTestImagePath); size_t numFloats = result->size() / sizeof(float); constexpr size_t kClipEmbeddingDimensions = 512; EXPECT_EQ(numFloats, kClipEmbeddingDimensions); @@ -77,7 +82,7 @@ TEST(ImageEmbeddingsGenerateTests, ResultsAreNormalized) { // TODO: Investigate the source of the issue; GTEST_SKIP() << "Expected to fail in emulator environments"; ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); + auto result = model.generateFromString(kValidTestImagePath); const float *data = reinterpret_cast(result->data()); size_t numFloats = result->size() / sizeof(float); @@ -92,7 +97,7 @@ TEST(ImageEmbeddingsGenerateTests, ResultsAreNormalized) { TEST(ImageEmbeddingsGenerateTests, ResultsContainValidValues) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); + auto result = model.generateFromString(kValidTestImagePath); const float *data = reinterpret_cast(result->data()); size_t numFloats = result->size() / sizeof(float); @@ -122,3 +127,16 @@ TEST(ImageEmbeddingsInheritedTests, GetMethodMetaWorks) { auto result = model.getMethodMeta("forward"); EXPECT_TRUE(result.ok()); } + +// ============================================================================ +// generateFromPixels smoke test +// ============================================================================ +TEST(ImageEmbeddingsPixelTests, ValidPixelsReturnsEmbedding) { + ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); + std::vector buf(64 * 64 * 3, 128); + JSTensorViewIn view{ + buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + auto result = model.generateFromPixels(view); + EXPECT_NE(result, nullptr); + EXPECT_GT(result->size(), 0u); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp index 428fb5afb1..072c761164 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp @@ -1,6 +1,8 @@ #include "BaseModelTests.h" +#include #include #include +#include #include #include @@ -41,7 +43,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generate(kValidTestImagePath); + (void)model.generateFromString(kValidTestImagePath); } }; } // namespace model_tests @@ -67,27 +69,27 @@ TEST(OCRCtorTests, EmptySymbolsThrows) { TEST(OCRGenerateTests, InvalidImagePathThrows) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(OCRGenerateTests, EmptyImagePathThrows) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(OCRGenerateTests, MalformedURIThrows) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(OCRGenerateTests, ValidImageReturnsResults) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); // May or may not have detections depending on image content EXPECT_GE(results.size(), 0u); } @@ -95,7 +97,7 @@ TEST(OCRGenerateTests, ValidImageReturnsResults) { TEST(OCRGenerateTests, DetectionsHaveValidBoundingBoxes) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); for (const auto &detection : results) { // Each bbox should have 4 points @@ -110,7 +112,7 @@ TEST(OCRGenerateTests, DetectionsHaveValidBoundingBoxes) { TEST(OCRGenerateTests, DetectionsHaveValidScores) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); for (const auto &detection : results) { EXPECT_GE(detection.score, 0.0f); @@ -121,8 +123,21 @@ TEST(OCRGenerateTests, DetectionsHaveValidScores) { TEST(OCRGenerateTests, DetectionsHaveNonEmptyText) { OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); - auto results = model.generate(kValidTestImagePath); + auto results = model.generateFromString(kValidTestImagePath); for (const auto &detection : results) { EXPECT_FALSE(detection.text.empty()); } } + +// ============================================================================ +// generateFromPixels smoke test +// ============================================================================ +TEST(OCRPixelTests, ValidPixelsReturnsResults) { + OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, + createMockCallInvoker()); + std::vector buf(64 * 64 * 3, 128); + JSTensorViewIn view{ + buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + auto results = model.generateFromPixels(view); + EXPECT_GE(results.size(), 0u); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp index 8964f20133..6222f6d682 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp @@ -1,4 +1,5 @@ #include "BaseModelTests.h" +#include "VisionModelTests.h" #include #include #include @@ -57,6 +58,8 @@ template <> struct ModelTraits { using ObjectDetectionTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, CommonModelTest, ObjectDetectionTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, VisionModelTest, + ObjectDetectionTypes); // ============================================================================ // Model-specific tests @@ -163,40 +166,6 @@ TEST(ObjectDetectionPixelTests, ValidPixelDataReturnsResults) { EXPECT_GE(results.size(), 0u); } -TEST(ObjectDetectionPixelTests, WrongSizesLengthThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - std::vector pixelData(16, 0); - JSTensorViewIn tensorView{ - pixelData.data(), {4, 4}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(tensorView, 0.5), - RnExecutorchError); -} - -TEST(ObjectDetectionPixelTests, WrongChannelCountThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - constexpr int32_t width = 4, height = 4, channels = 4; - std::vector pixelData(width * height * channels, 0); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(tensorView, 0.5), - RnExecutorchError); -} - -TEST(ObjectDetectionPixelTests, WrongScalarTypeThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - constexpr int32_t width = 4, height = 4, channels = 3; - std::vector pixelData(width * height * channels, 0); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Float}; - EXPECT_THROW((void)model.generateFromPixels(tensorView, 0.5), - RnExecutorchError); -} - TEST(ObjectDetectionPixelTests, NegativeThresholdThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp index 76b213ca8f..957421f091 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp @@ -1,12 +1,13 @@ +#include +#include #include #include +#include #include #include #include #include -#include - using namespace rnexecutorch; using namespace rnexecutorch::models::semantic_segmentation; using executorch::extension::make_tensor_ptr; @@ -15,6 +16,15 @@ using executorch::runtime::EValue; constexpr auto kValidSemanticSegmentationModelPath = "deeplabV3_xnnpack_fp32.pte"; +constexpr auto kValidTestImagePath = + "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; + +static JSTensorViewIn makeRgbView(std::vector &buf, int32_t h, + int32_t w) { + buf.assign(static_cast(h * w * 3), 128); + return JSTensorViewIn{ + buf.data(), {h, w, 3}, executorch::aten::ScalarType::Byte}; +} // Test fixture for tests that need dummy input data class SemanticSegmentationForwardTest : public ::testing::Test { @@ -94,6 +104,95 @@ TEST_F(SemanticSegmentationForwardTest, ForwardAfterUnloadThrows) { EXPECT_THROW((void)model->forward(EValue(inputTensor)), RnExecutorchError); } +// ============================================================================ +// generateFromString tests +// ============================================================================ +TEST(SemanticSegmentationGenerateTests, InvalidImagePathThrows) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + EXPECT_THROW( + (void)model.generateFromString("nonexistent_image.jpg", {}, true), + RnExecutorchError); +} + +TEST(SemanticSegmentationGenerateTests, EmptyImagePathThrows) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + EXPECT_THROW((void)model.generateFromString("", {}, true), RnExecutorchError); +} + +TEST(SemanticSegmentationGenerateTests, MalformedURIThrows) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + EXPECT_THROW( + (void)model.generateFromString("not_a_valid_uri://bad", {}, true), + RnExecutorchError); +} + +TEST(SemanticSegmentationGenerateTests, ValidImageNoFilterReturnsResult) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, {}, true); + EXPECT_NE(result.argmax, nullptr); + EXPECT_NE(result.classBuffers, nullptr); +} + +TEST(SemanticSegmentationGenerateTests, ValidImageReturnsAllClasses) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, {}, true); + ASSERT_NE(result.classBuffers, nullptr); + EXPECT_EQ(result.classBuffers->size(), 21u); +} + +TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::set> filter = {"PERSON", "CAT"}; + auto result = model.generateFromString(kValidTestImagePath, filter, true); + ASSERT_NE(result.classBuffers, nullptr); + // Only the requested classes should appear in classBuffers + for (const auto &[label, _] : *result.classBuffers) { + EXPECT_TRUE(filter.count(label) > 0); + } +} + +TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, {}, false); + EXPECT_NE(result.argmax, nullptr); +} + +// ============================================================================ +// generateFromPixels tests +// ============================================================================ +TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, {}, true); + EXPECT_NE(result.argmax, nullptr); + EXPECT_NE(result.classBuffers, nullptr); +} + +TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, {}, true); + ASSERT_NE(result.classBuffers, nullptr); + EXPECT_EQ(result.classBuffers->size(), 21u); +} + +TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) { + SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + std::set> filter = {"PERSON"}; + auto result = model.generateFromPixels(view, filter, true); + ASSERT_NE(result.classBuffers, nullptr); + for (const auto &[label, _] : *result.classBuffers) { + EXPECT_EQ(label, "PERSON"); + } +} + +// ============================================================================ +// Inherited BaseModel tests +// ============================================================================ TEST(SemanticSegmentationInheritedTests, GetInputShapeWorks) { SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); auto shape = model.getInputShape("forward", 0); @@ -125,6 +224,9 @@ TEST(SemanticSegmentationInheritedTests, InputShapeIsSquare) { EXPECT_EQ(shape[2], shape[3]); // Height == Width for DeepLabV3 } +// ============================================================================ +// Constants tests +// ============================================================================ TEST(SemanticSegmentationConstantsTests, ClassLabelsHas21Entries) { EXPECT_EQ(constants::kDeeplabV3Resnet50Labels.size(), 21u); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp index d5427ce61b..a4511cad11 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp @@ -1,9 +1,11 @@ #include "BaseModelTests.h" -#include "utils/TestUtils.h" -#include +#include "VisionModelTests.h" +#include #include #include +#include #include +#include using namespace rnexecutorch; using namespace rnexecutorch::models::style_transfer; @@ -14,6 +16,13 @@ constexpr auto kValidStyleTransferModelPath = constexpr auto kValidTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; +static JSTensorViewIn makeRgbView(std::vector &buf, int32_t h, + int32_t w) { + buf.assign(static_cast(h * w * 3), 128); + return JSTensorViewIn{ + buf.data(), {h, w, 3}, executorch::aten::ScalarType::Byte}; +} + // ============================================================================ // Common tests via typed test suite // ============================================================================ @@ -30,7 +39,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generate(kValidTestImagePath); + (void)model.generateFromString(kValidTestImagePath, false); } }; } // namespace model_tests @@ -38,59 +47,125 @@ template <> struct ModelTraits { using StyleTransferTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, CommonModelTest, StyleTransferTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, VisionModelTest, + StyleTransferTypes); // ============================================================================ -// Model-specific tests +// generateFromString tests // ============================================================================ TEST(StyleTransferGenerateTests, InvalidImagePathThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", false), RnExecutorchError); } TEST(StyleTransferGenerateTests, EmptyImagePathThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString("", false), RnExecutorchError); } TEST(StyleTransferGenerateTests, MalformedURIThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", false), RnExecutorchError); } TEST(StyleTransferGenerateTests, ValidImageReturnsFilePath) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); - EXPECT_FALSE(result.empty()); + auto result = model.generateFromString(kValidTestImagePath, false); + ASSERT_TRUE(std::holds_alternative(result)); + auto &pr = std::get(result); + EXPECT_NE(pr.dataPtr, nullptr); + EXPECT_GT(pr.width, 0); + EXPECT_GT(pr.height, 0); +} + +// ============================================================================ +// generateFromString saveToFile tests +// ============================================================================ +TEST(StyleTransferSaveToFileTests, SaveToFileFalseReturnsValidPixelData) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, false); + ASSERT_TRUE(std::holds_alternative(result)); + EXPECT_NE(std::get(result).dataPtr, nullptr); +} + +TEST(StyleTransferSaveToFileTests, SaveToFileFalseHasPositiveDimensions) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + auto result = model.generateFromString(kValidTestImagePath, false); + ASSERT_TRUE(std::holds_alternative(result)); + auto &pr = std::get(result); + EXPECT_GT(pr.width, 0); + EXPECT_GT(pr.height, 0); } -TEST(StyleTransferGenerateTests, ResultIsValidFilePath) { +TEST(StyleTransferSaveToFileTests, SaveToFileTrueReturnsStringVariant) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); - test_utils::trimFilePrefix(result); - EXPECT_TRUE(std::filesystem::exists(result)); + auto result = model.generateFromString(kValidTestImagePath, true); + EXPECT_TRUE(std::holds_alternative(result)); } -TEST(StyleTransferGenerateTests, ResultFileHasContent) { +TEST(StyleTransferSaveToFileTests, SaveToFileTrueStringIsNonEmpty) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - auto result = model.generate(kValidTestImagePath); - test_utils::trimFilePrefix(result); - auto fileSize = std::filesystem::file_size(result); - EXPECT_GT(fileSize, 0u); + auto result = model.generateFromString(kValidTestImagePath, true); + ASSERT_TRUE(std::holds_alternative(result)); + EXPECT_FALSE(std::get(result).empty()); } -TEST(StyleTransferGenerateTests, MultipleGeneratesWork) { +TEST(StyleTransferSaveToFileTests, SaveToFileTrueStringHasFileScheme) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_NO_THROW((void)model.generate(kValidTestImagePath)); - auto result1 = model.generate(kValidTestImagePath); - auto result2 = model.generate(kValidTestImagePath); - test_utils::trimFilePrefix(result1); - test_utils::trimFilePrefix(result2); - EXPECT_TRUE(std::filesystem::exists(result1)); - EXPECT_TRUE(std::filesystem::exists(result2)); + auto result = model.generateFromString(kValidTestImagePath, true); + ASSERT_TRUE(std::holds_alternative(result)); + EXPECT_TRUE(std::get(result).starts_with("file://")); } +// ============================================================================ +// generateFromPixels tests +// ============================================================================ +TEST(StyleTransferPixelTests, ValidPixelsSaveToFileFalseReturnsPixelData) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, false); + ASSERT_TRUE(std::holds_alternative(result)); + EXPECT_NE(std::get(result).dataPtr, nullptr); +} + +TEST(StyleTransferPixelTests, ValidPixelsSaveToFileFalseHasPositiveDimensions) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, false); + ASSERT_TRUE(std::holds_alternative(result)); + auto &pr = std::get(result); + EXPECT_GT(pr.width, 0); + EXPECT_GT(pr.height, 0); +} + +TEST(StyleTransferPixelTests, + ValidPixelsSaveToFileTrueReturnsFileSchemeString) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, true); + ASSERT_TRUE(std::holds_alternative(result)); + EXPECT_TRUE(std::get(result).starts_with("file://")); +} + +TEST(StyleTransferPixelTests, OutputDimensionsMatchInputSize) { + StyleTransfer model(kValidStyleTransferModelPath, nullptr); + std::vector buf; + auto view = makeRgbView(buf, 64, 64); + auto result = model.generateFromPixels(view, false); + ASSERT_TRUE(std::holds_alternative(result)); + auto &pr = std::get(result); + EXPECT_EQ(pr.width, 64); + EXPECT_EQ(pr.height, 64); +} + +// ============================================================================ +// Inherited BaseModel tests +// ============================================================================ TEST(StyleTransferInheritedTests, GetInputShapeWorks) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); auto shape = model.getInputShape("forward", 0); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp index 7b1010a81e..fd6d59441d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp @@ -1,6 +1,8 @@ #include "BaseModelTests.h" +#include #include #include +#include #include #include @@ -43,7 +45,7 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generate(kValidVerticalTestImagePath); + (void)model.generateFromString(kValidVerticalTestImagePath); } }; } // namespace model_tests @@ -85,34 +87,34 @@ TEST(VerticalOCRCtorTests, IndependentCharsFalseDoesntThrow) { TEST(VerticalOCRGenerateTests, IndependentCharsInvalidImageThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, IndependentCharsEmptyImagePathThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(VerticalOCRGenerateTests, IndependentCharsMalformedURIThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, IndependentCharsValidImageReturnsResults) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); EXPECT_GE(results.size(), 0u); } TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidBBoxes) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_EQ(detection.bbox.size(), 4u); @@ -126,7 +128,7 @@ TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidBBoxes) { TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidScores) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_GE(detection.score, 0.0f); @@ -137,7 +139,7 @@ TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidScores) { TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveNonEmptyText) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_FALSE(detection.text.empty()); @@ -148,34 +150,34 @@ TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveNonEmptyText) { TEST(VerticalOCRGenerateTests, JointCharsInvalidImageThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("nonexistent_image.jpg"), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, JointCharsEmptyImagePathThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - EXPECT_THROW((void)model.generate(""), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(VerticalOCRGenerateTests, JointCharsMalformedURIThrows) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - EXPECT_THROW((void)model.generate("not_a_valid_uri://bad"), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, JointCharsValidImageReturnsResults) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); EXPECT_GE(results.size(), 0u); } TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidBBoxes) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_EQ(detection.bbox.size(), 4u); @@ -189,7 +191,7 @@ TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidBBoxes) { TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidScores) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_GE(detection.score, 0.0f); @@ -200,7 +202,7 @@ TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidScores) { TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveNonEmptyText) { VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - auto results = model.generate(kValidVerticalTestImagePath); + auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { EXPECT_FALSE(detection.text.empty()); @@ -216,8 +218,10 @@ TEST(VerticalOCRStrategyTests, BothStrategiesRunSuccessfully) { kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, createMockCallInvoker()); - EXPECT_NO_THROW((void)independentModel.generate(kValidVerticalTestImagePath)); - EXPECT_NO_THROW((void)jointModel.generate(kValidVerticalTestImagePath)); + EXPECT_NO_THROW( + (void)independentModel.generateFromString(kValidVerticalTestImagePath)); + EXPECT_NO_THROW( + (void)jointModel.generateFromString(kValidVerticalTestImagePath)); } TEST(VerticalOCRStrategyTests, BothStrategiesReturnValidResults) { @@ -229,10 +233,24 @@ TEST(VerticalOCRStrategyTests, BothStrategiesReturnValidResults) { createMockCallInvoker()); auto independentResults = - independentModel.generate(kValidVerticalTestImagePath); - auto jointResults = jointModel.generate(kValidVerticalTestImagePath); + independentModel.generateFromString(kValidVerticalTestImagePath); + auto jointResults = + jointModel.generateFromString(kValidVerticalTestImagePath); // Both should return some results (or none if no text detected) EXPECT_GE(independentResults.size(), 0u); EXPECT_GE(jointResults.size(), 0u); } + +// ============================================================================ +// generateFromPixels smoke test +// ============================================================================ +TEST(VerticalOCRPixelTests, ValidPixelsReturnsResults) { + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, + ENGLISH_SYMBOLS, false, createMockCallInvoker()); + std::vector buf(64 * 64 * 3, 128); + JSTensorViewIn view{ + buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + auto results = model.generateFromPixels(view); + EXPECT_GE(results.size(), 0u); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp new file mode 100644 index 0000000000..6736454d6f --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp @@ -0,0 +1,121 @@ +#include +#include +#include +#include +#include +#include +#include + +using namespace rnexecutorch; +using namespace rnexecutorch::models; +using executorch::aten::ScalarType; + +// ============================================================================ +// TestableVisionModel — exposes protected methods for testing +// ============================================================================ +class TestableVisionModel : public VisionModel { +public: + explicit TestableVisionModel(const std::string &path) + : VisionModel(path, nullptr) {} + + cv::Mat preprocessPublic(const cv::Mat &img) const { return preprocess(img); } + + cv::Mat extractFromPixelsPublic(const JSTensorViewIn &v) const { + return extractFromPixels(v); + } + + void setInputShape(std::vector shape) { + modelInputShape_ = std::move(shape); + } +}; + +// Reuse the style_transfer .pte as a vehicle — we never call forward(). +constexpr auto kModelPath = "style_transfer_candy_xnnpack_fp32.pte"; + +// ============================================================================ +// preprocess() tests +// ============================================================================ +class VisionModelPreprocessTest : public ::testing::Test { +protected: + void SetUp() override { + model = std::make_unique(kModelPath); + } + std::unique_ptr model; +}; + +TEST_F(VisionModelPreprocessTest, CorrectSizeImageReturnedAsIs) { + model->setInputShape({1, 3, 64, 64}); + cv::Mat img(64, 64, CV_8UC3, cv::Scalar(100, 150, 200)); + auto result = model->preprocessPublic(img); + EXPECT_EQ(result.size(), cv::Size(64, 64)); +} + +TEST_F(VisionModelPreprocessTest, CorrectSizeDataUnchanged) { + model->setInputShape({1, 3, 8, 8}); + cv::Mat img(8, 8, CV_8UC3, cv::Scalar(255, 0, 0)); + auto result = model->preprocessPublic(img); + auto pixel = result.at(0, 0); + EXPECT_EQ(pixel[0], 255); + EXPECT_EQ(pixel[1], 0); + EXPECT_EQ(pixel[2], 0); +} + +TEST_F(VisionModelPreprocessTest, LargerImageIsResizedDown) { + model->setInputShape({1, 3, 32, 32}); + cv::Mat img(128, 128, CV_8UC3, cv::Scalar(0)); + auto result = model->preprocessPublic(img); + EXPECT_EQ(result.size(), cv::Size(32, 32)); +} + +TEST_F(VisionModelPreprocessTest, SmallerImageIsResizedUp) { + model->setInputShape({1, 3, 128, 128}); + cv::Mat img(32, 32, CV_8UC3, cv::Scalar(0)); + auto result = model->preprocessPublic(img); + EXPECT_EQ(result.size(), cv::Size(128, 128)); +} + +TEST_F(VisionModelPreprocessTest, NonSquareTargetSize) { + model->setInputShape({1, 3, 48, 96}); + cv::Mat img(200, 100, CV_8UC3, cv::Scalar(0)); + auto result = model->preprocessPublic(img); + EXPECT_EQ(result.rows, 48); + EXPECT_EQ(result.cols, 96); +} + +// ============================================================================ +// extractFromPixels() tests +// ============================================================================ +class VisionModelExtractFromPixelsTest : public ::testing::Test { +protected: + void SetUp() override { + model = std::make_unique(kModelPath); + } + std::unique_ptr model; +}; + +TEST_F(VisionModelExtractFromPixelsTest, ValidInputReturnsCorrectDimensions) { + std::vector buf(64 * 64 * 3, 128); + JSTensorViewIn view{buf.data(), {64, 64, 3}, ScalarType::Byte}; + auto mat = model->extractFromPixelsPublic(view); + EXPECT_EQ(mat.rows, 64); + EXPECT_EQ(mat.cols, 64); + EXPECT_EQ(mat.channels(), 3); +} + +TEST_F(VisionModelExtractFromPixelsTest, TwoDimensionalSizesThrows) { + std::vector buf(16, 0); + JSTensorViewIn view{buf.data(), {4, 4}, ScalarType::Byte}; + EXPECT_THROW(model->extractFromPixelsPublic(view), RnExecutorchError); +} + +TEST_F(VisionModelExtractFromPixelsTest, WrongChannelsThrows) { + std::vector buf(64, 0); + JSTensorViewIn view{buf.data(), {4, 4, 4}, ScalarType::Byte}; + EXPECT_THROW(model->extractFromPixelsPublic(view), RnExecutorchError); +} + +TEST_F(VisionModelExtractFromPixelsTest, WrongScalarTypeThrows) { + std::vector buf(48, 0); + JSTensorViewIn view{buf.data(), {4, 4, 3}, ScalarType::Float}; + EXPECT_THROW(model->extractFromPixelsPublic(view), RnExecutorchError); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTests.h b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTests.h new file mode 100644 index 0000000000..04e407787d --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTests.h @@ -0,0 +1,61 @@ +#pragma once + +#include "BaseModelTests.h" +#include +#include +#include +#include + +namespace model_tests { + +template class VisionModelTest : public ::testing::Test { +protected: + using Traits = ModelTraits; + using ModelType = typename Traits::ModelType; +}; + +TYPED_TEST_SUITE_P(VisionModelTest); + +TYPED_TEST_P(VisionModelTest, TwoConcurrentGeneratesDoNotCrash) { + SETUP_TRAITS(); + auto model = Traits::createValid(); + std::atomic successCount{0}; + std::atomic exceptionCount{0}; + + auto task = [&]() { + try { + Traits::callGenerate(model); + successCount++; + } catch (const rnexecutorch::RnExecutorchError &) { + exceptionCount++; + } + }; + + std::thread a(task); + std::thread b(task); + a.join(); + b.join(); + + EXPECT_EQ(successCount + exceptionCount, 2); +} + +TYPED_TEST_P(VisionModelTest, GenerateAndUnloadConcurrentlyDoesNotCrash) { + SETUP_TRAITS(); + auto model = Traits::createValid(); + + std::thread a([&]() { + try { + Traits::callGenerate(model); + } catch (const rnexecutorch::RnExecutorchError &) { + } + }); + std::thread b([&]() { model.unload(); }); + + a.join(); + b.join(); +} + +REGISTER_TYPED_TEST_SUITE_P(VisionModelTest, TwoConcurrentGeneratesDoNotCrash, + GenerateAndUnloadConcurrentlyDoesNotCrash); + +} // namespace model_tests diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp new file mode 100644 index 0000000000..cfea1eb2a4 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp @@ -0,0 +1,93 @@ +#include +#include +#include +#include +#include +#include + +using namespace rnexecutorch; +using namespace rnexecutorch::utils; +using executorch::aten::ScalarType; + +static JSTensorViewIn makeValidView(std::vector &buf, int32_t h, + int32_t w) { + buf.assign(static_cast(h * w * 3), 128); + return JSTensorViewIn{buf.data(), {h, w, 3}, ScalarType::Byte}; +} + +// ============================================================================ +// Valid input +// ============================================================================ +TEST(PixelsToMatValidInput, ProducesCorrectDimensions) { + std::vector buf; + auto view = makeValidView(buf, 48, 64); + auto mat = pixelsToMat(view); + EXPECT_EQ(mat.rows, 48); + EXPECT_EQ(mat.cols, 64); +} + +TEST(PixelsToMatValidInput, MatTypeIsCV_8UC3) { + std::vector buf; + auto view = makeValidView(buf, 4, 4); + auto mat = pixelsToMat(view); + EXPECT_EQ(mat.channels(), 3); + EXPECT_EQ(mat.type(), CV_8UC3); +} + +TEST(PixelsToMatValidInput, MatWrapsOriginalData) { + std::vector buf; + auto view = makeValidView(buf, 4, 4); + auto mat = pixelsToMat(view); + EXPECT_EQ(mat.data, buf.data()); +} + +// ============================================================================ +// Invalid sizes dimensionality +// ============================================================================ +TEST(PixelsToMatInvalidSizes, TwoDimensionalThrows) { + std::vector buf(16, 0); + JSTensorViewIn view{buf.data(), {4, 4}, ScalarType::Byte}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +TEST(PixelsToMatInvalidSizes, FourDimensionalThrows) { + std::vector buf(48, 0); + JSTensorViewIn view{buf.data(), {1, 4, 4, 3}, ScalarType::Byte}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +TEST(PixelsToMatInvalidSizes, EmptySizesThrows) { + std::vector buf(4, 0); + JSTensorViewIn view{buf.data(), {}, ScalarType::Byte}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +// ============================================================================ +// Invalid channel count +// ============================================================================ +TEST(PixelsToMatInvalidChannels, OneChannelThrows) { + std::vector buf(16, 0); + JSTensorViewIn view{buf.data(), {4, 4, 1}, ScalarType::Byte}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +TEST(PixelsToMatInvalidChannels, FourChannelsThrows) { + std::vector buf(64, 0); + JSTensorViewIn view{buf.data(), {4, 4, 4}, ScalarType::Byte}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +// ============================================================================ +// Invalid scalar type +// ============================================================================ +TEST(PixelsToMatInvalidScalarType, FloatScalarTypeThrows) { + std::vector buf(48, 0); + JSTensorViewIn view{buf.data(), {4, 4, 3}, ScalarType::Float}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} + +TEST(PixelsToMatInvalidScalarType, IntScalarTypeThrows) { + std::vector buf(48, 0); + JSTensorViewIn view{buf.data(), {4, 4, 3}, ScalarType::Int}; + EXPECT_THROW(pixelsToMat(view), RnExecutorchError); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp index baae35dc35..d14c522184 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp @@ -47,8 +47,7 @@ cv::Mat extractFromCVPixelBuffer(void *pixelBuffer) { errorMessage); } - // Note: We don't unlock here - Vision Camera manages the lifecycle - // When frame.dispose() is called, Vision Camera will unlock and release + CVPixelBufferUnlockBaseAddress(buffer, kCVPixelBufferLock_ReadOnly); return mat; } @@ -88,8 +87,7 @@ cv::Mat extractFromAHardwareBuffer(void *hardwareBuffer) { errorMessage); } - // Note: We don't unlock here - Vision Camera manages the lifecycle - + AHardwareBuffer_unlock(buffer, nullptr); return mat; #else throw RnExecutorchError(RnExecutorchErrorCode::PlatformNotSupported, diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp index 30238ad5c4..19df5ba34e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp @@ -25,4 +25,52 @@ cv::Mat extractFrame(jsi::Runtime &runtime, const jsi::Object &frameData) { return extractFromNativeBuffer(bufferPtr); } + +cv::Mat frameToMat(jsi::Runtime &runtime, const jsi::Value &frameData) { + auto frameObj = frameData.asObject(runtime); + cv::Mat frame = extractFrame(runtime, frameObj); + + // Camera sensors deliver landscape frames; rotate to portrait orientation. + if (frame.cols > frame.rows) { + cv::Mat upright; + cv::rotate(frame, upright, cv::ROTATE_90_CLOCKWISE); + return upright; + } + return frame; +} + +cv::Mat pixelsToMat(const JSTensorViewIn &pixelData) { + if (pixelData.sizes.size() != 3) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Invalid pixel data: sizes must have 3 elements " + "[height, width, channels], got %zu", + pixelData.sizes.size()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + int32_t height = pixelData.sizes[0]; + int32_t width = pixelData.sizes[1]; + int32_t channels = pixelData.sizes[2]; + + if (channels != 3) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Invalid pixel data: expected 3 channels (RGB), got %d", + channels); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + errorMessage); + } + + if (pixelData.scalarType != executorch::aten::ScalarType::Byte) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidUserInput, + "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); + } + + auto *dataPtr = static_cast(pixelData.dataPtr); + return cv::Mat(height, width, CV_8UC3, dataPtr); +} + } // namespace rnexecutorch::utils diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h index 403f4bde91..757fa95bbc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h @@ -2,6 +2,7 @@ #include #include +#include namespace rnexecutorch::utils { @@ -10,18 +11,30 @@ using namespace facebook; /** * @brief Extract cv::Mat from VisionCamera frame data via nativeBuffer * - * @param runtime JSI runtime - * @param frameData JSI object containing frame data from VisionCamera - * Expected properties: - * - nativeBuffer: BigInt pointer to native buffer - * - * @return cv::Mat wrapping the frame data (zero-copy) + * Returns an RGB mat (as delivered by the native buffer). * * @throws RnExecutorchError if nativeBuffer is not present or extraction fails - * * @note The returned cv::Mat does not own the data. - * Caller must ensure the source frame remains valid during use. */ cv::Mat extractFrame(jsi::Runtime &runtime, const jsi::Object &frameData); +/** + * @brief Convert a VisionCamera frame to a rotated RGB cv::Mat. + * + * Handles frame extraction and landscape→portrait rotation. + * Callers are responsible for any further colour space conversion. + */ +cv::Mat frameToMat(jsi::Runtime &runtime, const jsi::Value &frameData); + +/** + * @brief Validate a JSTensorViewIn and wrap its data in a RGB cv::Mat. + * + * Validates sizes (must be [H, W, 3]), scalar type (Byte), and returns a + * cv::Mat that wraps the raw pixel buffer without copying. + * Callers are responsible for any further colour space conversion. + * + * @throws RnExecutorchError on invalid input + */ +cv::Mat pixelsToMat(const JSTensorViewIn &pixelData); + } // namespace rnexecutorch::utils diff --git a/packages/react-native-executorch/src/controllers/BaseOCRController.ts b/packages/react-native-executorch/src/controllers/BaseOCRController.ts index 614d42a212..5ef5f935bc 100644 --- a/packages/react-native-executorch/src/controllers/BaseOCRController.ts +++ b/packages/react-native-executorch/src/controllers/BaseOCRController.ts @@ -2,7 +2,8 @@ import { Logger } from '../common/Logger'; import { symbols } from '../constants/ocr/symbols'; import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; -import { ResourceSource } from '../types/common'; +import { isPixelData } from '../modules/computer_vision/VisionModule'; +import { Frame, PixelData, ResourceSource } from '../types/common'; import { OCRLanguage, OCRDetection } from '../types/ocr'; import { ResourceFetcher } from '../utils/ResourceFetcher'; @@ -87,7 +88,37 @@ export abstract class BaseOCRController { } }; - public forward = async (imageSource: string): Promise => { + get runOnFrame(): ((frame: Frame) => OCRDetection[]) | null { + if (!this.isReady) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling runOnFrame().' + ); + } + + const nativeGenerateFromFrame = this.nativeModule.generateFromFrame; + + return (frame: any): OCRDetection[] => { + 'worklet'; + + let nativeBuffer: any = null; + try { + nativeBuffer = frame.getNativeBuffer(); + const frameData = { + nativeBuffer: nativeBuffer.pointer, + }; + return nativeGenerateFromFrame(frameData); + } finally { + if (nativeBuffer?.release) { + nativeBuffer.release(); + } + } + }; + } + + public forward = async ( + input: string | PixelData + ): Promise => { if (!this.isReady) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, @@ -104,7 +135,17 @@ export abstract class BaseOCRController { try { this.isGenerating = true; this.isGeneratingCallback(this.isGenerating); - return await this.nativeModule.generate(imageSource); + + if (typeof input === 'string') { + return await this.nativeModule.generateFromString(input); + } else if (isPixelData(input)) { + return await this.nativeModule.generateFromPixels(input); + } else { + throw new RnExecutorchError( + RnExecutorchErrorCode.InvalidArgument, + 'Invalid input: expected string path or PixelData object. For VisionCamera frames, use runOnFrame instead.' + ); + } } catch (e) { throw parseUnknownError(e); } finally { diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts index c014d6b0ed..2e3b4e22a0 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts @@ -3,6 +3,7 @@ import { ClassificationProps, ClassificationType, } from '../../types/classification'; +import { PixelData } from '../../types/common'; import { useModuleFactory } from '../useModuleFactory'; /** @@ -16,17 +17,30 @@ export const useClassification = ({ model, preventLoad = false, }: ClassificationProps): ClassificationType => { - const { error, isReady, isGenerating, downloadProgress, runForward } = - useModuleFactory({ - factory: (config, onProgress) => - ClassificationModule.fromModelName(config, onProgress), - config: model, - deps: [model.modelName, model.modelSource], - preventLoad, - }); + const { + error, + isReady, + isGenerating, + downloadProgress, + runForward, + runOnFrame, + } = useModuleFactory({ + factory: (config, onProgress) => + ClassificationModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); - const forward = (imageSource: string) => + const forward = (imageSource: string | PixelData) => runForward((inst) => inst.forward(imageSource)); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + runOnFrame, + }; }; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts index b4e79c9263..8c0e6209e5 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts @@ -3,6 +3,7 @@ import { ImageEmbeddingsProps, ImageEmbeddingsType, } from '../../types/imageEmbeddings'; +import { PixelData } from '../../types/common'; import { useModuleFactory } from '../useModuleFactory'; /** @@ -16,17 +17,30 @@ export const useImageEmbeddings = ({ model, preventLoad = false, }: ImageEmbeddingsProps): ImageEmbeddingsType => { - const { error, isReady, isGenerating, downloadProgress, runForward } = - useModuleFactory({ - factory: (config, onProgress) => - ImageEmbeddingsModule.fromModelName(config, onProgress), - config: model, - deps: [model.modelName, model.modelSource], - preventLoad, - }); + const { + error, + isReady, + isGenerating, + downloadProgress, + runForward, + runOnFrame, + } = useModuleFactory({ + factory: (config, onProgress) => + ImageEmbeddingsModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); - const forward = (imageSource: string) => + const forward = (imageSource: string | PixelData) => runForward((inst) => inst.forward(imageSource)); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + runOnFrame, + }; }; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts index 473d3631ba..208824b8b8 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts @@ -1,4 +1,5 @@ import { useCallback, useEffect, useState } from 'react'; +import { Frame, PixelData } from '../../types/common'; import { OCRController } from '../../controllers/OCRController'; import { RnExecutorchError } from '../../errors/errorUtils'; import { OCRDetection, OCRProps, OCRType } from '../../types/ocr'; @@ -15,6 +16,9 @@ export const useOCR = ({ model, preventLoad = false }: OCRProps): OCRType => { const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); const [error, setError] = useState(null); + const [runOnFrame, setRunOnFrame] = useState< + ((frame: Frame) => OCRDetection[]) | null + >(null); const [controller] = useState( () => @@ -38,7 +42,13 @@ export const useOCR = ({ model, preventLoad = false }: OCRProps): OCRType => { setDownloadProgress ); + const worklet = controller.runOnFrame; + if (worklet) { + setRunOnFrame(() => worklet); + } + return () => { + setRunOnFrame(null); if (controller.isReady) { controller.delete(); } @@ -53,10 +63,17 @@ export const useOCR = ({ model, preventLoad = false }: OCRProps): OCRType => { ]); const forward = useCallback( - (imageSource: string): Promise => + (imageSource: string | PixelData): Promise => controller.forward(imageSource), [controller] ); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + forward, + downloadProgress, + runOnFrame, + }; }; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts index 81c81ce22f..5dfc552b63 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts @@ -7,7 +7,6 @@ import { ObjectDetectionProps, ObjectDetectionType, } from '../../types/objectDetection'; -import { useMemo } from 'react'; import { PixelData } from '../../types/common'; import { useModuleFactory } from '../useModuleFactory'; @@ -31,7 +30,7 @@ export const useObjectDetection = ({ isGenerating, downloadProgress, runForward, - instance, + runOnFrame, } = useModuleFactory({ factory: (config, onProgress) => ObjectDetectionModule.fromModelName(config, onProgress), @@ -43,8 +42,6 @@ export const useObjectDetection = ({ const forward = (input: string | PixelData, detectionThreshold?: number) => runForward((inst) => inst.forward(input, detectionThreshold)); - const runOnFrame = useMemo(() => instance?.runOnFrame ?? null, [instance]); - return { error, isReady, diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts index dd43aaf8b3..ae6ebed938 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts @@ -1,3 +1,4 @@ +import { PixelData } from '../..'; import { SemanticSegmentationModule, SegmentationLabels, @@ -34,17 +35,23 @@ export const useSemanticSegmentation = < }: SemanticSegmentationProps): SemanticSegmentationType< SegmentationLabels> > => { - const { error, isReady, isGenerating, downloadProgress, runForward } = - useModuleFactory({ - factory: (config, onProgress) => - SemanticSegmentationModule.fromModelName(config, onProgress), - config: model, - deps: [model.modelName, model.modelSource], - preventLoad, - }); + const { + error, + isReady, + isGenerating, + downloadProgress, + runForward, + runOnFrame, + } = useModuleFactory({ + factory: (config, onProgress) => + SemanticSegmentationModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); const forward = >>( - imageSource: string, + imageSource: string | PixelData, classesOfInterest: K[] = [], resizeToInput: boolean = true ) => @@ -52,5 +59,12 @@ export const useSemanticSegmentation = < inst.forward(imageSource, classesOfInterest, resizeToInput) ); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + runOnFrame, + }; }; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts index bfa42eee71..dfa9095cc7 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts @@ -1,4 +1,5 @@ import { StyleTransferModule } from '../../modules/computer_vision/StyleTransferModule'; +import { PixelData } from '../../types/common'; import { StyleTransferProps, StyleTransferType, @@ -16,17 +17,32 @@ export const useStyleTransfer = ({ model, preventLoad = false, }: StyleTransferProps): StyleTransferType => { - const { error, isReady, isGenerating, downloadProgress, runForward } = - useModuleFactory({ - factory: (config, onProgress) => - StyleTransferModule.fromModelName(config, onProgress), - config: model, - deps: [model.modelName, model.modelSource], - preventLoad, - }); + const { + error, + isReady, + isGenerating, + downloadProgress, + runForward, + runOnFrame, + } = useModuleFactory({ + factory: (config, onProgress) => + StyleTransferModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); - const forward = (imageSource: string) => - runForward((inst) => inst.forward(imageSource)); + const forward = ( + imageSource: string | PixelData, + outputType?: O + ) => runForward((inst) => inst.forward(imageSource, outputType)); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + runOnFrame, + } as StyleTransferType; }; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts index 71774198fc..4f72c97d9c 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts @@ -1,4 +1,5 @@ import { useCallback, useEffect, useState } from 'react'; +import { Frame, PixelData } from '../../types/common'; import { VerticalOCRController } from '../../controllers/VerticalOCRController'; import { RnExecutorchError } from '../../errors/errorUtils'; import { OCRDetection, OCRType, VerticalOCRProps } from '../../types/ocr'; @@ -20,6 +21,10 @@ export const useVerticalOCR = ({ const [downloadProgress, setDownloadProgress] = useState(0); const [error, setError] = useState(null); + const [runOnFrame, setRunOnFrame] = useState< + ((frame: Frame) => OCRDetection[]) | null + >(null); + const [controller] = useState( () => new VerticalOCRController({ @@ -43,7 +48,13 @@ export const useVerticalOCR = ({ setDownloadProgress ); + const worklet = controller.runOnFrame; + if (worklet) { + setRunOnFrame(() => worklet); + } + return () => { + setRunOnFrame(null); if (controller.isReady) { controller.delete(); } @@ -59,10 +70,17 @@ export const useVerticalOCR = ({ ]); const forward = useCallback( - (imageSource: string): Promise => + (imageSource: string | PixelData): Promise => controller.forward(imageSource), [controller] ); - return { error, isReady, isGenerating, downloadProgress, forward }; + return { + error, + isReady, + isGenerating, + forward, + downloadProgress, + runOnFrame, + }; }; diff --git a/packages/react-native-executorch/src/hooks/useModule.ts b/packages/react-native-executorch/src/hooks/useModule.ts index cc1fc1ef2e..632e94d3ea 100644 --- a/packages/react-native-executorch/src/hooks/useModule.ts +++ b/packages/react-native-executorch/src/hooks/useModule.ts @@ -76,6 +76,8 @@ export const useModule = < return () => { isMounted = false; + setIsReady(false); + setRunOnFrame(null); moduleInstance.delete(); }; diff --git a/packages/react-native-executorch/src/hooks/useModuleFactory.ts b/packages/react-native-executorch/src/hooks/useModuleFactory.ts index 3d7f474052..bb3140518d 100644 --- a/packages/react-native-executorch/src/hooks/useModuleFactory.ts +++ b/packages/react-native-executorch/src/hooks/useModuleFactory.ts @@ -1,9 +1,11 @@ -import { useState, useEffect } from 'react'; +import { useState, useEffect, useMemo } from 'react'; import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; type Deletable = { delete: () => void }; +type RunOnFrame = M extends { runOnFrame: infer R } ? R : never; + /** * Shared hook for modules that are instantiated via an async static factory * (i.e. `SomeModule.fromModelName(config, onProgress)`). @@ -92,6 +94,14 @@ export function useModuleFactory({ } }; + const runOnFrame = useMemo( + () => + instance && 'runOnFrame' in instance + ? (instance.runOnFrame as RunOnFrame | null) + : null, + [instance] + ); + return { error, isReady, @@ -99,5 +109,6 @@ export function useModuleFactory({ downloadProgress, runForward, instance, + runOnFrame, }; } diff --git a/packages/react-native-executorch/src/modules/BaseLabeledModule.ts b/packages/react-native-executorch/src/modules/BaseLabeledModule.ts deleted file mode 100644 index 01678f83d5..0000000000 --- a/packages/react-native-executorch/src/modules/BaseLabeledModule.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { ResourceFetcher } from '../utils/ResourceFetcher'; -import { LabelEnum, ResourceSource } from '../types/common'; -import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; -import { RnExecutorchError } from '../errors/errorUtils'; -import { BaseModule } from './BaseModule'; - -/** - * Fetches a model binary and returns its local path, throwing if the download - * was interrupted (paused or cancelled). - * - * @internal - */ -export async function fetchModelPath( - source: ResourceSource, - onDownloadProgress: (progress: number) => void -): Promise { - const paths = await ResourceFetcher.fetch(onDownloadProgress, source); - if (!paths?.[0]) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'The download has been interrupted. Please retry.' - ); - } - return paths[0]; -} - -/** - * Given a model configs record (mapping model names to `{ labelMap }`) and a - * type `T` (either a model name key or a raw {@link LabelEnum}), resolves to - * the label map for that model or `T` itself. - * - * @internal - */ -export type ResolveLabels< - T, - Configs extends Record, -> = T extends keyof Configs - ? Configs[T]['labelMap'] - : T extends LabelEnum - ? T - : never; - -/** - * Base class for vision modules that carry a type-safe label map. - * - * @typeParam LabelMap - The resolved {@link LabelEnum} for the model's output classes. - * @internal - */ -export abstract class BaseLabeledModule< - LabelMap extends LabelEnum, -> extends BaseModule { - protected readonly labelMap: LabelMap; - - protected constructor(labelMap: LabelMap, nativeModule: unknown) { - super(); - this.labelMap = labelMap; - this.nativeModule = nativeModule; - } -} diff --git a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts index 43691c2047..d9ef0d7f73 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts @@ -1,22 +1,23 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; -import { ResourceSource } from '../../types/common'; +import { PixelData, ResourceSource } from '../../types/common'; import { ClassificationModelName } from '../../types/classification'; -import { BaseModule } from '../BaseModule'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { Logger } from '../../common/Logger'; +import { VisionModule } from './VisionModule'; /** * Module for image classification tasks. * * @category Typescript API */ -export class ClassificationModule extends BaseModule { +export class ClassificationModule extends VisionModule<{ + [category: string]: number; +}> { private constructor(nativeModule: unknown) { super(); this.nativeModule = nativeModule; } - /** * Creates a classification instance for a built-in model. * @@ -74,18 +75,9 @@ export class ClassificationModule extends BaseModule { ); } - /** - * Executes the model's forward pass to classify the provided image. - * - * @param imageSource - A string image source (file path, URI, or Base64). - * @returns A Promise resolving to an object mapping category labels to confidence scores. - */ - async forward(imageSource: string): Promise<{ [category: string]: number }> { - if (this.nativeModule == null) - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' - ); - return await this.nativeModule.generate(imageSource); + async forward( + input: string | PixelData + ): Promise<{ [category: string]: number }> { + return super.forward(input); } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts index dd81408e36..c4cd57b889 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts @@ -1,22 +1,21 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; -import { ResourceSource } from '../../types/common'; import { ImageEmbeddingsModelName } from '../../types/imageEmbeddings'; +import { ResourceSource, PixelData } from '../../types/common'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; -import { BaseModule } from '../BaseModule'; import { Logger } from '../../common/Logger'; +import { VisionModule } from './VisionModule'; /** * Module for generating image embeddings from input images. * * @category Typescript API */ -export class ImageEmbeddingsModule extends BaseModule { +export class ImageEmbeddingsModule extends VisionModule { private constructor(nativeModule: unknown) { super(); this.nativeModule = nativeModule; } - /** * Creates an image embeddings instance for a built-in model. * @@ -74,18 +73,7 @@ export class ImageEmbeddingsModule extends BaseModule { ); } - /** - * Executes the model's forward pass to generate an embedding for the provided image. - * - * @param imageSource - A string image source (file path, URI, or Base64). - * @returns A Promise resolving to a `Float32Array` containing the image embedding vector. - */ - async forward(imageSource: string): Promise { - if (this.nativeModule == null) - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' - ); - return new Float32Array(await this.nativeModule.generate(imageSource)); + async forward(input: string | PixelData): Promise { + return super.forward(input); } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index c24bbd1369..20eb51db56 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -13,8 +13,8 @@ import { import { fetchModelPath, ResolveLabels as ResolveLabelsFor, -} from '../BaseLabeledModule'; -import { VisionLabeledModule } from './VisionLabeledModule'; + VisionLabeledModule, +} from './VisionLabeledModule'; const ModelConfigs = { 'ssdlite-320-mobilenet-v3-large': { diff --git a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts index d24e988930..a35df92d83 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts @@ -1,4 +1,4 @@ -import { ResourceSource, LabelEnum } from '../../types/common'; +import { ResourceSource, LabelEnum, PixelData } from '../../types/common'; import { DeeplabLabel, ModelNameOf, @@ -7,14 +7,12 @@ import { SemanticSegmentationModelName, SelfieSegmentationLabel, } from '../../types/semanticSegmentation'; -import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; -import { RnExecutorchError } from '../../errors/errorUtils'; import { IMAGENET1K_MEAN, IMAGENET1K_STD } from '../../constants/commonVision'; import { - BaseLabeledModule, fetchModelPath, ResolveLabels as ResolveLabelsFor, -} from '../BaseLabeledModule'; + VisionLabeledModule, +} from './VisionLabeledModule'; const PascalVocSegmentationConfig = { labelMap: DeeplabLabel, @@ -79,7 +77,10 @@ type ResolveLabels = */ export class SemanticSegmentationModule< T extends SemanticSegmentationModelName | LabelEnum, -> extends BaseLabeledModule> { +> extends VisionLabeledModule< + Record<'ARGMAX', Int32Array> & Record, Float32Array>, + ResolveLabels +> { private constructor(labelMap: ResolveLabels, nativeModule: unknown) { super(labelMap, nativeModule); } @@ -184,35 +185,24 @@ export class SemanticSegmentationModule< /** * Executes the model's forward pass to perform semantic segmentation on the provided image. * - * @param imageSource - A string representing the image source (e.g., a file path, URI, or Base64-encoded string). + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) * @param classesOfInterest - An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. * @returns A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. * @throws {RnExecutorchError} If the model is not loaded. */ - async forward>( - imageSource: string, + override async forward>( + input: string | PixelData, classesOfInterest: K[] = [], resizeToInput: boolean = true ): Promise & Record> { - if (this.nativeModule == null) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded.' - ); - } - - const classesOfInterestNames = classesOfInterest.map((label) => - String(label) - ); - - const nativeResult = await this.nativeModule.generate( - imageSource, - classesOfInterestNames, - resizeToInput - ); - - return nativeResult as Record<'ARGMAX', Int32Array> & - Record; + const classesOfInterestNames = classesOfInterest.map(String); + return super.forward(input, classesOfInterestNames, resizeToInput); } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index 6027d1fd2d..6519a29b91 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -1,22 +1,21 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; -import { ResourceSource } from '../../types/common'; import { StyleTransferModelName } from '../../types/styleTransfer'; -import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; +import { ResourceSource, PixelData } from '../../types/common'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; -import { BaseModule } from '../BaseModule'; +import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { Logger } from '../../common/Logger'; +import { VisionModule } from './VisionModule'; /** * Module for style transfer tasks. * * @category Typescript API */ -export class StyleTransferModule extends BaseModule { +export class StyleTransferModule extends VisionModule { private constructor(nativeModule: unknown) { super(); this.nativeModule = nativeModule; } - /** * Creates a style transfer instance for a built-in model. * @@ -72,18 +71,12 @@ export class StyleTransferModule extends BaseModule { ); } - /** - * Executes the model's forward pass to apply the selected style to the provided image. - * - * @param imageSource - A string image source (file path, URI, or Base64). - * @returns A Promise resolving to the stylized image as a Base64-encoded string. - */ - async forward(imageSource: string): Promise { - if (this.nativeModule == null) - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' - ); - return await this.nativeModule.generate(imageSource); + async forward( + input: string | PixelData, + outputType?: O + ): Promise { + return super.forward(input, outputType === 'url') as Promise< + O extends 'url' ? string : PixelData + >; } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts index 61a0bab091..188b03c8c9 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts @@ -1,6 +1,31 @@ -import { LabelEnum } from '../../types/common'; +import { LabelEnum, ResourceSource } from '../../types/common'; +import { ResourceFetcher } from '../../utils/ResourceFetcher'; +import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; +import { RnExecutorchError } from '../../errors/errorUtils'; import { VisionModule } from './VisionModule'; +export { ResolveLabels } from '../../types/computerVision'; + +/** + * Fetches a model binary and returns its local path, throwing if the download + * was interrupted (paused or cancelled). + * + * @internal + */ +export async function fetchModelPath( + source: ResourceSource, + onDownloadProgress: (progress: number) => void +): Promise { + const paths = await ResourceFetcher.fetch(onDownloadProgress, source); + if (!paths?.[0]) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. Please retry.' + ); + } + return paths[0]; +} + /** * Base class for computer vision modules that carry a type-safe label map * and support the full VisionModule API (string/PixelData forward + runOnFrame). @@ -10,8 +35,8 @@ import { VisionModule } from './VisionModule'; * @internal */ export abstract class VisionLabeledModule< - TOutput, - LabelMap extends LabelEnum, + TOutput = unknown, + LabelMap extends LabelEnum = LabelEnum, > extends VisionModule { protected readonly labelMap: LabelMap; diff --git a/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts index 762d09987e..d2c78edf0d 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VisionModule.ts @@ -3,19 +3,7 @@ import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError } from '../../errors/errorUtils'; import { Frame, PixelData, ScalarType } from '../../types/common'; -/** - * Base class for computer vision models that support multiple input types. - * - * VisionModule extends BaseModule with: - * - Unified `forward()` API accepting string paths or raw pixel data - * - `runOnFrame` getter for real-time VisionCamera frame processing - * - Shared frame processor creation logic - * - * Subclasses should only implement model-specific loading logic. - * - * @category Typescript API - */ -function isPixelData(input: unknown): input is PixelData { +export function isPixelData(input: unknown): input is PixelData { return ( typeof input === 'object' && input !== null && @@ -29,6 +17,18 @@ function isPixelData(input: unknown): input is PixelData { ); } +/** + * Base class for computer vision models that support multiple input types. + * + * VisionModule extends BaseModule with: + * - Unified `forward()` API accepting string paths or raw pixel data + * - `runOnFrame` getter for real-time VisionCamera frame processing + * - Shared frame processor creation logic + * + * Subclasses implement model-specific loading logic and may override `forward` for typed signatures. + * + * @category Typescript API + */ export abstract class VisionModule extends BaseModule { /** * Synchronous worklet function for real-time VisionCamera frame processing. @@ -59,8 +59,11 @@ export abstract class VisionModule extends BaseModule { * ``` */ get runOnFrame(): ((frame: Frame, ...args: any[]) => TOutput) | null { - if (!this.nativeModule?.generateFromFrame) { - return null; + if (!this.nativeModule) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling runOnFrame().' + ); } // Extract pure JSI function reference (runs on JS thread) @@ -73,9 +76,7 @@ export abstract class VisionModule extends BaseModule { let nativeBuffer: any = null; try { nativeBuffer = frame.getNativeBuffer(); - const frameData = { - nativeBuffer: nativeBuffer.pointer, - }; + const frameData = { nativeBuffer: nativeBuffer.pointer }; return nativeGenerateFromFrame(frameData, ...args); } finally { if (nativeBuffer?.release) { @@ -127,7 +128,6 @@ export abstract class VisionModule extends BaseModule { RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded. Please load the model before calling forward().' ); - // Type detection and routing if (typeof input === 'string') { return await this.nativeModule.generateFromString(input, ...args); diff --git a/packages/react-native-executorch/src/types/classification.ts b/packages/react-native-executorch/src/types/classification.ts index 144f2af5ae..994d72a05c 100644 --- a/packages/react-native-executorch/src/types/classification.ts +++ b/packages/react-native-executorch/src/types/classification.ts @@ -1,5 +1,5 @@ import { RnExecutorchError } from '../errors/errorUtils'; -import { ResourceSource } from './common'; +import { ResourceSource, PixelData, Frame } from './common'; /** * Union of all built-in classification model names. @@ -53,9 +53,32 @@ export interface ClassificationType { /** * Executes the model's forward pass to classify the provided image. - * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be classified. - * @returns A Promise that resolves to the classification result (typically containing labels and confidence scores). + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) + * @returns A Promise that resolves to the classification result (labels and confidence scores). * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ - forward: (imageSource: string) => Promise<{ [category: string]: number }>; + forward: ( + input: string | PixelData + ) => Promise<{ [category: string]: number }>; + + /** + * Synchronous worklet function for real-time VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * Available after model is loaded (`isReady: true`). + * + * @param frame - VisionCamera Frame object + * @returns Object mapping class labels to confidence scores. + */ + runOnFrame: ((frame: Frame) => { [category: string]: number }) | null; } diff --git a/packages/react-native-executorch/src/types/computerVision.ts b/packages/react-native-executorch/src/types/computerVision.ts index e69de29bb2..62357f1402 100644 --- a/packages/react-native-executorch/src/types/computerVision.ts +++ b/packages/react-native-executorch/src/types/computerVision.ts @@ -0,0 +1,17 @@ +import { LabelEnum } from './common'; + +/** + * Given a model configs record (mapping model names to `{ labelMap }`) and a + * type `T` (either a model name key or a raw {@link LabelEnum}), resolves to + * the label map for that model or `T` itself. + * + * @internal + */ +export type ResolveLabels< + T, + Configs extends Record, +> = T extends keyof Configs + ? Configs[T]['labelMap'] + : T extends LabelEnum + ? T + : never; diff --git a/packages/react-native-executorch/src/types/imageEmbeddings.ts b/packages/react-native-executorch/src/types/imageEmbeddings.ts index 88308ddd6f..7130ac5b84 100644 --- a/packages/react-native-executorch/src/types/imageEmbeddings.ts +++ b/packages/react-native-executorch/src/types/imageEmbeddings.ts @@ -1,5 +1,5 @@ import { RnExecutorchError } from '../errors/errorUtils'; -import { ResourceSource } from './common'; +import { ResourceSource, PixelData, Frame } from './common'; /** * Union of all built-in image embeddings model names. @@ -53,9 +53,30 @@ export interface ImageEmbeddingsType { /** * Executes the model's forward pass to generate embeddings (a feature vector) for the provided image. - * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed. + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or {@link PixelData} object) * @returns A Promise that resolves to a `Float32Array` containing the generated embedding vector. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ - forward: (imageSource: string) => Promise; + forward: (input: string | PixelData) => Promise; + + /** + * Synchronous worklet function for real-time VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * Available after model is loaded (`isReady: true`). + * + * @param frame - VisionCamera Frame object + * @returns Float32Array containing the embedding vector for the frame. + */ + runOnFrame: ((frame: Frame) => Float32Array) | null; } diff --git a/packages/react-native-executorch/src/types/objectDetection.ts b/packages/react-native-executorch/src/types/objectDetection.ts index aa25e9c412..38dc4bd12d 100644 --- a/packages/react-native-executorch/src/types/objectDetection.ts +++ b/packages/react-native-executorch/src/types/objectDetection.ts @@ -110,6 +110,19 @@ export interface ObjectDetectionType { * @param detectionThreshold - An optional number between 0 and 1 representing the minimum confidence score. Default is 0.7. * @returns A Promise that resolves to an array of `Detection` objects. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. + * + * @example + * ```typescript + * // String path + * const detections1 = await model.forward('file:///path/to/image.jpg'); + * + * // Pixel data + * const detections2 = await model.forward({ + * dataPtr: new Uint8Array(rgbPixels), + * sizes: [480, 640, 3], + * scalarType: ScalarType.BYTE + * }); + * ``` */ forward: ( input: string | PixelData, diff --git a/packages/react-native-executorch/src/types/ocr.ts b/packages/react-native-executorch/src/types/ocr.ts index e31d618478..b4cd8c8b69 100644 --- a/packages/react-native-executorch/src/types/ocr.ts +++ b/packages/react-native-executorch/src/types/ocr.ts @@ -1,6 +1,6 @@ import { symbols } from '../constants/ocr/symbols'; import { RnExecutorchError } from '../errors/errorUtils'; -import { ResourceSource } from './common'; +import { Frame, PixelData, ResourceSource } from './common'; /** * OCRDetection represents a single detected text instance in an image, @@ -110,11 +110,35 @@ export interface OCRType { /** * Executes the OCR pipeline (detection and recognition) on the provided image. - * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed. - * @returns A Promise that resolves to the OCR results (typically containing the recognized text strings and their bounding boxes). + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) + * @returns A Promise that resolves to the OCR results (recognized text and bounding boxes). * @throws {RnExecutorchError} If the models are not loaded or are currently processing another image. */ - forward: (imageSource: string) => Promise; + forward: (input: string | PixelData) => Promise; + + /** + * Synchronous worklet function for VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * **Note**: OCR is a two-stage pipeline (detection + recognition) and may not + * achieve real-time frame rates. Frames may be dropped if inference is still running. + * + * Available after model is loaded (`isReady: true`). + * + * @param frame - VisionCamera Frame object + * @returns Array of OCRDetection results for the frame. + */ + runOnFrame: ((frame: Frame) => OCRDetection[]) | null; } /** diff --git a/packages/react-native-executorch/src/types/semanticSegmentation.ts b/packages/react-native-executorch/src/types/semanticSegmentation.ts index 109f6f094e..10784721b9 100644 --- a/packages/react-native-executorch/src/types/semanticSegmentation.ts +++ b/packages/react-native-executorch/src/types/semanticSegmentation.ts @@ -1,5 +1,5 @@ import { RnExecutorchError } from '../errors/errorUtils'; -import { LabelEnum, Triple, ResourceSource } from './common'; +import { LabelEnum, Triple, ResourceSource, PixelData, Frame } from './common'; /** * Configuration for a custom semantic segmentation model. @@ -148,15 +148,44 @@ export interface SemanticSegmentationType { /** * Executes the model's forward pass to perform semantic segmentation on the provided image. - * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed. + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) * @param classesOfInterest - An optional array of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. * @returns A Promise resolving to an object with an `'ARGMAX'` `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ forward: ( - imageSource: string, + input: string | PixelData, classesOfInterest?: K[], resizeToInput?: boolean ) => Promise & Record>; + + /** + * Synchronous worklet function for real-time VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * Available after model is loaded (`isReady: true`). + * + * @param frame - VisionCamera Frame object + * @param classesOfInterest - Labels for which to return per-class probability masks. + * @param resizeToInput - Whether to resize masks to original frame dimensions. Defaults to `true`. + * @returns Object with `ARGMAX` Int32Array and per-class Float32Array masks. + */ + runOnFrame: + | (( + frame: Frame, + classesOfInterest?: string[], + resizeToInput?: boolean + ) => Record<'ARGMAX', Int32Array> & Record) + | null; } diff --git a/packages/react-native-executorch/src/types/styleTransfer.ts b/packages/react-native-executorch/src/types/styleTransfer.ts index 2571203ee1..a325f94d25 100644 --- a/packages/react-native-executorch/src/types/styleTransfer.ts +++ b/packages/react-native-executorch/src/types/styleTransfer.ts @@ -1,5 +1,5 @@ import { RnExecutorchError } from '../errors/errorUtils'; -import { ResourceSource } from './common'; +import { ResourceSource, PixelData, Frame } from './common'; /** * Union of all built-in style transfer model names. @@ -59,9 +59,34 @@ export interface StyleTransferType { /** * Executes the model's forward pass to apply the specific artistic style to the provided image. - * @param imageSource - A string representing the input image source (e.g., a file path, URI, or base64 string) to be stylized. - * @returns A Promise that resolves to a string containing the stylized image (typically as a base64 string or a file URI). + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * + * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. + * + * @param input - Image source (string or PixelData object) + * @param outputType - Output format: `'pixelData'` (default) returns raw RGBA pixel data; `'url'` saves the result to a temp file and returns its `file://` path. + * @returns A Promise resolving to `PixelData` when `outputType` is `'pixelData'` (default), or a `file://` URL string when `outputType` is `'url'`. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ - forward: (imageSource: string) => Promise; + forward( + input: string | PixelData, + outputType?: O + ): Promise; + + /** + * Synchronous worklet function for real-time VisionCamera frame processing. + * Automatically handles native buffer extraction and cleanup. + * + * **Use this for VisionCamera frame processing in worklets.** + * For async processing, use `forward()` instead. + * + * Available after model is loaded (`isReady: true`). + * + * @param frame - VisionCamera Frame object + * @returns PixelData containing the stylized frame as raw RGB pixel data. + */ + runOnFrame: ((frame: Frame) => PixelData) | null; } diff --git a/yarn.lock b/yarn.lock index 12cb5c31b7..0ed00cdfd9 100644 --- a/yarn.lock +++ b/yarn.lock @@ -7270,14 +7270,14 @@ __metadata: react-native-gesture-handler: "npm:~2.28.0" react-native-image-picker: "npm:^7.2.2" react-native-loading-spinner-overlay: "npm:^3.0.1" - react-native-nitro-image: "npm:^0.12.0" - react-native-nitro-modules: "npm:^0.33.9" + react-native-nitro-image: "npm:0.13.0" + react-native-nitro-modules: "npm:0.35.0" react-native-reanimated: "npm:~4.2.2" react-native-safe-area-context: "npm:~5.6.0" react-native-screens: "npm:~4.16.0" react-native-svg: "npm:15.15.3" react-native-svg-transformer: "npm:^1.5.3" - react-native-vision-camera: "npm:5.0.0-beta.2" + react-native-vision-camera: "npm:5.0.0-beta.6" react-native-worklets: "npm:0.7.4" languageName: unknown linkType: soft @@ -14475,24 +14475,24 @@ __metadata: languageName: node linkType: hard -"react-native-nitro-image@npm:^0.12.0": - version: 0.12.0 - resolution: "react-native-nitro-image@npm:0.12.0" +"react-native-nitro-image@npm:0.13.0": + version: 0.13.0 + resolution: "react-native-nitro-image@npm:0.13.0" peerDependencies: react: "*" react-native: "*" react-native-nitro-modules: "*" - checksum: 10/03f165381c35e060d4d05eae3ce029b32a4009482f327e9526840f306181ca87a862b335e12667c55d4ee9f2069542ca93dd112feb7f1822bf7d2ddc38fe58f0 + checksum: 10/77f04e0c262fed839aa16276a31cce6d6969788d2aada55594fd083959dec9e00bd75f7bd8333cc59f3768bf329592736da60688e827a8d25d18de8bbda9b2d7 languageName: node linkType: hard -"react-native-nitro-modules@npm:^0.33.9": - version: 0.33.9 - resolution: "react-native-nitro-modules@npm:0.33.9" +"react-native-nitro-modules@npm:0.35.0": + version: 0.35.0 + resolution: "react-native-nitro-modules@npm:0.35.0" peerDependencies: react: "*" react-native: "*" - checksum: 10/4ebf4db46d1e4987a0e52054724081aa9712bcd1d505a6dbdd47aebc6afe72a7abaa0e947651d9f3cc594e4eb3dba47fc6f59db27c5a5ed383946e40d96543a0 + checksum: 10/6c9166a115a03bfc26d3cb9a75761a1fdf33a06bdfb853779539cfe3d7dc2239e242a7fbd4cbb7e9dc0af90a373606fc607fa8a09ef4ff49f7ff29ccff736bbf languageName: node linkType: hard @@ -14588,16 +14588,16 @@ __metadata: languageName: node linkType: hard -"react-native-vision-camera@npm:5.0.0-beta.2": - version: 5.0.0-beta.2 - resolution: "react-native-vision-camera@npm:5.0.0-beta.2" +"react-native-vision-camera@npm:5.0.0-beta.6": + version: 5.0.0-beta.6 + resolution: "react-native-vision-camera@npm:5.0.0-beta.6" peerDependencies: react: "*" react-native: "*" react-native-nitro-image: "*" react-native-nitro-modules: "*" react-native-worklets: "*" - checksum: 10/1f38d097d001c10b8544d0b931a9387a91c5df1e0677ae53e639962a90589586af02ca658ca5e99a5ca179af8d86bc8365227cf70750f2df4bfb775f4a26fc6d + checksum: 10/5c1f3104869e51b173d2bba88a69397c08949314a47a14d70e9ea01b6531c8c74ea6e15f8de4bab6c55eb82c673f8231b6be223d8b4b44a8df433f1cd35c0376 languageName: node linkType: hard