Skip to content

Commit dbf86df

Browse files
feat: support all vision models
1 parent fc4eb98 commit dbf86df

File tree

16 files changed

+537
-164
lines changed

16 files changed

+537
-164
lines changed

apps/computer-vision/app/vision_camera/index.tsx

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,15 @@ import ColorPalette from '../../colors';
3131
import ClassificationTask from '../../components/vision_camera/tasks/ClassificationTask';
3232
import ObjectDetectionTask from '../../components/vision_camera/tasks/ObjectDetectionTask';
3333
import SegmentationTask from '../../components/vision_camera/tasks/SegmentationTask';
34+
import OCRTask from '../../components/vision_camera/tasks/OCRTask';
35+
import StyleTransferTask from '../../components/vision_camera/tasks/StyleTransferTask';
3436

35-
type TaskId = 'classification' | 'objectDetection' | 'segmentation';
37+
type TaskId =
38+
| 'classification'
39+
| 'objectDetection'
40+
| 'segmentation'
41+
| 'ocr'
42+
| 'styleTransfer';
3643
type ModelId =
3744
| 'classification'
3845
| 'objectDetectionSsdlite'
@@ -43,7 +50,10 @@ type ModelId =
4350
| 'segmentationLraspp'
4451
| 'segmentationFcnResnet50'
4552
| 'segmentationFcnResnet101'
46-
| 'segmentationSelfie';
53+
| 'segmentationSelfie'
54+
| 'ocr'
55+
| 'styleTransferCandy'
56+
| 'styleTransferMosaic';
4757

4858
type TaskVariant = { id: ModelId; label: string };
4959
type Task = { id: TaskId; label: string; variants: TaskVariant[] };
@@ -75,6 +85,19 @@ const TASKS: Task[] = [
7585
{ id: 'objectDetectionRfdetr', label: 'RF-DETR Nano' },
7686
],
7787
},
88+
{
89+
id: 'ocr',
90+
label: 'OCR',
91+
variants: [{ id: 'ocr', label: 'English' }],
92+
},
93+
{
94+
id: 'styleTransfer',
95+
label: 'Style',
96+
variants: [
97+
{ id: 'styleTransferCandy', label: 'Candy' },
98+
{ id: 'styleTransferMosaic', label: 'Mosaic' },
99+
],
100+
},
78101
];
79102

80103
// Module-level consts so worklets in task components can always reference the same stable objects.
@@ -225,6 +248,15 @@ export default function VisionCameraScreen() {
225248
}
226249
/>
227250
)}
251+
{activeTask === 'ocr' && <OCRTask {...taskProps} />}
252+
{activeTask === 'styleTransfer' && (
253+
<StyleTransferTask
254+
{...taskProps}
255+
activeModel={
256+
activeModel as 'styleTransferCandy' | 'styleTransferMosaic'
257+
}
258+
/>
259+
)}
228260

