Skip to content

Commit 50e868e

Browse files
committed
chore: add model selector where applicable
1 parent 621617a commit 50e868e

File tree

4 files changed

+212
-6
lines changed

4 files changed

+212
-6
lines changed

apps/computer-vision/app/classification/index.tsx

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,16 @@ import Spinner from '../../components/Spinner';
22
import { getImage } from '../../utils';
33
import {
44
useClassification,
5+
EFFICIENTNET_V2_S,
56
EFFICIENTNET_V2_S_QUANTIZED,
7+
ClassificationModelSources,
68
} from 'react-native-executorch';
9+
import { ModelPicker, ModelOption } from '../../components/ModelPicker';
10+
11+
const MODELS: ModelOption<ClassificationModelSources>[] = [
12+
{ label: 'EfficientNet V2 S Quantized', value: EFFICIENTNET_V2_S_QUANTIZED },
13+
{ label: 'EfficientNet V2 S', value: EFFICIENTNET_V2_S },
14+
];
715
import { View, StyleSheet, Image, Text, ScrollView } from 'react-native';
816
import { BottomBar } from '../../components/BottomBar';
917
import React, { useContext, useEffect, useState } from 'react';
@@ -13,6 +21,8 @@ import { StatsBar } from '../../components/StatsBar';
1321
import ErrorBanner from '../../components/ErrorBanner';
1422

