-
Notifications
You must be signed in to change notification settings - Fork 69
Expand file tree
/
Copy pathClassificationTask.tsx
More file actions
124 lines (113 loc) · 3.52 KB
/
ClassificationTask.tsx
File metadata and controls
124 lines (113 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import React, { useCallback, useEffect, useRef, useState } from 'react';
import { StyleSheet, Text, View } from 'react-native';
import { Frame, useFrameOutput } from 'react-native-vision-camera';
import { scheduleOnRN } from 'react-native-worklets';
import { EFFICIENTNET_V2_S, useClassification } from 'react-native-executorch';
import { TaskProps } from './types';
type Props = Omit<
TaskProps,
'activeModel' | 'canvasSize' | 'cameraPositionSync'
>;
export default function ClassificationTask({
frameKillSwitch,
onFrameOutputChange,
onReadyChange,
onProgressChange,
onGeneratingChange,
onFpsChange,
}: Props) {
const model = useClassification({ model: EFFICIENTNET_V2_S });
const [classResult, setClassResult] = useState({ label: '', score: 0 });
const lastFrameTimeRef = useRef(Date.now());
useEffect(() => {
onReadyChange(model.isReady);
}, [model.isReady, onReadyChange]);
useEffect(() => {
onProgressChange(model.downloadProgress);
}, [model.downloadProgress, onProgressChange]);
useEffect(() => {
onGeneratingChange(model.isGenerating);
}, [model.isGenerating, onGeneratingChange]);
const classRof = model.runOnFrame;
const updateClass = useCallback(
(r: { label: string; score: number }) => {
setClassResult(r);
const now = Date.now();
const diff = now - lastFrameTimeRef.current;
if (diff > 0) onFpsChange(Math.round(1000 / diff), diff);
lastFrameTimeRef.current = now;
},
[onFpsChange]
);
const frameOutput = useFrameOutput({
pixelFormat: 'rgb',
dropFramesWhileBusy: true,
enablePreviewSizedOutputBuffers: true,
onFrame: useCallback(
(frame: Frame) => {
'worklet';
if (frameKillSwitch.getDirty()) {
frame.dispose();
return;
}
try {
if (!classRof) return;
const result = classRof(frame);
if (result) {
let bestLabel = '';
let bestScore = -1;
const entries = Object.entries(result);
for (let i = 0; i < entries.length; i++) {
const [label, score] = entries[i]!;
if ((score as number) > bestScore) {
bestScore = score as number;
bestLabel = label;
}
}
scheduleOnRN(updateClass, { label: bestLabel, score: bestScore });
}
} catch {
// Frame may be disposed before processing completes — transient, safe to ignore.
} finally {
frame.dispose();
}
},
[classRof, frameKillSwitch, updateClass]
),
});
useEffect(() => {
onFrameOutputChange(frameOutput);
}, [frameOutput, onFrameOutputChange]);
return classResult.label ? (
<View style={styles.overlay} pointerEvents="none">
<Text style={styles.label}>{classResult.label}</Text>
<Text style={styles.score}>{(classResult.score * 100).toFixed(1)}%</Text>
</View>
) : null;
}
const styles = StyleSheet.create({
overlay: {
...StyleSheet.absoluteFillObject,
justifyContent: 'center',
alignItems: 'center',
},
label: {
color: 'white',
fontSize: 28,
fontWeight: '700',
textAlign: 'center',
textShadowColor: 'rgba(0,0,0,0.8)',
textShadowOffset: { width: 0, height: 1 },
textShadowRadius: 6,
paddingHorizontal: 24,
},
score: {
color: 'rgba(255,255,255,0.75)',
fontSize: 18,
fontWeight: '500',
marginTop: 4,
textShadowColor: 'rgba(0,0,0,0.8)',
textShadowOffset: { width: 0, height: 1 },
textShadowRadius: 6,
},
});