229261
{!isReady && (
230262
<View style={styles.loadingOverlay}>

apps/computer-vision/components/vision_camera/tasks/ClassificationTask.tsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ import { scheduleOnRN } from 'react-native-worklets';
55
import { EFFICIENTNET_V2_S, useClassification } from 'react-native-executorch';
66
import { TaskProps } from './types';
77

8-
type Props = Omit<TaskProps, 'activeModel' | 'canvasSize'>;
8+
type Props = Omit<
9+
TaskProps,
10+
'activeModel' | 'canvasSize' | 'cameraPositionSync'
11+
>;
912

1013
export default function ClassificationTask({
1114
frameKillSwitch,
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import React, { useCallback, useEffect, useRef, useState } from 'react';
2+
import { StyleSheet, View } from 'react-native';
3+
import { Frame, useFrameOutput } from 'react-native-vision-camera';
4+
import { scheduleOnRN } from 'react-native-worklets';
5+
import { OCR_ENGLISH, OCRDetection, useOCR } from 'react-native-executorch';
6+
import Svg, { Polygon, Text as SvgText } from 'react-native-svg';
7+
import { TaskProps } from './types';
8+
9+
type Props = TaskProps & { activeModel: string };
10+
11+
export default function OCRTask({
12+
canvasSize,
13+
cameraPositionSync,
14+
frameKillSwitch,
15+
onFrameOutputChange,
16+
onReadyChange,
17+
onProgressChange,
18+
onGeneratingChange,
19+
onFpsChange,
20+
}: Props) {
21+
const model = useOCR({ model: OCR_ENGLISH });
22+
const [detections, setDetections] = useState<OCRDetection[]>([]);
23+
const [imageSize, setImageSize] = useState({ width: 1, height: 1 });
24+
const lastFrameTimeRef = useRef(Date.now());
25+
26+
useEffect(() => {
27+
onReadyChange(model.isReady);
28+
}, [model.isReady, onReadyChange]);
29+
30+
useEffect(() => {
31+
onProgressChange(model.downloadProgress);
32+
}, [model.downloadProgress, onProgressChange]);
33+
34+
useEffect(() => {
35+
onGeneratingChange(model.isGenerating);
36+
}, [model.isGenerating, onGeneratingChange]);
37+
38+
const ocrRof = model.runOnFrame;
39+
40+
const updateDetections = useCallback(
41+
(p: { results: OCRDetection[]; frameW: number; frameH: number }) => {
42+
setDetections(p.results);
43+
setImageSize({ width: p.frameW, height: p.frameH });
44+
const now = Date.now();
45+
const diff = now - lastFrameTimeRef.current;
46+
if (diff > 0) onFpsChange(Math.round(1000 / diff), diff);
47+
lastFrameTimeRef.current = now;
48+
},
49+
[onFpsChange]
50+
);
51+
52+
const frameOutput = useFrameOutput({
53+
pixelFormat: 'rgb',
54+
dropFramesWhileBusy: true,
55+
enablePreviewSizedOutputBuffers: true,
56+
onFrame: useCallback(
57+
(frame: Frame) => {
58+
'worklet';
59+
if (frameKillSwitch.getDirty()) {
60+
frame.dispose();
61+
return;
62+
}
63+
try {
64+
if (!ocrRof) return;
65+
const isMirrored = cameraPositionSync.getDirty() === 'front';
66+
const result = ocrRof(frame, isMirrored);
67+
if (result) {
68+
scheduleOnRN(updateDetections, {
69+
results: result,
70+
frameW: frame.height,
71+
frameH: frame.width,
72+
});
73+
}
74+
} catch {
75+
// ignore
76+
} finally {
77+
frame.dispose();
78+
}
79+
},
80+
[cameraPositionSync, frameKillSwitch, ocrRof, updateDetections]
81+
),
82+
});
83+
84+
useEffect(() => {
85+
onFrameOutputChange(frameOutput);
86+
}, [frameOutput, onFrameOutputChange]);
87+
88+
const scale = Math.max(
89+
canvasSize.width / imageSize.width,
90+
canvasSize.height / imageSize.height
91+
);
92+
const offsetX = (canvasSize.width - imageSize.width * scale) / 2;
93+
const offsetY = (canvasSize.height - imageSize.height * scale) / 2;
94+
95+
if (!detections.length) return null;
96+
97+
return (
98+
<View style={StyleSheet.absoluteFill} pointerEvents="none">
99+
<Svg
100+
width={canvasSize.width}
101+
height={canvasSize.height}
102+
style={StyleSheet.absoluteFill}
103+
>
104+
{detections.map((det, i) => {
105+
const pts = det.bbox
106+
.map((p) => `${p.x * scale + offsetX},${p.y * scale + offsetY}`)
107+
.join(' ');
108+
const labelX = det.bbox[0]!.x * scale + offsetX;
109+
const labelY = det.bbox[0]!.y * scale + offsetY - 4;
110+
return (
111+
<React.Fragment key={i}>
112+
<Polygon
113+
points={pts}
114+
fill="none"
115+
stroke="cyan"
116+
strokeWidth={2}
117+
/>
118+
<SvgText
119+
x={labelX}
120+
y={labelY}
121+
fill="white"
122+
fontSize={12}
123+
fontWeight="bold"
124+
>
125+
{det.text}
126+
</SvgText>
127+
</React.Fragment>
128+
);
129+
})}
130+
</Svg>
131+
</View>
132+
);
133+
}

apps/computer-vision/components/vision_camera/tasks/SegmentationTask.tsx

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,17 @@ export default function SegmentationTask({
153153
const result = segRof(frame, isMirrored, [], false);
154154
if (result?.ARGMAX) {
155155
const argmax: Int32Array = result.ARGMAX;
156-
const side = Math.round(Math.sqrt(argmax.length));
157-
const pixels = new Uint8Array(side * side * 4);
156+
const screenW = frame.height;
157+
const screenH = frame.width;
158+
const maskW =
159+
argmax.length === screenW * screenH
160+
? screenW
161+
: Math.round(Math.sqrt(argmax.length));
162+
const maskH =
163+
argmax.length === screenW * screenH
164+
? screenH
165+
: Math.round(Math.sqrt(argmax.length));
166+
const pixels = new Uint8Array(maskW * maskH * 4);
158167
for (let i = 0; i < argmax.length; i++) {
159168
const color = colors[argmax[i]!] ?? [0, 0, 0, 0];
160169
pixels[i * 4] = color[0]!;
@@ -165,13 +174,13 @@ export default function SegmentationTask({
165174
const skData = Skia.Data.fromBytes(pixels);
166175
const img = Skia.Image.MakeImage(
167176
{
168-
width: side,
169-
height: side,
177+
width: maskW,
178+
height: maskH,
170179
alphaType: AlphaType.Unpremul,
171180
colorType: ColorType.RGBA_8888,
172181
},
173182
skData,
174-
side * 4
183+
maskW * 4
175184
);
176185
if (img) scheduleOnRN(updateMask, img);
177186
}

0 commit comments

Comments
 (0)