-
Notifications
You must be signed in to change notification settings - Fork 76
Expand file tree
/
Copy pathindex.tsx
More file actions
125 lines (116 loc) · 3.46 KB
/
index.tsx
File metadata and controls
125 lines (116 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import Spinner from '../../components/Spinner';
import { BottomBar } from '../../components/BottomBar';
import { ModelPicker, ModelOption } from '../../components/ModelPicker';
import { getImage } from '../../utils';
import {
Detection,
useObjectDetection,
RF_DETR_NANO,
SSDLITE_320_MOBILENET_V3_LARGE,
ObjectDetectionModelSources,
} from 'react-native-executorch';
import { View, StyleSheet, Image } from 'react-native';
import ImageWithBboxes from '../../components/ImageWithBboxes';
import React, { useContext, useEffect, useState } from 'react';
import { GeneratingContext } from '../../context';
import ScreenWrapper from '../../ScreenWrapper';
const MODELS: ModelOption<ObjectDetectionModelSources>[] = [
{ label: 'RF-DeTR Nano', value: RF_DETR_NANO },
{ label: 'SSDLite MobileNet', value: SSDLITE_320_MOBILENET_V3_LARGE },
];
export default function ObjectDetectionScreen() {
const [imageUri, setImageUri] = useState('');
const [results, setResults] = useState<Detection[]>([]);
const [imageDimensions, setImageDimensions] = useState<{
width: number;
height: number;
}>();
const [selectedModel, setSelectedModel] =
useState<ObjectDetectionModelSources>(RF_DETR_NANO);
const model = useObjectDetection({ model: selectedModel });
const { setGlobalGenerating } = useContext(GeneratingContext);
useEffect(() => {
setGlobalGenerating(model.isGenerating);
}, [model.isGenerating, setGlobalGenerating]);
const handleCameraPress = async (isCamera: boolean) => {
const image = await getImage(isCamera);
const uri = image?.uri;
const width = image?.width;
const height = image?.height;
if (uri && width && height) {
setImageUri(image.uri as string);
setImageDimensions({ width: width as number, height: height as number });
setResults([]);
}
};
const runForward = async () => {
if (imageUri) {
try {
const output = await model.forward(imageUri);
setResults(output);
} catch (e) {
console.error(e);
}
}
};
if (!model.isReady) {
return (
<Spinner
visible={!model.isReady}
textContent={`Loading the model ${(model.downloadProgress * 100).toFixed(0)} %`}
/>
);
}
return (
<ScreenWrapper>
<View style={styles.imageContainer}>
<View style={styles.image}>
{imageUri && imageDimensions?.width && imageDimensions?.height ? (
<ImageWithBboxes
imageUri={
imageUri || require('../../assets/icons/executorch_logo.png')
}
imageWidth={imageDimensions.width}
imageHeight={imageDimensions.height}
detections={results}
/>
) : (
<Image
style={styles.fullSizeImage}
resizeMode="contain"
source={require('../../assets/icons/executorch_logo.png')}
/>
)}
</View>
</View>
<ModelPicker
models={MODELS}
selectedModel={selectedModel}
onSelect={(m) => {
setSelectedModel(m);
setResults([]);
}}
/>
<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
/>
</ScreenWrapper>
);
}
const styles = StyleSheet.create({
imageContainer: {
flex: 6,
width: '100%',
padding: 16,
},
image: {
flex: 2,
borderRadius: 8,
width: '100%',
},
fullSizeImage: {
width: '100%',
height: '100%',
},
});