1523
export default function ClassificationScreen() {
24+
const [selectedModel, setSelectedModel] =
25+
useState<ClassificationModelSources>(EFFICIENTNET_V2_S_QUANTIZED);
1626
const [results, setResults] = useState<{ label: string; score: number }[]>(
1727
[]
1828
);
@@ -21,7 +31,7 @@ export default function ClassificationScreen() {
2131

2232
const [error, setError] = useState<string | null>(null);
2333

24-
const model = useClassification({ model: EFFICIENTNET_V2_S_QUANTIZED });
34+
const model = useClassification({ model: selectedModel });
2535
const { setGlobalGenerating } = useContext(GeneratingContext);
2636

2737
useEffect(() => {
@@ -106,6 +116,15 @@ export default function ClassificationScreen() {
106116
</View>
107117
)}
108118
</View>
119+
<ModelPicker
120+
models={MODELS}
121+
selectedModel={selectedModel}
122+
disabled={model.isGenerating}
123+
onSelect={(m) => {
124+
setSelectedModel(m);
125+
setResults([]);
126+
}}
127+
/>
109128
<StatsBar inferenceTime={inferenceTime} />
110129
<BottomBar
111130
handleCameraPress={handleCameraPress}

apps/text-embeddings/app/clip-embeddings/index.tsx

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,21 @@ import {
1616
useTextEmbeddings,
1717
useImageEmbeddings,
1818
CLIP_VIT_BASE_PATCH32_TEXT,
19+
CLIP_VIT_BASE_PATCH32_IMAGE,
1920
CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED,
21+
ImageEmbeddingsProps,
2022
} from 'react-native-executorch';
23+
24+
type ImageEmbeddingModel = ImageEmbeddingsProps['model'];
25+
26+
const IMAGE_MODELS: { label: string; value: ImageEmbeddingModel }[] = [
27+
{ label: 'ViT-B/32 Quantized', value: CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED },
28+
{ label: 'ViT-B/32 FP32', value: CLIP_VIT_BASE_PATCH32_IMAGE },
29+
];
2130
import { launchImageLibrary } from 'react-native-image-picker';
2231
import { useIsFocused } from '@react-navigation/native';
2332
import { dotProduct } from '../../utils/math';
33+
import { ModelPicker } from '../../components/ModelPicker';
2434

2535
const DEFAULT_LABELS = [
2636
'a photo of a dog',
@@ -37,10 +47,11 @@ export default function ClipEmbeddingsScreenWrapper() {
3747
}
3848

3949
function ClipEmbeddingsScreen() {
50+
const [selectedImageModel, setSelectedImageModel] =
51+
useState<ImageEmbeddingModel>(CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED);
52+
4053
const textModel = useTextEmbeddings({ model: CLIP_VIT_BASE_PATCH32_TEXT });
41-
const imageModel = useImageEmbeddings({
42-
model: CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED,
43-
});
54+
const imageModel = useImageEmbeddings({ model: selectedImageModel });
4455

4556
const [imageUri, setImageUri] = useState<string | null>(null);
4657
const [newLabel, setNewLabel] = useState('');
@@ -131,6 +142,15 @@ function ClipEmbeddingsScreen() {
131142
</Text>
132143
</View>
133144

145+
<ModelPicker
146+
models={IMAGE_MODELS}
147+
selectedModel={selectedImageModel}
148+
onSelect={(m) => {
149+
setSelectedImageModel(m);
150+
setResults([]);
151+
}}
152+
/>
153+
134154
{/* Image picker */}
135155
<TouchableOpacity style={styles.imagePicker} onPress={pickImage}>
136156
{imageUri ? (

apps/text-embeddings/app/text-embeddings/index.tsx

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,24 @@ import {
1111
Platform,
1212
} from 'react-native';
1313
import { Ionicons } from '@expo/vector-icons';
14-
import { useTextEmbeddings, ALL_MINILM_L6_V2 } from 'react-native-executorch';
14+
import { ModelPicker } from '../../components/ModelPicker';
15+
import {
16+
useTextEmbeddings,
17+
ALL_MINILM_L6_V2,
18+
ALL_MPNET_BASE_V2,
19+
MULTI_QA_MINILM_L6_COS_V1,
20+
MULTI_QA_MPNET_BASE_DOT_V1,
21+
TextEmbeddingsProps,
22+
} from 'react-native-executorch';
23+
24+
type TextEmbeddingModel = TextEmbeddingsProps['model'];
25+
26+
const MODELS: { label: string; value: TextEmbeddingModel }[] = [
27+
{ label: 'MiniLM L6', value: ALL_MINILM_L6_V2 },
28+
{ label: 'MPNet Base', value: ALL_MPNET_BASE_V2 },
29+
{ label: 'MultiQA MiniLM', value: MULTI_QA_MINILM_L6_COS_V1 },
30+
{ label: 'MultiQA MPNet', value: MULTI_QA_MPNET_BASE_DOT_V1 },
31+
];
1532
import { useIsFocused } from '@react-navigation/native';
1633
import { dotProduct } from '../../utils/math';
1734
import ErrorBanner from '../../components/ErrorBanner';
@@ -23,7 +40,9 @@ export default function TextEmbeddingsScreenWrapper() {
2340
}
2441

2542
function TextEmbeddingsScreen() {
26-
const model = useTextEmbeddings({ model: ALL_MINILM_L6_V2 });
43+
const [selectedModel, setSelectedModel] =
44+
useState<TextEmbeddingModel>(ALL_MINILM_L6_V2);
45+
const model = useTextEmbeddings({ model: selectedModel });
2746
const [error, setError] = useState<string | null>(null);
2847

2948
const [inputSentence, setInputSentence] = useState('');
@@ -132,6 +151,15 @@ function TextEmbeddingsScreen() {
132151
<ScrollView contentContainerStyle={styles.scrollContainer}>
133152
<Text style={styles.heading}>Text Embeddings Playground</Text>
134153
<Text style={styles.sectionTitle}>{getModelStatusText()}</Text>
154+
<ModelPicker
155+
models={MODELS}
156+
selectedModel={selectedModel}
157+
onSelect={(m) => {
158+
setSelectedModel(m);
159+
setSentencesWithEmbeddings([]);
160+
setTopMatches([]);
161+
}}
162+
/>
135163
<ErrorBanner message={error} onDismiss={() => setError(null)} />
136164

137165
<View style={styles.card}>
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import React, { useEffect, useState } from 'react';
2+
import {
3+
View,
4+
StyleSheet,
5+
Text,
6+
TouchableOpacity,
7+
ScrollView,
8+
} from 'react-native';
9+
10+
export type ModelOption<T> = {
11+
label: string;
12+
value: T;
13+
};
14+
15+
type Props<T> = {
16+
models: ModelOption<T>[];
17+
selectedModel: T;
18+
onSelect: (model: T) => void;
19+
label?: string;
20+
disabled?: boolean;
21+
};
22+
23+
export function ModelPicker<T>({
24+
models,
25+
selectedModel,
26+
onSelect,
27+
label,
28+
disabled,
29+
}: Props<T>) {
30+
const [open, setOpen] = useState(false);
31+
const selected = models.find((m) => m.value === selectedModel);
32+
33+
useEffect(() => {
34+
if (disabled) setOpen(false);
35+
}, [disabled]);
36+
37+
return (
38+
<View style={styles.container}>
39+
<TouchableOpacity
40+
style={[styles.trigger, disabled && styles.triggerDisabled]}
41+
onPress={() => !disabled && setOpen((v) => !v)}
42+
activeOpacity={disabled ? 1 : 0.7}
43+
>
44+
{label && <Text style={styles.label}>{label}</Text>}
45+
<Text style={styles.triggerText}>{selected?.label ?? '—'}</Text>
46+
<Text style={styles.chevron}>{open ? '▲' : '▼'}</Text>
47+
</TouchableOpacity>
48+
49+
{open && (
50+
<ScrollView
51+
style={styles.dropdown}
52+
nestedScrollEnabled
53+
keyboardShouldPersistTaps="handled"
54+
>
55+
{models.map((item) => {
56+
const isSelected = item.value === selectedModel;
57+
return (
58+
<TouchableOpacity
59+
key={item.label}
60+
style={[styles.option, isSelected && styles.optionSelected]}
61+
onPress={() => {
62+
onSelect(item.value);
63+
setOpen(false);
64+
}}
65+
>
66+
<Text
67+
style={[
68+
styles.optionText,
69+
isSelected && styles.optionTextSelected,
70+
]}
71+
>
72+
{item.label}
73+
</Text>
74+
</TouchableOpacity>
75+
);
76+
})}
77+
</ScrollView>
78+
)}
79+
</View>
80+
);
81+
}
82+
83+
const styles = StyleSheet.create({
84+
container: { marginHorizontal: 12, marginVertical: 4, alignSelf: 'stretch' },
85+
trigger: {
86+
flexDirection: 'row',
87+
alignItems: 'center',
88+
borderWidth: 1,
89+
borderColor: '#C1C6E5',
90+
borderRadius: 8,
91+
paddingHorizontal: 12,
92+
paddingVertical: 10,
93+
backgroundColor: '#f5f5f5',
94+
},
95+
triggerDisabled: {
96+
opacity: 0.4,
97+
},
98+
label: {
99+
fontSize: 12,
100+
color: '#888',
101+
marginRight: 6,
102+
},
103+
triggerText: {
104+
flex: 1,
105+
fontSize: 14,
106+
color: '#001A72',
107+
fontWeight: '500',
108+
},
109+
chevron: {
110+
fontSize: 10,
111+
color: '#888',
112+
marginLeft: 6,
113+
},
114+
dropdown: {
115+
borderWidth: 1,
116+
borderColor: '#C1C6E5',
117+
borderRadius: 8,
118+
backgroundColor: '#fff',
119+
maxHeight: 200,
120+
marginTop: 2,
121+
},
122+
option: {
123+
paddingHorizontal: 12,
124+
paddingVertical: 10,
125+
borderBottomWidth: 1,
126+
borderBottomColor: '#f0f0f0',
127+
},
128+
optionSelected: {
129+
backgroundColor: '#e8ecf8',
130+
},
131+
optionText: {
132+
fontSize: 14,
133+
color: '#333',
134+
},
135+
optionTextSelected: {
136+
color: '#001A72',
137+
fontWeight: '600',
138+
},
139+
});

0 commit comments

Comments
 (0)