Skip to content

Commit e51547e

Browse files
chmjkbmkopcins
andauthored
feat!: make image segmentation generic, general refactor (software-mansion#814)
## Description Refactors image segmentation into a generic, multi-model architecture. Previously the module was hardcoded to DeepLab V3 — now it supports multiple built-in models (DeepLab V3, selfie segmentation, RF-DETR) and custom user-provided models with type-safe label maps. **Key changes:** - **C++ base class**: Extracted `BaseImageSegmentation` with virtual `preprocess()`/`postprocess()` methods. `ImageSegmentation` is now a thin subclass. This allows future models to override preprocessing (e.g. different normalization) or postprocessing without duplicating the pipeline. - **Optional normalization in C++**: `readImageToTensor` now accepts optional `normMean`/`normStd` params, eliminating duplicated normalization logic. **also, imo it would be a good idea to do such factories for the entire API** - **Generic TypeScript module**: `ImageSegmentationModule<T>` is generic over model name or custom `LabelEnum`. Two static factories: `fromModelName()` (built-in models with auto label resolution) and `fromCustomConfig()` (custom models with user-provided labels). - **Generic hook**: `useImageSegmentation` infers the model's label types from the config — no explicit generic parameter needed. `forward()` return type narrows based on `classesOfInterest` passed in. - **Correct return types**: `forward()` now returns `Record<'ARGMAX', Int32Array> & Record<K, Float32Array>` matching what the native side actually produces (was incorrectly typed as `number[]`). - **ARGMAX always returned**: Removed `'ARGMAX'` from `classesOfInterest` — it's always in the output regardless, and the return type reflects this. ### Introduces a breaking change? - [x] Yes - [ ] No ### Type of change - [ ] Bug fix (change which fixes an issue) - [x] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [ ] Other (chores, tests, code style improvements etc.) ### Tested on - [ ] iOS - [x] Android ### Testing instructions 1. Build and run the `computer-vision` demo app 2. Navigate to Image Segmentation screen 3. Pick an image and run segmentation — verify the ARGMAX overlay renders correctly 4. Verify the hook API works as expected: ```ts const { isReady, forward } = useImageSegmentation({ model: { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 }, }); // Returns Record<'ARGMAX', Int32Array> — no generic needed const result = await forward(imageUri); // Narrows return type to include 'PERSON' key as Float32Array const result2 = await forward(imageUri, ['PERSON']); ``` 5. Verify TypeScript autocompletion: `classesOfInterest` should only suggest valid label keys for the chosen model (e.g. `'PERSON'`, `'CAR'` for DeepLab, `'SELFIE'`/`'BACKGROUND'` for selfie segmentation) 6. You can also try changing the parameters, to say selfie segmentation and see how the return types react. Please contact me for weights for selfie segmentation as I'm not pushing them to HF yet ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes The `ImageSegmentationModule.fromCustomConfig()` API allows users to bring their own segmentation model with a custom label map: ```ts const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; const seg = await ImageSegmentationModule.fromCustomConfig( 'https://example.com/model.pte', { labelMap: MyLabels }, ); ``` --------- Co-authored-by: Mateusz Kopcinski <120639731+mkopcins@users.noreply.github.com>
1 parent 2d58410 commit e51547e

31 files changed

Lines changed: 954 additions & 341 deletions

File tree

.cspell-wordlist.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,7 @@ logprob
111111
RNFS
112112
pogodin
113113
kesha
114-
antonov
114+
antonov
115+
rfdetr
116+
basemodule
117+
IMAGENET

apps/computer-vision/app/image_segmentation/index.tsx

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@ import Spinner from '../../components/Spinner';
22
import { BottomBar } from '../../components/BottomBar';
33
import { getImage } from '../../utils';
44
import {
5-
useImageSegmentation,
65
DEEPLAB_V3_RESNET50,
7-
DeeplabLabel,
6+
useImageSegmentation,
87
} from 'react-native-executorch';
98
import {
109
Canvas,
@@ -44,16 +43,20 @@ const numberToColor: number[][] = [
4443
];
4544

4645
export default function ImageSegmentationScreen() {
47-
const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 });
4846
const { setGlobalGenerating } = useContext(GeneratingContext);
49-
useEffect(() => {
50-
setGlobalGenerating(model.isGenerating);
51-
}, [model.isGenerating, setGlobalGenerating]);
47+
const { isReady, isGenerating, downloadProgress, forward } =
48+
useImageSegmentation({
49+
model: DEEPLAB_V3_RESNET50,
50+
});
5251
const [imageUri, setImageUri] = useState('');
5352
const [imageSize, setImageSize] = useState({ width: 0, height: 0 });
5453
const [segImage, setSegImage] = useState<SkImage | null>(null);
5554
const [canvasSize, setCanvasSize] = useState({ width: 0, height: 0 });
5655

56+
useEffect(() => {
57+
setGlobalGenerating(isGenerating);
58+
}, [isGenerating, setGlobalGenerating]);
59+
5760
const handleCameraPress = async (isCamera: boolean) => {
5861
const image = await getImage(isCamera);
5962
if (!image?.uri) return;
@@ -69,12 +72,8 @@ export default function ImageSegmentationScreen() {
6972
if (!imageUri || imageSize.width === 0 || imageSize.height === 0) return;
7073
try {
7174
const { width, height } = imageSize;
72-
const output = await model.forward(imageUri, [DeeplabLabel.ARGMAX]);
73-
const argmax = output[DeeplabLabel.ARGMAX] || [];
74-
const uniqueValues = new Set<number>();
75-
for (let i = 0; i < argmax.length; i++) {
76-
uniqueValues.add(argmax[i]);
77-
}
75+
const output = await forward(imageUri, [], true);
76+
const argmax = output.ARGMAX || [];
7877
const pixels = new Uint8Array(width * height * 4);
7978

8079
for (let row = 0; row < height; row++) {
@@ -105,11 +104,11 @@ export default function ImageSegmentationScreen() {
105104
}
106105
};
107106

108-
if (!model.isReady) {
107+
if (!isReady) {
109108
return (
110109
<Spinner
111-
visible={!model.isReady}
112-
textContent={`Loading the model ${(model.downloadProgress * 100).toFixed(0)} %`}
110+
visible={!isReady}
111+
textContent={`Loading the model ${(downloadProgress * 100).toFixed(0)} %`}
113112
/>
114113
);
115114
}

docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ import {
2121
DEEPLAB_V3_RESNET50,
2222
} from 'react-native-executorch';
2323

24-
const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 });
24+
const model = useImageSegmentation({
25+
model: DEEPLAB_V3_RESNET50,
26+
});
2527

2628
const imageUri = 'file::///Users/.../cute_cat.png';
2729

2830
try {
29-
const outputDict = await model.forward(imageUri);
31+
const result = await model.forward(imageUri);
32+
// result.ARGMAX is an Int32Array of per-pixel class indices
3033
} catch (error) {
3134
console.error(error);
3235
}
@@ -36,9 +39,13 @@ try {
3639

3740
`useImageSegmentation` takes [`ImageSegmentationProps`](../../06-api-reference/interfaces/ImageSegmentationProps.md) that consists of:
3841

39-
- `model` containing [`modelSource`](../../06-api-reference/interfaces/ImageSegmentationProps.md#modelsource).
42+
- `model` - An object containing:
43+
- `modelName` - The name of a built-in model. See [`ModelSources`](../../06-api-reference/type-aliases/ModelSources.md) for the list of supported models.
44+
- `modelSource` - The location of the model binary (a URL or a bundled resource).
4045
- An optional flag [`preventLoad`](../../06-api-reference/interfaces/ImageSegmentationProps.md#preventload) which prevents auto-loading of the model.
4146

47+
The hook is generic over the model config — TypeScript automatically infers the correct label type based on the `modelName` you provide. No explicit generic parameter is needed.
48+
4249
You need more details? Check the following resources:
4350

4451
- For detailed information about `useImageSegmentation` arguments check this section: [`useImageSegmentation` arguments](../../06-api-reference/functions/useImageSegmentation.md#parameters).
@@ -47,45 +54,70 @@ You need more details? Check the following resources:
4754

4855
### Returns
4956

50-
`useImageSegmentation` returns an object called `ImageSegmentationType` containing bunch of functions to interact with image segmentation models. To get more details please read: [`ImageSegmentationType` API Reference](../../06-api-reference/interfaces/ImageSegmentationType.md).
57+
`useImageSegmentation` returns an [`ImageSegmentationType`](../../06-api-reference/interfaces/ImageSegmentationType.md) object containing:
58+
59+
- `isReady` - Whether the model is loaded and ready to process images.
60+
- `isGenerating` - Whether the model is currently processing an image.
61+
- `error` - An error object if the model failed to load or encountered a runtime error.
62+
- `downloadProgress` - A value between 0 and 1 representing the download progress of the model binary.
63+
- `forward` - A function to run inference on an image.
5164

5265
## Running the model
5366

54-
To run the model, you can use the [`forward`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) method. It accepts three arguments: a required image - can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64), an optional list of classes, and an optional flag whether to resize the output to the original dimensions.
67+
To run the model, use the [`forward`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) method. It accepts three arguments:
5568

56-
- The image can be a remote URL, a local file URI, or a base64-encoded image.
57-
- The [`classesOfInterest`](../../06-api-reference/interfaces/ImageSegmentationType.md#classesofinterest) list contains classes for which to output the full results. By default the list is empty, and only the most probable classes are returned (essentially an arg max for each pixel). Look at [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md) enum for possible classes.
58-
- The [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#resizetoinput) flag specifies whether the output will be rescaled back to the size of the input image. The default is `true`. The model runs inference on a scaled (probably smaller) version of your image (224x224 for `DEEPLAB_V3_RESNET50`). If you choose to resize, the output will be `number[]` of size `width * height` of your original image.
69+
- [`imageSource`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64).
70+
- [`classesOfInterest`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]` (no class masks). The `ARGMAX` map is always returned regardless of this parameter.
71+
- [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions (e.g. 224x224 for `DEEPLAB_V3_RESNET50`).
5972

6073
:::warning
6174
Setting `resizeToInput` to `false` will make `forward` faster.
6275
:::
6376

64-
[`forward`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) returns a promise which can resolve either to an error or a dictionary containing number arrays with size depending on [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#resizetoinput):
77+
`forward` returns a promise resolving to an object containing:
78+
79+
- `ARGMAX` - An `Int32Array` where each element is the class index with the highest probability for that pixel.
80+
- For each label included in `classesOfInterest`, a `Float32Array` of per-pixel probabilities for that class.
6581

66-
- For the key [`DeeplabLabel.ARGMAX`](../../06-api-reference/enumerations/DeeplabLabel.md#argmax) the array contains for each pixel an integer corresponding to the class with the highest probability.
67-
- For every other key from [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md), if the label was included in [`classesOfInterest`](../../06-api-reference/interfaces/ImageSegmentationType.md#classesofinterest) the dictionary will contain an array of floats corresponding to the probability of this class for every pixel.
82+
The return type is fully typed — TypeScript narrows it based on the labels you pass in `classesOfInterest`.
6883

6984
## Example
7085

7186
```typescript
87+
import {
88+
useImageSegmentation,
89+
DEEPLAB_V3_RESNET50,
90+
DeeplabLabel,
91+
} from 'react-native-executorch';
92+
7293
function App() {
73-
const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 });
94+
const model = useImageSegmentation({
95+
model: DEEPLAB_V3_RESNET50,
96+
});
7497

75-
// ...
76-
const imageUri = 'file::///Users/.../cute_cat.png';
98+
const handleSegment = async () => {
99+
if (!model.isReady) return;
100+
101+
const imageUri = 'file::///Users/.../cute_cat.png';
102+
103+
try {
104+
const result = await model.forward(imageUri, ['CAT', 'PERSON'], true);
105+
106+
// result.ARGMAX — Int32Array of per-pixel class indices
107+
// result.CAT — Float32Array of per-pixel probabilities for CAT
108+
// result.PERSON — Float32Array of per-pixel probabilities for PERSON
109+
} catch (error) {
110+
console.error(error);
111+
}
112+
};
77113

78-
try {
79-
const outputDict = await model.forward(imageUri, [DeeplabLabel.CAT], true);
80-
} catch (error) {
81-
console.error(error);
82-
}
83114
// ...
84115
}
85116
```
86117

87118
## Supported models
88119

89-
| Model | Number of classes | Class list |
90-
| ------------------------------------------------------------------------------------------------ | ----------------- | ------------------------------------------------------------------- |
91-
| [deeplabv3_resnet50](https://huggingface.co/software-mansion/react-native-executorch-deeplab-v3) | 21 | [DeeplabLabel](../../06-api-reference/enumerations/DeeplabLabel.md) |
120+
| Model | Number of classes | Class list |
121+
| ------------------------------------------------------------------------------------------------ | ----------------- | ----------------------------------------------------------------------------------------- |
122+
| [deeplabv3_resnet50](https://huggingface.co/software-mansion/react-native-executorch-deeplab-v3) | 21 | [DeeplabLabel](../../06-api-reference/enumerations/DeeplabLabel.md) |
123+
| selfie-segmentation | 2 | [SelfieSegmentationLabel](../../06-api-reference/enumerations/SelfieSegmentationLabel.md) |

docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@ import {
1919

2020
const imageUri = 'path/to/image.png';
2121

22-
// Creating an instance
23-
const imageSegmentationModule = new ImageSegmentationModule();
24-
25-
// Loading the model
26-
await imageSegmentationModule.load(DEEPLAB_V3_RESNET50);
22+
// Creating an instance from a built-in model
23+
const segmentation = await ImageSegmentationModule.fromModelName({
24+
modelName: 'deeplab-v3',
25+
modelSource: DEEPLAB_V3_RESNET50,
26+
});
2727

2828
// Running the model
29-
const outputDict = await imageSegmentationModule.forward(imageUri);
29+
const result = await segmentation.forward(imageUri);
30+
// result.ARGMAX — Int32Array of per-pixel class indices
3031
```
3132

3233
### Methods
@@ -35,34 +36,75 @@ All methods of `ImageSegmentationModule` are explained in details here: [`ImageS
3536

3637
## Loading the model
3738

38-
To initialize the module, create an instance and call the [`load`](../../06-api-reference/classes/ImageSegmentationModule.md#load) method with the following parameters:
39+
`ImageSegmentationModule` uses static factory methods instead of `new()` + `load()`. There are two ways to create an instance:
40+
41+
### Built-in models — `fromModelName`
42+
43+
Use [`fromModelName`](../../06-api-reference/classes/ImageSegmentationModule.md#frommodelname) for models that ship with built-in label maps and preprocessing configs:
44+
45+
```typescript
46+
const segmentation = await ImageSegmentationModule.fromModelName(
47+
DEEPLAB_V3_RESNET50,
48+
(progress) => console.log(`Download: ${Math.round(progress * 100)}%`)
49+
);
50+
```
3951

40-
- [`model`](../../06-api-reference/classes/ImageSegmentationModule.md#model) - Object containing:
41-
- [`modelSource`](../../06-api-reference/classes/ImageSegmentationModule.md#modelsource) - Location of the used model.
52+
The `config` parameter is a discriminated union — TypeScript ensures you provide the correct fields for each model name. Available built-in models: `'deeplab-v3'`, `'selfie-segmentation'`.
4253

43-
- [`onDownloadProgressCallback`](../../06-api-reference/classes/ImageSegmentationModule.md#ondownloadprogresscallback) - Callback to track download progress.
54+
### Custom models — `fromCustomConfig`
4455

45-
This method returns a promise, which can resolve to an error or void.
56+
Use [`fromCustomConfig`](../../06-api-reference/classes/ImageSegmentationModule.md#fromcustomconfig) for custom-exported segmentation models with your own label map:
57+
58+
```typescript
59+
const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const;
60+
61+
const segmentation = await ImageSegmentationModule.fromCustomConfig(
62+
'https://example.com/custom_model.pte',
63+
{
64+
labelMap: MyLabels,
65+
preprocessorConfig: {
66+
normMean: [0.485, 0.456, 0.406],
67+
normStd: [0.229, 0.224, 0.225],
68+
},
69+
}
70+
);
71+
```
72+
73+
The `preprocessorConfig` is optional. If omitted, no input normalization is applied. The module instance will be typed to your custom label map — `forward` will accept and return keys from `MyLabels`.
4674

4775
For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page.
4876

4977
## Running the model
5078

51-
To run the model, you can use the [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) method on the module object. It accepts three arguments: a required image - can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64), an optional list of classes, and an optional flag whether to resize the output to the original dimensions.
79+
To run the model, use the [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) method. It accepts three arguments:
5280

53-
- The image can be a remote URL, a local file URI, or a base64-encoded image.
54-
- The [`classesOfInterest`](../../06-api-reference/classes/ImageSegmentationModule.md#classesofinterest) list contains classes for which to output the full results. By default the list is empty, and only the most probable classes are returned (essentially an arg max for each pixel). Look at [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md) enum for possible classes.
55-
- The [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#resizetoinput) flag specifies whether the output will be rescaled back to the size of the input image. The default is `true`. The model runs inference on a scaled (probably smaller) version of your image (224x224 for the `DEEPLAB_V3_RESNET50`). If you choose to resize, the output will be `number[]` of size `width * height` of your original image.
81+
- [`imageSource`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64).
82+
- [`classesOfInterest`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]`. The `ARGMAX` map is always returned regardless.
83+
- [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions.
5684

5785
:::warning
58-
Setting `resize` to true will make `forward` slower.
86+
Setting `resizeToInput` to `false` will make `forward` faster.
5987
:::
6088

61-
[`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) returns a promise which can resolve either to an error or a dictionary containing number arrays with size depending on [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#resizetoinput):
89+
`forward` returns a promise resolving to an object containing:
90+
91+
- `ARGMAX` - An `Int32Array` where each element is the class index with the highest probability for that pixel.
92+
- For each label included in `classesOfInterest`, a `Float32Array` of per-pixel probabilities for that class.
6293

63-
- For the key [`DeeplabLabel.ARGMAX`](../../06-api-reference/enumerations/DeeplabLabel.md#argmax) the array contains for each pixel an integer corresponding to the class with the highest probability.
64-
- For every other key from [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md), if the label was included in [`classesOfInterest`](../../06-api-reference/classes/ImageSegmentationModule.md#classesofinterest) the dictionary will contain an array of floats corresponding to the probability of this class for every pixel.
94+
The return type narrows based on the labels passed in `classesOfInterest`:
95+
96+
```typescript
97+
// Only ARGMAX in the result
98+
const result = await segmentation.forward(imageUri);
99+
result.ARGMAX; // Int32Array
100+
101+
// ARGMAX + requested class masks
102+
const result = await segmentation.forward(imageUri, ['CAT', 'DOG']);
103+
result.ARGMAX; // Int32Array
104+
result.CAT; // Float32Array
105+
result.DOG; // Float32Array
106+
```
65107

66108
## Managing memory
67109

68-
The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) after [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) unless you load the module again.
110+
The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) after [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) unless you create a new instance.

0 commit comments

Comments
 (0)