Skip to content

Commit bdb3616

Browse files
chmjkbmsluszniak
andcommitted
chore!: don't resize the image segmentation output by default, change param name (#801)
## Description This PR changes the param name of from `resize` to `resizeToInput` in image segmentation APIs. It also defaults to true now, as the performance impact is acceptable. ### Introduces a breaking change? - [x] Yes - [ ] No ### Type of change - [ ] Bug fix (change which fixes an issue) - [ ] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [x] Other (chores, tests, code style improvements etc.) ### Tested on - [ ] iOS - [ ] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. --> --------- Co-authored-by: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com>
1 parent c4f63e8 commit bdb3616

7 files changed

Lines changed: 77 additions & 76 deletions

File tree

apps/computer-vision/app/image_segmentation/index.tsx

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,13 @@ import {
1212
Skia,
1313
AlphaType,
1414
ColorType,
15+
SkImage,
1516
} from '@shopify/react-native-skia';
1617
import { View, StyleSheet, Image } from 'react-native';
1718
import React, { useContext, useEffect, useState } from 'react';
1819
import { GeneratingContext } from '../../context';
1920
import ScreenWrapper from '../../ScreenWrapper';
2021

21-
const width = 224;
22-
const height = 224;
23-
24-
let pixels = new Uint8Array(width * height * 4);
25-
pixels.fill(255);
26-
27-
let data = Skia.Data.fromBytes(pixels);
28-
let img = Skia.Image.MakeImage(
29-
{
30-
width: width,
31-
height: height,
32-
alphaType: AlphaType.Opaque,
33-
colorType: ColorType.RGBA_8888,
34-
},
35-
data,
36-
width * 4
37-
);
38-
3922
const numberToColor: number[][] = [
4023
[255, 87, 51], // 0 Red
4124
[51, 255, 87], // 1 Green
@@ -67,48 +50,58 @@ export default function ImageSegmentationScreen() {
6750
setGlobalGenerating(model.isGenerating);
6851
}, [model.isGenerating, setGlobalGenerating]);
6952
const [imageUri, setImageUri] = useState('');
53+
const [imageSize, setImageSize] = useState({ width: 0, height: 0 });
54+
const [segImage, setSegImage] = useState<SkImage | null>(null);
55+
const [canvasSize, setCanvasSize] = useState({ width: 0, height: 0 });
7056

7157
const handleCameraPress = async (isCamera: boolean) => {
7258
const image = await getImage(isCamera);
73-
const uri = image?.uri;
74-
setImageUri(uri as string);
59+
if (!image?.uri) return;
60+
setImageUri(image.uri);
61+
setImageSize({
62+
width: image.width ?? 0,
63+
height: image.height ?? 0,
64+
});
65+
setSegImage(null);
7566
};
7667

