Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions apps/computer-vision/app/classification/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ import { BottomBar } from '../../components/BottomBar';
import React, { useContext, useEffect, useState } from 'react';
import { GeneratingContext } from '../../context';
import ScreenWrapper from '../../ScreenWrapper';
import { StatsBar } from '../../components/StatsBar';

export default function ClassificationScreen() {
const [results, setResults] = useState<{ label: string; score: number }[]>(
[]
);
const [imageUri, setImageUri] = useState('');
const [inferenceTime, setInferenceTime] = useState<number | null>(null);

const model = useClassification({ model: EFFICIENTNET_V2_S_QUANTIZED });
const { setGlobalGenerating } = useContext(GeneratingContext);
Expand All @@ -28,13 +30,16 @@ export default function ClassificationScreen() {
if (typeof uri === 'string') {
setImageUri(uri as string);
setResults([]);
setInferenceTime(null);
}
};

const runForward = async () => {
if (imageUri) {
try {
const start = Date.now();
const output = await model.forward(imageUri);
setInferenceTime(Date.now() - start);
const top10 = Object.entries(output)
.sort(([, a], [, b]) => (b as number) - (a as number))
.slice(0, 10)
Expand Down Expand Up @@ -80,6 +85,7 @@ export default function ClassificationScreen() {
</View>
)}
</View>
<StatsBar inferenceTime={inferenceTime} />
<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
Expand Down
48 changes: 46 additions & 2 deletions apps/computer-vision/app/instance_segmentation/index.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import Spinner from '../../components/Spinner';
import { BottomBar } from '../../components/BottomBar';
import { getImage } from '../../utils';
import { useInstanceSegmentation, YOLO26N_SEG } from 'react-native-executorch';
import { ModelPicker, ModelOption } from '../../components/ModelPicker';
import {
useInstanceSegmentation,
YOLO26N_SEG,
YOLO26S_SEG,
YOLO26M_SEG,
YOLO26L_SEG,
YOLO26X_SEG,
RF_DETR_NANO_SEG,
InstanceSegmentationModelSources,
} from 'react-native-executorch';
import {
View,
StyleSheet,
Expand All @@ -16,8 +26,22 @@ import ImageWithMasks, {
buildDisplayInstances,
DisplayInstance,
} from '../../components/ImageWithMasks';
import { StatsBar } from '../../components/StatsBar';

const MODELS: ModelOption<InstanceSegmentationModelSources>[] = [
{ label: 'Yolo26N', value: YOLO26N_SEG },
{ label: 'Yolo26S', value: YOLO26S_SEG },
{ label: 'Yolo26M', value: YOLO26M_SEG },
{ label: 'Yolo26L', value: YOLO26L_SEG },
{ label: 'Yolo26X', value: YOLO26X_SEG },
{ label: 'RF-DeTR Nano', value: RF_DETR_NANO_SEG },
];

export default function InstanceSegmentationScreen() {
const [selectedModel, setSelectedModel] =
useState<InstanceSegmentationModelSources>(YOLO26N_SEG);
const [inferenceTime, setInferenceTime] = useState<number | null>(null);

const { setGlobalGenerating } = useContext(GeneratingContext);

const {
Expand All @@ -28,7 +52,7 @@ export default function InstanceSegmentationScreen() {
error,
getAvailableInputSizes,
} = useInstanceSegmentation({
model: YOLO26N_SEG,
model: selectedModel,
});

const [imageUri, setImageUri] = useState('');
Expand Down Expand Up @@ -60,12 +84,14 @@ export default function InstanceSegmentationScreen() {
height: image.height ?? 0,
});
setInstances([]);
setInferenceTime(null);
Comment thread
msluszniak marked this conversation as resolved.
};

const runForward = async () => {
if (!imageUri || imageSize.width === 0 || imageSize.height === 0) return;

try {
const start = Date.now();
const output = await forward(imageUri, {
confidenceThreshold: 0.5,
iouThreshold: 0.55,
Expand All @@ -74,6 +100,8 @@ export default function InstanceSegmentationScreen() {
inputSize: selectedInputSize ?? undefined,
});

setInferenceTime(Date.now() - start);

// Convert raw masks → small Skia images immediately.
// Raw Uint8Array mask buffers (backed by native OwningArrayBuffer)
// go out of scope here and become eligible for GC right away.
Expand Down Expand Up @@ -168,6 +196,22 @@ export default function InstanceSegmentationScreen() {
)}
</View>

<ModelPicker
models={MODELS}
selectedModel={selectedModel}
disabled={isGenerating}
onSelect={(m) => {
setSelectedModel(m);
setInstances([]);
setInferenceTime(null);
}}
/>

<StatsBar
inferenceTime={inferenceTime}
detectionCount={instances.length > 0 ? instances.length : null}
/>

<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
Expand Down
9 changes: 9 additions & 0 deletions apps/computer-vision/app/object_detection/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import ImageWithBboxes from '../../components/ImageWithBboxes';
import React, { useContext, useEffect, useState } from 'react';
import { GeneratingContext } from '../../context';
import ScreenWrapper from '../../ScreenWrapper';
import { StatsBar } from '../../components/StatsBar';

const MODELS: ModelOption<ObjectDetectionModelSources>[] = [
{ label: 'RF-DeTR Nano', value: RF_DETR_NANO },
Expand All @@ -29,6 +30,7 @@ export default function ObjectDetectionScreen() {
}>();
const [selectedModel, setSelectedModel] =
useState<ObjectDetectionModelSources>(RF_DETR_NANO);
const [inferenceTime, setInferenceTime] = useState<number | null>(null);

const model = useObjectDetection({ model: selectedModel });
const { setGlobalGenerating } = useContext(GeneratingContext);
Expand All @@ -46,13 +48,16 @@ export default function ObjectDetectionScreen() {
setImageUri(image.uri as string);
setImageDimensions({ width: width as number, height: height as number });
setResults([]);
setInferenceTime(null);
}
};

const runForward = async () => {
if (imageUri) {
try {
const start = Date.now();
const output = await model.forward(imageUri);
setInferenceTime(Date.now() - start);
setResults(output);
} catch (e) {
console.error(e);
Expand Down Expand Up @@ -100,6 +105,10 @@ export default function ObjectDetectionScreen() {
setResults([]);
}}
/>
<StatsBar
inferenceTime={inferenceTime}
detectionCount={results.length > 0 ? results.length : null}
/>
<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
Expand Down
9 changes: 9 additions & 0 deletions apps/computer-vision/app/ocr/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import ImageWithBboxes2 from '../../components/ImageWithOCRBboxes';
import React, { useContext, useEffect, useState } from 'react';
import { GeneratingContext } from '../../context';
import ScreenWrapper from '../../ScreenWrapper';
import { StatsBar } from '../../components/StatsBar';

type OCRModelSources = OCRProps['model'];

Expand All @@ -40,6 +41,7 @@ export default function OCRScreen() {
}>();
const [selectedModel, setSelectedModel] =
useState<OCRModelSources>(OCR_ENGLISH);
const [inferenceTime, setInferenceTime] = useState<number | null>(null);

const model = useOCR({
model: selectedModel,
Expand All @@ -58,12 +60,15 @@ export default function OCRScreen() {
if (typeof uri === 'string') {
setImageUri(uri as string);
setResults([]);
setInferenceTime(null);
}
};

const runForward = async () => {
try {
const start = Date.now();
const output = await model.forward(imageUri);
setInferenceTime(Date.now() - start);
setResults(output);
} catch (e) {
console.error(e);
Expand Down Expand Up @@ -123,6 +128,10 @@ export default function OCRScreen() {
setResults([]);
}}
/>
<StatsBar
inferenceTime={inferenceTime}
detectionCount={results.length > 0 ? results.length : null}
/>
<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
Expand Down
9 changes: 9 additions & 0 deletions apps/computer-vision/app/ocr_vertical/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import ImageWithBboxes2 from '../../components/ImageWithOCRBboxes';
import React, { useContext, useEffect, useState } from 'react';
import { GeneratingContext } from '../../context';
import ScreenWrapper from '../../ScreenWrapper';
import { StatsBar } from '../../components/StatsBar';

export default function VerticalOCRScree() {
const [imageUri, setImageUri] = useState('');
Expand All @@ -15,6 +16,7 @@ export default function VerticalOCRScree() {
width: number;
height: number;
}>();
const [inferenceTime, setInferenceTime] = useState<number | null>(null);
const model = useVerticalOCR({
model: OCR_ENGLISH,
independentCharacters: true,
Expand All @@ -33,12 +35,15 @@ export default function VerticalOCRScree() {
if (typeof uri === 'string') {
setImageUri(uri as string);
setResults([]);
setInferenceTime(null);
}
};

const runForward = async () => {
try {
const start = Date.now();
const output = await model.forward(imageUri);
setInferenceTime(Date.now() - start);
setResults(output);
} catch (e) {
console.error(e);
Expand Down Expand Up @@ -89,6 +94,10 @@ export default function VerticalOCRScree() {
</View>
)}
</View>
<StatsBar
inferenceTime={inferenceTime}
detectionCount={results.length > 0 ? results.length : null}
/>
<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
Expand Down
6 changes: 6 additions & 0 deletions apps/computer-vision/app/semantic_segmentation/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { View, StyleSheet, Image } from 'react-native';
import React, { useContext, useEffect, useState } from 'react';
import { GeneratingContext } from '../../context';
import ScreenWrapper from '../../ScreenWrapper';
import { StatsBar } from '../../components/StatsBar';

const numberToColor: number[][] = [
[255, 87, 51], // 0 Red
Expand Down Expand Up @@ -75,6 +76,7 @@ export default function SemanticSegmentationScreen() {
const [imageSize, setImageSize] = useState({ width: 0, height: 0 });
const [segImage, setSegImage] = useState<SkImage | null>(null);
const [canvasSize, setCanvasSize] = useState({ width: 0, height: 0 });
const [inferenceTime, setInferenceTime] = useState<number | null>(null);

useEffect(() => {
setGlobalGenerating(isGenerating);
Expand All @@ -86,11 +88,13 @@ export default function SemanticSegmentationScreen() {
setImageUri(image.uri);
setImageSize({ width: image.width ?? 0, height: image.height ?? 0 });
setSegImage(null);
setInferenceTime(null);
};

const runForward = async () => {
if (!imageUri || imageSize.width === 0 || imageSize.height === 0) return;
try {
const start = Date.now();
const { width, height } = imageSize;
const output = await forward(imageUri, [], true);
const argmax = output.ARGMAX || [];
Expand Down Expand Up @@ -119,6 +123,7 @@ export default function SemanticSegmentationScreen() {
width * 4
);
setSegImage(img);
setInferenceTime(Date.now() - start);
} catch (e) {
console.error(e);
}
Expand Down Expand Up @@ -179,6 +184,7 @@ export default function SemanticSegmentationScreen() {
setSegImage(null);
}}
/>
<StatsBar inferenceTime={inferenceTime} />
<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
Expand Down
6 changes: 6 additions & 0 deletions apps/computer-vision/app/style_transfer/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { View, StyleSheet, Image } from 'react-native';
import React, { useContext, useEffect, useState } from 'react';
import { GeneratingContext } from '../../context';
import ScreenWrapper from '../../ScreenWrapper';
import { StatsBar } from '../../components/StatsBar';

type StyleTransferModelSources = {
modelName: StyleTransferModelName;
Expand All @@ -42,20 +43,24 @@ export default function StyleTransferScreen() {

const [imageUri, setImageUri] = useState('');
const [styledUri, setStyledUri] = useState('');
const [inferenceTime, setInferenceTime] = useState<number | null>(null);

const handleCameraPress = async (isCamera: boolean) => {
const image = await getImage(isCamera);
const uri = image?.uri;
if (typeof uri === 'string') {
setImageUri(uri);
setStyledUri('');
setInferenceTime(null);
}
};

const runForward = async () => {
if (imageUri) {
try {
const start = Date.now();
const uri = await model.forward(imageUri, 'url');
setInferenceTime(Date.now() - start);
setStyledUri(uri);
} catch (e) {
console.error(e);
Expand Down Expand Up @@ -96,6 +101,7 @@ export default function StyleTransferScreen() {
setStyledUri('');
}}
/>
<StatsBar inferenceTime={inferenceTime} />
<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
Expand Down
Loading
Loading