Skip to content

Commit 653cde9

Browse files
refactor: add SegmentationTask component
1 parent 88edb62 commit 653cde9

File tree

1 file changed

+213
-0
lines changed

1 file changed

+213
-0
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
// apps/computer-vision/app/vision_camera/tasks/SegmentationTask.tsx
2+
import React, { useCallback, useEffect, useRef, useState } from 'react';
3+
import { StyleSheet, View } from 'react-native';
4+
import { Frame, useFrameOutput } from 'react-native-vision-camera';
5+
import { scheduleOnRN } from 'react-native-worklets';
6+
import {
7+
DEEPLAB_V3_RESNET50_QUANTIZED,
8+
DEEPLAB_V3_RESNET101_QUANTIZED,
9+
DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED,
10+
FCN_RESNET50_QUANTIZED,
11+
FCN_RESNET101_QUANTIZED,
12+
LRASPP_MOBILENET_V3_LARGE_QUANTIZED,
13+
SELFIE_SEGMENTATION,
14+
useSemanticSegmentation,
15+
} from 'react-native-executorch';
16+
import {
17+
AlphaType,
18+
Canvas,
19+
ColorType,
20+
Image as SkiaImage,
21+
Skia,
22+
SkImage,
23+
} from '@shopify/react-native-skia';
24+
import { CLASS_COLORS } from '../utils/colors';
25+
import { TaskProps } from './types';
26+
27+
type SegModelId =
28+
| 'segmentation_deeplab_resnet50'
29+
| 'segmentation_deeplab_resnet101'
30+
| 'segmentation_deeplab_mobilenet'
31+
| 'segmentation_lraspp'
32+
| 'segmentation_fcn_resnet50'
33+
| 'segmentation_fcn_resnet101'
34+
| 'segmentation_selfie';
35+
36+
type Props = TaskProps & { activeModel: SegModelId };
37+
38+
export default function SegmentationTask({
39+
activeModel,
40+
canvasSize,
41+
cameraPosition,
42+
frameKillSwitch,
43+
onFrameOutputChange,
44+
onReadyChange,
45+
onProgressChange,
46+
onGeneratingChange,
47+
onFpsChange,
48+
}: Props) {
49+
const segDeeplabResnet50 = useSemanticSegmentation({
50+
model: DEEPLAB_V3_RESNET50_QUANTIZED,
51+
preventLoad: activeModel !== 'segmentation_deeplab_resnet50',
52+
});
53+
const segDeeplabResnet101 = useSemanticSegmentation({
54+
model: DEEPLAB_V3_RESNET101_QUANTIZED,
55+
preventLoad: activeModel !== 'segmentation_deeplab_resnet101',
56+
});
57+
const segDeeplabMobilenet = useSemanticSegmentation({
58+
model: DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED,
59+
preventLoad: activeModel !== 'segmentation_deeplab_mobilenet',
60+
});
61+
const segLraspp = useSemanticSegmentation({
62+
model: LRASPP_MOBILENET_V3_LARGE_QUANTIZED,
63+
preventLoad: activeModel !== 'segmentation_lraspp',
64+
});
65+
const segFcnResnet50 = useSemanticSegmentation({
66+
model: FCN_RESNET50_QUANTIZED,
67+
preventLoad: activeModel !== 'segmentation_fcn_resnet50',
68+
});
69+
const segFcnResnet101 = useSemanticSegmentation({
70+
model: FCN_RESNET101_QUANTIZED,
71+
preventLoad: activeModel !== 'segmentation_fcn_resnet101',
72+
});
73+
const segSelfie = useSemanticSegmentation({
74+
model: SELFIE_SEGMENTATION,
75+
preventLoad: activeModel !== 'segmentation_selfie',
76+
});
77+
78+
const active = {
79+
segmentation_deeplab_resnet50: segDeeplabResnet50,
80+
segmentation_deeplab_resnet101: segDeeplabResnet101,
81+
segmentation_deeplab_mobilenet: segDeeplabMobilenet,
82+
segmentation_lraspp: segLraspp,
83+
segmentation_fcn_resnet50: segFcnResnet50,
84+
segmentation_fcn_resnet101: segFcnResnet101,
85+
segmentation_selfie: segSelfie,
86+
}[activeModel];
87+
88+
const [maskImage, setMaskImage] = useState<SkImage | null>(null);
89+
const lastFrameTimeRef = useRef(Date.now());
90+
91+
useEffect(() => {
92+
onReadyChange(active.isReady);
93+
}, [active.isReady, onReadyChange]);
94+
95+
useEffect(() => {
96+
onProgressChange(active.downloadProgress);
97+
}, [active.downloadProgress, onProgressChange]);
98+
99+
useEffect(() => {
100+
onGeneratingChange(active.isGenerating);
101+
}, [active.isGenerating, onGeneratingChange]);
102+
103+
// Clear stale mask when the segmentation model variant changes
104+
useEffect(() => {
105+
setMaskImage((prev) => {
106+
prev?.dispose();
107+
return null;
108+
});
109+
}, [activeModel]);
110+
111+
// Dispose native Skia image on unmount to prevent memory leaks
112+
useEffect(() => {
113+
return () => {
114+
setMaskImage((prev) => {
115+
prev?.dispose();
116+
return null;
117+
});
118+
};
119+
}, []);
120+
121+
const segRof = active.runOnFrame;
122+
123+
const updateMask = useCallback(
124+
(img: SkImage) => {
125+
setMaskImage((prev) => {
126+
prev?.dispose();
127+
return img;
128+
});
129+
const now = Date.now();
130+
const diff = now - lastFrameTimeRef.current;
131+
if (diff > 0) onFpsChange(Math.round(1000 / diff), diff);
132+
lastFrameTimeRef.current = now;
133+
},
134+
[onFpsChange]
135+
);
136+
137+
// CLASS_COLORS captured directly in closure — worklets cannot import modules
138+
const colors = CLASS_COLORS;
139+
140+
const frameOutput = useFrameOutput({
141+
pixelFormat: 'rgb',
142+
dropFramesWhileBusy: true,
143+
onFrame: useCallback(
144+
(frame: Frame) => {
145+
'worklet';
146+
if (frameKillSwitch.getDirty()) {
147+
frame.dispose();
148+
return;
149+
}
150+
try {
151+
if (!segRof) return;
152+
const result = segRof(frame, [], false);
153+
if (result?.ARGMAX) {
154+
const argmax: Int32Array = result.ARGMAX;
155+
const side = Math.round(Math.sqrt(argmax.length));
156+
const pixels = new Uint8Array(side * side * 4);
157+
for (let i = 0; i < argmax.length; i++) {
158+
const color = colors[argmax[i]!] ?? [0, 0, 0, 0];
159+
pixels[i * 4] = color[0]!;
160+
pixels[i * 4 + 1] = color[1]!;
161+
pixels[i * 4 + 2] = color[2]!;
162+
pixels[i * 4 + 3] = color[3]!;
163+
}
164+
const skData = Skia.Data.fromBytes(pixels);
165+
const img = Skia.Image.MakeImage(
166+
{
167+
width: side,
168+
height: side,
169+
alphaType: AlphaType.Unpremul,
170+
colorType: ColorType.RGBA_8888,
171+
},
172+
skData,
173+
side * 4
174+
);
175+
if (img) scheduleOnRN(updateMask, img);
176+
}
177+
} catch {
178+
// ignore
179+
} finally {
180+
frame.dispose();
181+
}
182+
},
183+
[colors, frameKillSwitch, segRof, updateMask]
184+
),
185+
});
186+
187+
useEffect(() => {
188+
onFrameOutputChange(frameOutput);
189+
}, [frameOutput, onFrameOutputChange]);
190+
191+
if (!maskImage) return null;
192+
193+
return (
194+
<View
195+
style={[
196+
StyleSheet.absoluteFill,
197+
cameraPosition === 'front' && { transform: [{ scaleX: -1 }] },
198+
]}
199+
pointerEvents="none"
200+
>
201+
<Canvas style={StyleSheet.absoluteFill}>
202+
<SkiaImage
203+
image={maskImage}
204+
fit="cover"
205+
x={0}
206+
y={0}
207+
width={canvasSize.width}
208+
height={canvasSize.height}
209+
/>
210+
</Canvas>
211+
</View>
212+
);
213+
}

0 commit comments

Comments
 (0)