Skip to content

Commit f070fcb

Browse files
committed
add type for hook return
1 parent dc38a5d commit f070fcb

3 files changed

Lines changed: 55 additions & 7 deletions

File tree

apps/computer-vision/app/image_segmentation/index.tsx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import Spinner from '../../components/Spinner';
22
import { BottomBar } from '../../components/BottomBar';
33
import { getImage } from '../../utils';
4-
import { useImageSegmentation } from 'react-native-executorch';
4+
import {
5+
DEEPLAB_V3_RESNET50,
6+
useImageSegmentation,
7+
} from 'react-native-executorch';
58
import {
69
Canvas,
710
Image as SkiaImage,
@@ -43,10 +46,7 @@ export default function ImageSegmentationScreen() {
4346
const { setGlobalGenerating } = useContext(GeneratingContext);
4447
const { isReady, isGenerating, downloadProgress, forward } =
4548
useImageSegmentation({
46-
model: {
47-
modelName: 'deeplab-v3',
48-
modelSource: 'https://ai.swmansion.com/storage/jc_tests/selfie_seg.pte',
49-
},
49+
model: DEEPLAB_V3_RESNET50,
5050
});
5151
const [imageUri, setImageUri] = useState('');
5252
const [imageSize, setImageSize] = useState({ width: 0, height: 0 });
@@ -72,7 +72,7 @@ export default function ImageSegmentationScreen() {
7272
if (!imageUri || imageSize.width === 0 || imageSize.height === 0) return;
7373
try {
7474
const { width, height } = imageSize;
75-
const output = await forward(imageUri, ['PERSON'], true);
75+
const output = await forward(imageUri, [], true);
7676
const argmax = output['ARGMAX'] || [];
7777
const pixels = new Uint8Array(width * height * 4);
7878

packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {
55
} from '../../modules/computer_vision/ImageSegmentationModule';
66
import {
77
ImageSegmentationProps,
8+
ImageSegmentationType,
89
ModelNameOf,
910
ModelSources,
1011
} from '../../types/imageSegmentation';
@@ -30,7 +31,9 @@ import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils';
3031
export const useImageSegmentation = <C extends ModelSources>({
3132
model,
3233
preventLoad = false,
33-
}: ImageSegmentationProps<C>) => {
34+
}: ImageSegmentationProps<C>): ImageSegmentationType<
35+
SegmentationLabels<ModelNameOf<C>>
36+
> => {
3437
const [error, setError] = useState<RnExecutorchError | null>(null);
3538
const [isReady, setIsReady] = useState(false);
3639
const [isGenerating, setIsGenerating] = useState(false);

packages/react-native-executorch/src/types/imageSegmentation.ts

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { RnExecutorchError } from '../errors/errorUtils';
12
import { LabelEnum, Triple, ResourceSource } from './common';
23

34
/**
@@ -95,3 +96,47 @@ export interface ImageSegmentationProps<C extends ModelSources> {
9596
model: C;
9697
preventLoad?: boolean;
9798
}
99+
100+
/**
101+
* Return type for the `useImageSegmentation` hook.
102+
* Manages the state and operations for image segmentation models.
103+
*
104+
* @typeParam L - The {@link LabelEnum} representing the model's class labels.
105+
*
106+
* @category Types
107+
*/
108+
export interface ImageSegmentationType<L extends LabelEnum> {
109+
/**
110+
* Contains the error object if the model failed to load, download, or encountered a runtime error during segmentation.
111+
*/
112+
error: RnExecutorchError | null;
113+
114+
/**
115+
* Indicates whether the segmentation model is loaded and ready to process images.
116+
*/
117+
isReady: boolean;
118+
119+
/**
120+
* Indicates whether the model is currently processing an image.
121+
*/
122+
isGenerating: boolean;
123+
124+
/**
125+
* Represents the download progress of the model binary as a value between 0 and 1.
126+
*/
127+
downloadProgress: number;
128+
129+
/**
130+
* Executes the model's forward pass to perform semantic segmentation on the provided image.
131+
* @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed.
132+
* @param classesOfInterest - An optional array of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless.
133+
* @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`.
134+
* @returns A Promise resolving to an object with an `'ARGMAX'` `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities.
135+
* @throws {RnExecutorchError} If the model is not loaded or is currently processing another image.
136+
*/
137+
forward: <K extends keyof L>(
138+
imageSource: string,
139+
classesOfInterest?: K[],
140+
resizeToInput?: boolean
141+
) => Promise<Record<'ARGMAX', Int32Array> & Record<K, Float32Array>>;
142+
}

0 commit comments

Comments
 (0)