77-
const [resultPresent, setResultPresent] = useState(false);
78-
7968
const runForward = async () => {
80-
if (imageUri) {
81-
try {
82-
const output = await model.forward(imageUri);
83-
pixels = new Uint8Array(width * height * 4);
69+
if (!imageUri || imageSize.width === 0 || imageSize.height === 0) return;
70+
try {
71+
const { width, height } = imageSize;
72+
const output = await model.forward(imageUri, [DeeplabLabel.ARGMAX]);
73+
const argmax = output[DeeplabLabel.ARGMAX] || [];
74+
const uniqueValues = new Set<number>();
75+
for (let i = 0; i < argmax.length; i++) {
76+
uniqueValues.add(argmax[i]);
77+
}
78+
const pixels = new Uint8Array(width * height * 4);
8479

85-
for (let x = 0; x < width; x++) {
86-
for (let y = 0; y < height; y++) {
87-
for (let i = 0; i < 3; i++) {
88-
pixels[(x * height + y) * 4 + i] =
89-
numberToColor[
90-
(output[DeeplabLabel.ARGMAX] || [])[x * height + y]
91-
][i];
92-
}
93-
pixels[(x * height + y) * 4 + 3] = 255;
94-
}
80+
for (let row = 0; row < height; row++) {
81+
for (let col = 0; col < width; col++) {
82+
const idx = row * width + col;
83+
const color = numberToColor[argmax[idx]] || [0, 0, 0];
84+
pixels[idx * 4] = color[0];
85+
pixels[idx * 4 + 1] = color[1];
86+
pixels[idx * 4 + 2] = color[2];
87+
pixels[idx * 4 + 3] = 255;
9588
}
96-
97-
data = Skia.Data.fromBytes(pixels);
98-
img = Skia.Image.MakeImage(
99-
{
100-
width: width,
101-
height: height,
102-
alphaType: AlphaType.Opaque,
103-
colorType: ColorType.RGBA_8888,
104-
},
105-
data,
106-
width * 4
107-
);
108-
setResultPresent(true);
109-
} catch (e) {
110-
console.error(e);
11189
}
90+
91+
const data = Skia.Data.fromBytes(pixels);
92+
const img = Skia.Image.MakeImage(
93+
{
94+
width,
95+
height,
96+
alphaType: AlphaType.Opaque,
97+
colorType: ColorType.RGBA_8888,
98+
},
99+
data,
100+
width * 4
101+
);
102+
setSegImage(img);
103+
} catch (e) {
104+
console.error(e);
112105
}
113106
};
114107

@@ -135,16 +128,24 @@ export default function ImageSegmentationScreen() {
135128
}
136129
/>
137130
</View>
138-
{resultPresent && (
139-
<View style={styles.canvasContainer}>
131+
{segImage && (
132+
<View
133+
style={styles.canvasContainer}
134+
onLayout={(e) =>
135+
setCanvasSize({
136+
width: e.nativeEvent.layout.width,
137+
height: e.nativeEvent.layout.height,
138+
})
139+
}
140+
>
140141
<Canvas style={styles.canvas}>
141142
<SkiaImage
142-
image={img}
143+
image={segImage}
143144
fit="contain"
144145
x={0}
145146
y={0}
146-
width={width}
147-
height={height}
147+
width={canvasSize.width}
148+
height={canvasSize.height}
148149
/>
149150
</Canvas>
150151
</View>
@@ -181,7 +182,7 @@ const styles = StyleSheet.create({
181182
padding: 4,
182183
},
183184
canvas: {
184-
width: width,
185-
height: height,
185+
width: '100%',
186+
height: '100%',
186187
},
187188
});

docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ To run the model, you can use the [`forward`](../../06-api-reference/interfaces/
5555

5656
- The image can be a remote URL, a local file URI, or a base64-encoded image.
5757
- The [`classesOfInterest`](../../06-api-reference/interfaces/ImageSegmentationType.md#classesofinterest) list contains classes for which to output the full results. By default the list is empty, and only the most probable classes are returned (essentially an arg max for each pixel). Look at [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md) enum for possible classes.
58-
- The [`resize`](../../06-api-reference/interfaces/ImageSegmentationType.md#resize) flag says whether the output will be rescaled back to the size of the image you put in. The default is `false`. The model runs inference on a scaled (probably smaller) version of your image (224x224 for `DEEPLAB_V3_RESNET50`). If you choose to resize, the output will be `number[]` of size `width * height` of your original image.
58+
- The [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#resizeToInput) flag specifies whether the output will be rescaled back to the size of the input image. The default is `true`. The model runs inference on a scaled (probably smaller) version of your image (224x224 for `DEEPLAB_V3_RESNET50`). If you choose to resize, the output will be `number[]` of size `width * height` of your original image.
5959

6060
:::warning
6161
Setting `resize` to true will make `forward` slower.

docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ To run the model, you can use the [`forward`](../../06-api-reference/classes/Ima
5252

5353
- The image can be a remote URL, a local file URI, or a base64-encoded image.
5454
- The [`classesOfInterest`](../../06-api-reference/classes/ImageSegmentationModule.md#classesofinterest) list contains classes for which to output the full results. By default the list is empty, and only the most probable classes are returned (essentially an arg max for each pixel). Look at [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md) enum for possible classes.
55-
- The [`resize`](../../06-api-reference/classes/ImageSegmentationModule.md#resize) flag says whether the output will be rescaled back to the size of the image you put in. The default is `false`. The model runs inference on a scaled (probably smaller) version of your image (224x224 for the `DEEPLAB_V3_RESNET50`). If you choose to resize, the output will be `number[]` of size `width * height` of your original image.
55+
- The [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#resizeToInput) flag specifies whether the output will be rescaled back to the size of the input image. The default is `true`. The model runs inference on a scaled (probably smaller) version of your image (224x224 for the `DEEPLAB_V3_RESNET50`). If you choose to resize, the output will be `number[]` of size `width * height` of your original image.
5656

5757
:::warning
5858
Setting `resize` to true will make `forward` slower.

docs/docs/06-api-reference/classes/ImageSegmentationModule.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Class: ImageSegmentationModule
22

3-
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts:13](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L13)
3+
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts:13](https://github.com/software-mansion/react-native-executorch/blob/b5006f04ed89e0ab316675cb5fc7fabdaa345c32/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L13)
44

55
Module for image segmentation tasks.
66

@@ -28,7 +28,7 @@ Module for image segmentation tasks.
2828

2929
> **nativeModule**: `any` = `null`
3030
31-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:8](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L8)
31+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:8](https://github.com/software-mansion/react-native-executorch/blob/b5006f04ed89e0ab316675cb5fc7fabdaa345c32/packages/react-native-executorch/src/modules/BaseModule.ts#L8)
3232

3333
Native module instance
3434

@@ -42,7 +42,7 @@ Native module instance
4242

4343
> **delete**(): `void`
4444
45-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:41](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L41)
45+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:41](https://github.com/software-mansion/react-native-executorch/blob/b5006f04ed89e0ab316675cb5fc7fabdaa345c32/packages/react-native-executorch/src/modules/BaseModule.ts#L41)
4646

4747
Unloads the model from memory.
4848

@@ -58,9 +58,9 @@ Unloads the model from memory.
5858

5959
### forward()
6060

61-
> **forward**(`imageSource`, `classesOfInterest?`, `resize?`): `Promise`\<`Partial`\<`Record`\<[`DeeplabLabel`](../enumerations/DeeplabLabel.md), `number`[]\>\>\>
61+
> **forward**(`imageSource`, `classesOfInterest?`, `resizeToInput?`): `Promise`\<`Partial`\<`Record`\<[`DeeplabLabel`](../enumerations/DeeplabLabel.md), `number`[]\>\>\>
6262
63-
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts:46](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L46)
63+
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts:46](https://github.com/software-mansion/react-native-executorch/blob/b5006f04ed89e0ab316675cb5fc7fabdaa345c32/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L46)
6464

6565
Executes the model's forward pass
6666

@@ -78,11 +78,11 @@ a fetchable resource or a Base64-encoded string.
7878

7979
an optional list of DeeplabLabel used to indicate additional arrays of probabilities to output (see section "Running the model"). The default is an empty list.
8080

81-
##### resize?
81+
##### resizeToInput?
8282

8383
`boolean`
8484

85-
an optional boolean to indicate whether the output should be resized to the original image dimensions, or left in the size of the model (see section "Running the model"). The default is `false`.
85+
an optional boolean to indicate whether the output should be resized to the original input image dimensions. If `false`, returns the model output without any resizing (see section "Running the model"). Defaults to `true`.
8686

8787
#### Returns
8888

@@ -96,7 +96,7 @@ A dictionary where keys are `DeeplabLabel` and values are arrays of probabilitie
9696

9797
> `protected` **forwardET**(`inputTensor`): `Promise`\<[`TensorPtr`](../interfaces/TensorPtr.md)[]\>
9898
99-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:23](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L23)
99+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:23](https://github.com/software-mansion/react-native-executorch/blob/b5006f04ed89e0ab316675cb5fc7fabdaa345c32/packages/react-native-executorch/src/modules/BaseModule.ts#L23)
100100

101101
Runs the model's forward method with the given input tensors.
102102
It returns the output tensors that mimic the structure of output from ExecuTorch.
@@ -125,7 +125,7 @@ Array of output tensors.
125125

126126
> **getInputShape**(`methodName`, `index`): `Promise`\<`number`[]\>
127127
128-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:34](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L34)
128+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:34](https://github.com/software-mansion/react-native-executorch/blob/b5006f04ed89e0ab316675cb5fc7fabdaa345c32/packages/react-native-executorch/src/modules/BaseModule.ts#L34)
129129

130130
Gets the input shape for a given method and index.
131131

@@ -159,7 +159,7 @@ The input shape as an array of numbers.
159159

160160
> **load**(`model`, `onDownloadProgressCallback`): `Promise`\<`void`\>
161161
162-
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts:21](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L21)
162+
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts:21](https://github.com/software-mansion/react-native-executorch/blob/b5006f04ed89e0ab316675cb5fc7fabdaa345c32/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L21)
163163

164164
Loads the model, where `modelSource` is a string that specifies the location of the model binary.
165165
To track the download progress, supply a callback function `onDownloadProgressCallback`.

docs/docs/06-api-reference/functions/useImageSegmentation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
> **useImageSegmentation**(`ImageSegmentationProps`): [`ImageSegmentationType`](../interfaces/ImageSegmentationType.md)
44
5-
Defined in: [packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts:15](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts#L15)
5+
Defined in: [packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts:15](https://github.com/software-mansion/react-native-executorch/blob/9e79b9bf2a34159a71071fbfdaed3ddd9393702f/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts#L15)
66

77
React hook for managing an Image Segmentation model instance.
88

packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ export class ImageSegmentationModule extends BaseModule {
4040
*
4141
* @param imageSource - a fetchable resource or a Base64-encoded string.
4242
* @param classesOfInterest - an optional list of DeeplabLabel used to indicate additional arrays of probabilities to output (see section "Running the model"). The default is an empty list.
43-
* @param resize - an optional boolean to indicate whether the output should be resized to the original image dimensions, or left in the size of the model (see section "Running the model"). The default is `false`.
43+
* @param resizeToInput - an optional boolean to indicate whether the output should be resized to the original input image dimensions. If `false`, returns the model output without any resizing (see section "Running the model"). Defaults to `true`.
4444
* @returns A dictionary where keys are `DeeplabLabel` and values are arrays of probabilities for each pixel belonging to the corresponding class.
4545
*/
4646
async forward(
4747
imageSource: string,
4848
classesOfInterest?: DeeplabLabel[],
49-
resize?: boolean
49+
resizeToInput?: boolean
5050
): Promise<Partial<Record<DeeplabLabel, number[]>>> {
5151
if (this.nativeModule == null) {
5252
throw new RnExecutorchError(
@@ -58,7 +58,7 @@ export class ImageSegmentationModule extends BaseModule {
5858
const stringDict = await this.nativeModule.generate(
5959
imageSource,
6060
(classesOfInterest || []).map((label) => DeeplabLabel[label]),
61-
resize || false
61+
resizeToInput ?? true
6262
);
6363

6464
let enumDict: { [key in DeeplabLabel]?: number[] } = {};

packages/react-native-executorch/src/types/imageSegmentation.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ export interface ImageSegmentationType {
7676
* Executes the model's forward pass to perform semantic segmentation on the provided image.
7777
* @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed.
7878
* @param classesOfInterest - An optional array of `DeeplabLabel` enums. If provided, the model will only return segmentation masks for these specific classes.
79-
* @param resize - An optional boolean indicating whether the output segmentation masks should be resized to match the original image dimensions. Defaults to standard model behavior if undefined.
79+
* @param resizeToInput - an optional boolean to indicate whether the output should be resized to the original input image dimensions. If `false`, returns the model output without any resizing (see section "Running the model"). Defaults to `true`.
8080
* @returns A Promise that resolves to an object mapping each detected `DeeplabLabel` to its corresponding segmentation mask (represented as a flattened array of numbers).
8181
* @throws {RnExecutorchError} If the model is not loaded or is currently processing another image.
8282
*/
8383
forward: (
8484
imageSource: string,
8585
classesOfInterest?: DeeplabLabel[],
86-
resize?: boolean
86+
resizeToInput?: boolean
8787
) => Promise<Partial<Record<DeeplabLabel, number[]>>>;
8888
}

0 commit comments

Comments
 (0)