-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathindex.tsx
More file actions
116 lines (105 loc) · 3.38 KB
/
index.tsx
File metadata and controls
116 lines (105 loc) · 3.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import Spinner from '../../components/Spinner';
import { BottomBar } from '../../components/BottomBar';
import { ModelPicker, ModelOption } from '../../components/ModelPicker';
import { getImage } from '../../utils';
import {
useStyleTransfer,
STYLE_TRANSFER_CANDY_QUANTIZED,
STYLE_TRANSFER_MOSAIC_QUANTIZED,
STYLE_TRANSFER_RAIN_PRINCESS_QUANTIZED,
STYLE_TRANSFER_UDNIE_QUANTIZED,
StyleTransferModelName,
ResourceSource,
} from 'react-native-executorch';
import { View, StyleSheet, Image } from 'react-native';
import React, { useContext, useEffect, useState } from 'react';
import { GeneratingContext } from '../../context';
import ScreenWrapper from '../../ScreenWrapper';
import { StatsBar } from '../../components/StatsBar';
type StyleTransferModelSources = {
modelName: StyleTransferModelName;
modelSource: ResourceSource;
};
const MODELS: ModelOption<StyleTransferModelSources>[] = [
{ label: 'Candy', value: STYLE_TRANSFER_CANDY_QUANTIZED },
{ label: 'Mosaic', value: STYLE_TRANSFER_MOSAIC_QUANTIZED },
{ label: 'Rain Princess', value: STYLE_TRANSFER_RAIN_PRINCESS_QUANTIZED },
{ label: 'Udnie', value: STYLE_TRANSFER_UDNIE_QUANTIZED },
];
export default function StyleTransferScreen() {
const [selectedModel, setSelectedModel] = useState<StyleTransferModelSources>(
STYLE_TRANSFER_CANDY_QUANTIZED
);
const model = useStyleTransfer({ model: selectedModel });
const { setGlobalGenerating } = useContext(GeneratingContext);
useEffect(() => {
setGlobalGenerating(model.isGenerating);
}, [model.isGenerating, setGlobalGenerating]);
const [imageUri, setImageUri] = useState('');
const [styledUri, setStyledUri] = useState('');
const [inferenceTime, setInferenceTime] = useState<number | null>(null);
const handleCameraPress = async (isCamera: boolean) => {
const image = await getImage(isCamera);
const uri = image?.uri;
if (typeof uri === 'string') {
setImageUri(uri);
setStyledUri('');
setInferenceTime(null);
}
};
const runForward = async () => {
if (imageUri) {
try {
const start = Date.now();
const uri = await model.forward(imageUri, 'url');
setInferenceTime(Date.now() - start);
setStyledUri(uri);
} catch (e) {
console.error(e);
}
}
};
if (!model.isReady) {
return (
<Spinner
visible={!model.isReady}
textContent={`Loading the model ${(model.downloadProgress * 100).toFixed(0)} %`}
/>
);
}
return (
<ScreenWrapper>
<View style={styles.imageContainer}>
<Image
style={styles.image}
resizeMode="contain"
source={
styledUri
? { uri: styledUri }
: imageUri
? { uri: imageUri }
: require('../../assets/icons/executorch_logo.png')
}
/>
</View>
<ModelPicker
models={MODELS}
selectedModel={selectedModel}
disabled={model.isGenerating}
onSelect={(m) => {
setSelectedModel(m);
setStyledUri('');
}}
/>
<StatsBar inferenceTime={inferenceTime} />
<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
/>
</ScreenWrapper>
);
}
const styles = StyleSheet.create({
imageContainer: { flex: 6, width: '100%', padding: 16 },
image: { flex: 1, borderRadius: 8, width: '100%' },
});