Skip to content

Commit 514bf8d

Browse files
committed
wip
1 parent cc39ca2 commit 514bf8d

8 files changed

Lines changed: 307 additions & 317 deletions

File tree

packages/react-native-executorch/src/constants/commonVision.ts

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,101 @@ import { Triple } from '../types/common';
22

33
export const IMAGENET1K_MEAN: Triple<number> = [0.485, 0.456, 0.406];
44
export const IMAGENET1K_STD: Triple<number> = [0.229, 0.224, 0.225];
5+
6+
/**
7+
* COCO dataset class labels used for object detection.
8+
*
9+
* @category Types
10+
*/
11+
export enum CocoLabel {
12+
PERSON = 1,
13+
BICYCLE = 2,
14+
CAR = 3,
15+
MOTORCYCLE = 4,
16+
AIRPLANE = 5,
17+
BUS = 6,
18+
TRAIN = 7,
19+
TRUCK = 8,
20+
BOAT = 9,
21+
TRAFFIC_LIGHT = 10,
22+
FIRE_HYDRANT = 11,
23+
STREET_SIGN = 12,
24+
STOP_SIGN = 13,
25+
PARKING = 14,
26+
BENCH = 15,
27+
BIRD = 16,
28+
CAT = 17,
29+
DOG = 18,
30+
HORSE = 19,
31+
SHEEP = 20,
32+
COW = 21,
33+
ELEPHANT = 22,
34+
BEAR = 23,
35+
ZEBRA = 24,
36+
GIRAFFE = 25,
37+
HAT = 26,
38+
BACKPACK = 27,
39+
UMBRELLA = 28,
40+
SHOE = 29,
41+
EYE = 30,
42+
HANDBAG = 31,
43+
TIE = 32,
44+
SUITCASE = 33,
45+
FRISBEE = 34,
46+
SKIS = 35,
47+
SNOWBOARD = 36,
48+
SPORTS = 37,
49+
KITE = 38,
50+
BASEBALL = 39,
51+
SKATEBOARD = 41,
52+
SURFBOARD = 42,
53+
TENNIS_RACKET = 43,
54+
BOTTLE = 44,
55+
PLATE = 45,
56+
WINE_GLASS = 46,
57+
CUP = 47,
58+
FORK = 48,
59+
KNIFE = 49,
60+
SPOON = 50,
61+
BOWL = 51,
62+
BANANA = 52,
63+
APPLE = 53,
64+
SANDWICH = 54,
65+
ORANGE = 55,
66+
BROCCOLI = 56,
67+
CARROT = 57,
68+
HOT_DOG = 58,
69+
PIZZA = 59,
70+
DONUT = 60,
71+
CAKE = 61,
72+
CHAIR = 62,
73+
COUCH = 63,
74+
POTTED_PLANT = 64,
75+
BED = 65,
76+
MIRROR = 66,
77+
DINING_TABLE = 67,
78+
WINDOW = 68,
79+
DESK = 69,
80+
TOILET = 70,
81+
DOOR = 71,
82+
TV = 72,
83+
LAPTOP = 73,
84+
MOUSE = 74,
85+
REMOTE = 75,
86+
KEYBOARD = 76,
87+
CELL_PHONE = 77,
88+
MICROWAVE = 78,
89+
OVEN = 79,
90+
TOASTER = 80,
91+
SINK = 81,
92+
REFRIGERATOR = 82,
93+
BLENDER = 83,
94+
BOOK = 84,
95+
CLOCK = 85,
96+
VASE = 86,
97+
SCISSORS = 87,
98+
TEDDY_BEAR = 88,
99+
HAIR_DRIER = 89,
100+
TOOTHBRUSH = 90,
101+
HAIR_BRUSH = 91,
102+
}
Lines changed: 15 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import { useState, useEffect } from 'react';
21
import {
32
ImageSegmentationModule,
43
SegmentationLabels,
@@ -9,8 +8,7 @@ import {
98
ModelNameOf,
109
ModelSources,
1110
} from '../../types/imageSegmentation';
12-
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
13-
import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils';
11+
import { useModuleFactory } from '../useModuleFactory';
1412

1513
/**
1614
* React hook for managing an Image Segmentation model instance.
@@ -34,82 +32,22 @@ export const useImageSegmentation = <C extends ModelSources>({
3432
}: ImageSegmentationProps<C>): ImageSegmentationType<
3533
SegmentationLabels<ModelNameOf<C>>
3634
> => {
37-
const [error, setError] = useState<RnExecutorchError | null>(null);
38-
const [isReady, setIsReady] = useState(false);
39-
const [isGenerating, setIsGenerating] = useState(false);
40-
const [downloadProgress, setDownloadProgress] = useState(0);
41-
const [instance, setInstance] = useState<ImageSegmentationModule<
42-
ModelNameOf<C>
43-
> | null>(null);
44-
45-
useEffect(() => {
46-
if (preventLoad) return;
47-
48-
let isMounted = true;
49-
let currentInstance: ImageSegmentationModule<ModelNameOf<C>> | null = null;
50-
51-
(async () => {
52-
setDownloadProgress(0);
53-
setError(null);
54-
setIsReady(false);
55-
try {
56-
currentInstance = await ImageSegmentationModule.fromModelName(
57-
model,
58-
(progress) => {
59-
if (isMounted) setDownloadProgress(progress);
60-
}
61-
);
62-
if (isMounted) {
63-
setInstance(currentInstance);
64-
setIsReady(true);
65-
}
66-
} catch (err) {
67-
if (isMounted) setError(parseUnknownError(err));
68-
}
69-
})();
70-
71-
return () => {
72-
isMounted = false;
73-
currentInstance?.delete();
74-
};
75-
76-
// eslint-disable-next-line react-hooks/exhaustive-deps
77-
}, [model.modelName, model.modelSource, preventLoad]);
78-
79-
const forward = async <K extends keyof SegmentationLabels<ModelNameOf<C>>>(
35+
const { error, isReady, isGenerating, downloadProgress, runForward } =
36+
useModuleFactory({
37+
factory: (config, onProgress) =>
38+
ImageSegmentationModule.fromModelName(config, onProgress),
39+
config: model,
40+
preventLoad,
41+
});
42+
43+
const forward = <K extends keyof SegmentationLabels<ModelNameOf<C>>>(
8044
imageSource: string,
8145
classesOfInterest: K[] = [],
8246
resizeToInput: boolean = true
83-
) => {
84-
if (!isReady || !instance) {
85-
throw new RnExecutorchError(
86-
RnExecutorchErrorCode.ModuleNotLoaded,
87-
'The model is currently not loaded. Please load the model before calling forward().'
88-
);
89-
}
90-
if (isGenerating) {
91-
throw new RnExecutorchError(
92-
RnExecutorchErrorCode.ModelGenerating,
93-
'The model is currently generating. Please wait until previous model run is complete.'
94-
);
95-
}
96-
try {
97-
setIsGenerating(true);
98-
return await instance.forward(
99-
imageSource,
100-
classesOfInterest,
101-
resizeToInput
102-
);
103-
} finally {
104-
setIsGenerating(false);
105-
}
106-
};
47+
) =>
48+
runForward((inst) =>
49+
inst.forward(imageSource, classesOfInterest, resizeToInput)
50+
);
10751

108-
return {
109-
error,
110-
isReady,
111-
isGenerating,
112-
downloadProgress,
113-
forward,
114-
};
52+
return { error, isReady, isGenerating, downloadProgress, forward };
11553
};
Lines changed: 13 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import { useState, useEffect } from 'react';
21
import {
32
ObjectDetectionModule,
43
ObjectDetectionLabels,
@@ -8,8 +7,7 @@ import {
87
ObjectDetectionProps,
98
ObjectDetectionType,
109
} from '../../types/objectDetection';
11-
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
12-
import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils';
10+
import { useModuleFactory } from '../useModuleFactory';
1311

1412
/**
1513
* React hook for managing an Object Detection model instance.
@@ -25,75 +23,16 @@ export const useObjectDetection = <C extends ObjectDetectionModelSources>({
2523
}: ObjectDetectionProps<C>): ObjectDetectionType<
2624
ObjectDetectionLabels<C['modelName']>
2725
> => {
28-
const [error, setError] = useState<RnExecutorchError | null>(null);
29-
const [isReady, setIsReady] = useState(false);
30-
const [isGenerating, setIsGenerating] = useState(false);
31-
const [downloadProgress, setDownloadProgress] = useState(0);
32-
const [instance, setInstance] = useState<ObjectDetectionModule<
33-
C['modelName']
34-
> | null>(null);
35-
36-
useEffect(() => {
37-
if (preventLoad) return;
38-
39-
let currentInstance: ObjectDetectionModule<C['modelName']> | null = null;
40-
41-
(async () => {
42-
setDownloadProgress(0);
43-
setError(null);
44-
setIsReady(false);
45-
try {
46-
currentInstance = await ObjectDetectionModule.fromModelName(
47-
model,
48-
setDownloadProgress
49-
);
50-
setInstance(currentInstance);
51-
setIsReady(true);
52-
} catch (err) {
53-
setError(parseUnknownError(err));
54-
}
55-
})();
56-
57-
return () => {
58-
currentInstance?.delete();
59-
};
60-
61-
// eslint-disable-next-line react-hooks/exhaustive-deps
62-
}, [model.modelName, model.modelSource, preventLoad]);
63-
64-
const forward = async (imageSource: string, detectionThreshold?: number) => {
65-
if (!isReady || !instance) {
66-
throw new RnExecutorchError(
67-
RnExecutorchErrorCode.ModuleNotLoaded,
68-
'The model is currently not loaded. Please load the model before calling forward().'
69-
);
70-
}
71-
if (isGenerating) {
72-
throw new RnExecutorchError(
73-
RnExecutorchErrorCode.ModelGenerating,
74-
'The model is currently generating. Please wait until previous model run is complete.'
75-
);
76-
}
77-
try {
78-
setIsGenerating(true);
79-
return (await instance.forward(
80-
imageSource,
81-
detectionThreshold
82-
)) as Awaited<
83-
ReturnType<
84-
ObjectDetectionType<ObjectDetectionLabels<C['modelName']>>['forward']
85-
>
86-
>;
87-
} finally {
88-
setIsGenerating(false);
89-
}
90-
};
91-
92-
return {
93-
error,
94-
isReady,
95-
isGenerating,
96-
downloadProgress,
97-
forward,
98-
};
26+
const { error, isReady, isGenerating, downloadProgress, runForward } =
27+
useModuleFactory({
28+
factory: (config, onProgress) =>
29+
ObjectDetectionModule.fromModelName(config, onProgress),
30+
config: model,
31+
preventLoad,
32+
});
33+
34+
const forward = (imageSource: string, detectionThreshold?: number) =>
35+
runForward((inst) => inst.forward(imageSource, detectionThreshold));
36+
37+
return { error, isReady, isGenerating, downloadProgress, forward };
9938
};

0 commit comments

Comments
 (0)