Skip to content

Commit 91eea3f

Browse files
committed
Return class index, tweak mask drawing performance
1 parent 9722e7c commit 91eea3f

File tree

10 files changed

+275
-114
lines changed

10 files changed

+275
-114
lines changed

apps/computer-vision/app/instance_segmentation/index.tsx

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import Spinner from '../../components/Spinner';
22
import { BottomBar } from '../../components/BottomBar';
33
import { getImage } from '../../utils';
4-
import {
5-
useInstanceSegmentation,
6-
YOLO26X_SEG,
7-
} from 'react-native-executorch';
4+
import { useInstanceSegmentation, YOLO26N_SEG } from 'react-native-executorch';
85
import {
96
View,
107
StyleSheet,
@@ -15,7 +12,10 @@ import {
1512
import React, { useContext, useEffect, useState } from 'react';
1613
import { GeneratingContext } from '../../context';
1714
import ScreenWrapper from '../../ScreenWrapper';
18-
import ImageWithMasks from '../../components/ImageWithMasks';
15+
import ImageWithMasks, {
16+
buildDisplayInstances,
17+
DisplayInstance,
18+
} from '../../components/ImageWithMasks';
1919

2020
const AVAILABLE_INPUT_SIZES = [384, 512, 640];
2121

@@ -24,12 +24,12 @@ export default function InstanceSegmentationScreen() {
2424

2525
const { isReady, isGenerating, downloadProgress, forward, error } =
2626
useInstanceSegmentation({
27-
model: YOLO26X_SEG,
27+
model: YOLO26N_SEG,
2828
});
2929

3030
const [imageUri, setImageUri] = useState('');
3131
const [imageSize, setImageSize] = useState({ width: 0, height: 0 });
32-
const [instances, setInstances] = useState<any[]>([]);
32+
const [instances, setInstances] = useState<DisplayInstance[]>([]);
3333
const [selectedInputSize, setSelectedInputSize] = useState(
3434
AVAILABLE_INPUT_SIZES[0]
3535
);
@@ -61,7 +61,10 @@ export default function InstanceSegmentationScreen() {
6161
inputSize: selectedInputSize,
6262
});
6363

64-
setInstances(output);
64+
// Convert raw masks → small Skia images immediately.
65+
// Raw Uint8Array mask buffers (backed by native OwningArrayBuffer)
66+
// go out of scope here and become eligible for GC right away.
67+
setInstances(buildDisplayInstances(output));
6568
} catch (e) {
6669
console.error(e);
6770
}

apps/computer-vision/components/ImageWithMasks.tsx

Lines changed: 98 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import React, { useState } from 'react';
1+
import React, { useEffect, useState } from 'react';
22
import { Image, StyleSheet, View, Text } from 'react-native';
33
import {
44
Canvas,
@@ -10,8 +10,6 @@ import {
1010
Rect,
1111
Group,
1212
} from '@shopify/react-native-skia';
13-
import type { SegmentedInstance } from 'react-native-executorch';
14-
import type { LabelEnum } from 'react-native-executorch';
1513

1614
const INSTANCE_COLORS = [
1715
[255, 87, 51, 180],
@@ -26,39 +24,101 @@ const INSTANCE_COLORS = [
2624
[131, 51, 255, 180],
2725
];
2826

29-
interface Props {
30-
imageUri: string;
31-
instances: SegmentedInstance<LabelEnum>[];
32-
imageWidth: number;
33-
imageHeight: number;
27+
const MAX_MASK_DIM = 256;
28+
29+
/** Display-only data — no raw mask buffers. */
30+
export interface DisplayInstance {
31+
bbox: { x1: number; y1: number; x2: number; y2: number };
32+
label: string;
33+
score: number;
34+
maskImage: SkImage;
35+
}
36+
37+
/**
38+
* Convert raw segmentation output into lightweight display instances.
39+
* Call this eagerly (in the forward callback) so raw Uint8Array masks
40+
* can be garbage-collected immediately.
41+
*/
42+
export function buildDisplayInstances(
43+
rawInstances: {
44+
bbox: { x1: number; y1: number; x2: number; y2: number };
45+
mask: Uint8Array;
46+
maskWidth: number;
47+
maskHeight: number;
48+
label: string | number;
49+
score: number;
50+
}[]
51+
): DisplayInstance[] {
52+
return rawInstances
53+
.map((inst, i) => {
54+
const color = INSTANCE_COLORS[i % INSTANCE_COLORS.length];
55+
const img = createMaskImage(
56+
inst.mask,
57+
inst.maskWidth,
58+
inst.maskHeight,
59+
color
60+
);
61+
if (!img) return null;
62+
return {
63+
bbox: inst.bbox,
64+
label: String(inst.label),
65+
score: inst.score,
66+
maskImage: img,
67+
};
68+
})
69+
.filter((d): d is DisplayInstance => d !== null);
3470
}
3571

3672
function createMaskImage(
3773
mask: Uint8Array,
38-
width: number,
39-
height: number,
74+
srcW: number,
75+
srcH: number,
4076
color: number[]
4177
): SkImage | null {
42-
const pixels = new Uint8Array(width * height * 4);
43-
for (let j = 0; j < mask.length; j++) {
44-
if (mask[j] > 0) {
45-
pixels[j * 4] = color[0];
46-
pixels[j * 4 + 1] = color[1];
47-
pixels[j * 4 + 2] = color[2];
48-
pixels[j * 4 + 3] = color[3];
78+
const downscale = Math.min(1, MAX_MASK_DIM / Math.max(srcW, srcH));
79+
const dstW = Math.max(1, Math.round(srcW * downscale));
80+
const dstH = Math.max(1, Math.round(srcH * downscale));
81+
82+
const pixels = new Uint8Array(dstW * dstH * 4);
83+
const r = color[0],
84+
g = color[1],
85+
b = color[2],
86+
a = color[3];
87+
88+
for (let dy = 0; dy < dstH; dy++) {
89+
const sy = Math.min(Math.floor(dy / downscale), srcH - 1);
90+
for (let dx = 0; dx < dstW; dx++) {
91+
const sx = Math.min(Math.floor(dx / downscale), srcW - 1);
92+
if (mask[sy * srcW + sx] > 0) {
93+
const idx = (dy * dstW + dx) * 4;
94+
pixels[idx] = r;
95+
pixels[idx + 1] = g;
96+
pixels[idx + 2] = b;
97+
pixels[idx + 3] = a;
98+
}
4999
}
50100
}
101+
51102
const data = Skia.Data.fromBytes(pixels);
52-
return Skia.Image.MakeImage(
103+
const image = Skia.Image.MakeImage(
53104
{
54-
width,
55-
height,
105+
width: dstW,
106+
height: dstH,
56107
alphaType: AlphaType.Premul,
57108
colorType: ColorType.RGBA_8888,
58109
},
59110
data,
60-
width * 4
111+
dstW * 4
61112
);
113+
data.dispose();
114+
return image;
115+
}
116+
117+
interface Props {
118+
imageUri: string;
119+
instances: DisplayInstance[];
120+
imageWidth: number;
121+
imageHeight: number;
62122
}
63123

64124
export default function ImageWithMasks({
@@ -75,17 +135,12 @@ export default function ImageWithMasks({
75135
const offsetX = (layout.width - imageWidth * scale) / 2;
76136
const offsetY = (layout.height - imageHeight * scale) / 2;
77137

78-
const maskImages = instances
79-
.map((instance, i) => {
80-
const color = INSTANCE_COLORS[i % INSTANCE_COLORS.length];
81-
return createMaskImage(
82-
instance.mask,
83-
instance.maskWidth,
84-
instance.maskHeight,
85-
color
86-
);
87-
})
88-
.filter((img): img is SkImage => img !== null);
138+
// Dispose Skia images when instances are replaced or on unmount
139+
useEffect(() => {
140+
return () => {
141+
instances.forEach((inst) => inst.maskImage.dispose());
142+
};
143+
}, [instances]);
89144

90145
return (
91146
<View
@@ -108,16 +163,15 @@ export default function ImageWithMasks({
108163
{instances.length > 0 && (
109164
<View style={styles.overlay}>
110165
<Canvas style={styles.canvas}>
111-
{maskImages.map((maskImg, idx) => {
112-
const inst = instances[idx];
166+
{instances.map((inst, idx) => {
113167
const mx = inst.bbox.x1 * scale + offsetX;
114168
const my = inst.bbox.y1 * scale + offsetY;
115169
const mw = (inst.bbox.x2 - inst.bbox.x1) * scale;
116170
const mh = (inst.bbox.y2 - inst.bbox.y1) * scale;
117171
return (
118172
<SkiaImage
119173
key={`mask-${idx}`}
120-
image={maskImg}
174+
image={inst.maskImage}
121175
fit="fill"
122176
x={mx}
123177
y={my}
@@ -127,12 +181,12 @@ export default function ImageWithMasks({
127181
);
128182
})}
129183

130-
{instances.map((instance, idx) => {
184+
{instances.map((inst, idx) => {
131185
const color = INSTANCE_COLORS[idx % INSTANCE_COLORS.length];
132-
const bboxX = instance.bbox.x1 * scale + offsetX;
133-
const bboxY = instance.bbox.y1 * scale + offsetY;
134-
const bboxW = (instance.bbox.x2 - instance.bbox.x1) * scale;
135-
const bboxH = (instance.bbox.y2 - instance.bbox.y1) * scale;
186+
const bboxX = inst.bbox.x1 * scale + offsetX;
187+
const bboxY = inst.bbox.y1 * scale + offsetY;
188+
const bboxW = (inst.bbox.x2 - inst.bbox.x1) * scale;
189+
const bboxH = (inst.bbox.y2 - inst.bbox.y1) * scale;
136190

137191
return (
138192
<Group key={`bbox-${idx}`}>
@@ -150,10 +204,10 @@ export default function ImageWithMasks({
150204
})}
151205
</Canvas>
152206

153-
{instances.map((instance, idx) => {
207+
{instances.map((inst, idx) => {
154208
const color = INSTANCE_COLORS[idx % INSTANCE_COLORS.length];
155-
const bboxX = instance.bbox.x1 * scale + offsetX;
156-
const bboxY = instance.bbox.y1 * scale + offsetY;
209+
const bboxX = inst.bbox.x1 * scale + offsetX;
210+
const bboxY = inst.bbox.y1 * scale + offsetY;
157211

158212
return (
159213
<View
@@ -168,8 +222,7 @@ export default function ImageWithMasks({
168222
]}
169223
>
170224
<Text style={styles.labelText}>
171-
{String(instance.label) || 'Unknown'}{' '}
172-
{(instance.score * 100).toFixed(0)}%
225+
{inst.label || 'Unknown'} {(inst.score * 100).toFixed(0)}%
173226
</Text>
174227
</View>
175228
);

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,9 +480,7 @@ inline jsi::Value getJsiValue(
480480
instance.setProperty(runtime, "maskWidth", instances[i].maskWidth);
481481
instance.setProperty(runtime, "maskHeight", instances[i].maskHeight);
482482

483-
instance.setProperty(
484-
runtime, "label",
485-
jsi::String::createFromUtf8(runtime, instances[i].label));
483+
instance.setProperty(runtime, "classIndex", instances[i].classIndex);
486484

487485
instance.setProperty(runtime, "score", instances[i].score);
488486
instance.setProperty(runtime, "instanceId", instances[i].instanceId);

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@ namespace rnexecutorch::models::instance_segmentation {
1010
BaseInstanceSegmentation::BaseInstanceSegmentation(
1111
const std::string &modelSource, std::vector<float> normMean,
1212
std::vector<float> normStd, bool applyNMS,
13-
std::vector<std::string> labelNames,
1413
std::shared_ptr<react::CallInvoker> callInvoker)
15-
: BaseModel(modelSource, callInvoker), applyNMS_(applyNMS),
16-
labelNames_(std::move(labelNames)) {
14+
: BaseModel(modelSource, callInvoker), applyNMS_(applyNMS) {
1715
avalivableMethods_ = *module_->method_names();
1816
if (normMean.size() == 3) {
1917
normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]);
@@ -66,7 +64,7 @@ std::vector<types::InstanceMask> BaseInstanceSegmentation::nonMaxSuppression(
6664
continue;
6765
}
6866

69-
if (instances[i].label == instances[j].label) {
67+
if (instances[i].classIndex == instances[j].classIndex) {
7068
float iou = intersectionOverUnion(instances[i], instances[j]);
7169
if (iou > iouThreshold) {
7270
suppressed[j] = true;
@@ -150,14 +148,6 @@ std::vector<types::InstanceMask> BaseInstanceSegmentation::postprocess(
150148
labelIdx)) == allowedClasses.end())
151149
continue;
152150

153-
if (labelIdx >= labelNames_.size()) {
154-
throw RnExecutorchError(
155-
RnExecutorchErrorCode::InvalidConfig,
156-
"Model output class index " + std::to_string(labelIdx) +
157-
" exceeds labelNames size " + std::to_string(labelNames_.size()) +
158-
". Ensure the labelMap covers all model output classes.");
159-
}
160-
161151
// Scale bbox to original image coordinates
162152
float origX1 = x1 * widthRatio;
163153
float origY1 = y1 * heightRatio;
@@ -246,15 +236,15 @@ std::vector<types::InstanceMask> BaseInstanceSegmentation::postprocess(
246236
instance.mask = std::move(finalMask);
247237
instance.maskWidth = finalMaskWidth;
248238
instance.maskHeight = finalMaskHeight;
249-
instance.label = labelNames_[labelIdx];
239+
instance.classIndex = static_cast<int32_t>(labelIdx);
250240
instance.score = score;
251241
instance.instanceId = i;
252242
instances.push_back(std::move(instance));
253243
++processed;
254244
}
255245

256246
// Finalize: NMS + limit + renumber
257-
if (applyNMS_ && iouThreshold < 0.45) {
247+
if (applyNMS_) {
258248
instances = nonMaxSuppression(instances, iouThreshold);
259249
}
260250

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class BaseInstanceSegmentation : public BaseModel {
2020
BaseInstanceSegmentation(const std::string &modelSource,
2121
std::vector<float> normMean,
2222
std::vector<float> normStd, bool applyNMS,
23-
std::vector<std::string> labelNames,
2423
std::shared_ptr<react::CallInvoker> callInvoker);
2524

2625
[[nodiscard("Registered non-void function")]] std::vector<types::InstanceMask>
@@ -48,7 +47,6 @@ class BaseInstanceSegmentation : public BaseModel {
4847
std::optional<cv::Scalar> normMean_;
4948
std::optional<cv::Scalar> normStd_;
5049
bool applyNMS_;
51-
std::vector<std::string> labelNames_;
5250
cv::Size modelImageSize{0, 0};
5351
std::unordered_set<std::string> avalivableMethods_;
5452
std::string currentlyLoadedMethod_;
@@ -57,6 +55,5 @@ class BaseInstanceSegmentation : public BaseModel {
5755

5856
REGISTER_CONSTRUCTOR(models::instance_segmentation::BaseInstanceSegmentation,
5957
std::string, std::vector<float>, std::vector<float>, bool,
60-
std::vector<std::string>,
6158
std::shared_ptr<react::CallInvoker>);
6259
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/Types.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <string>
3+
#include <cstdint>
44
#include <vector>
55

66
namespace rnexecutorch::models::instance_segmentation::types {
@@ -19,7 +19,7 @@ struct InstanceMask {
1919
std::vector<uint8_t> mask; ///< Binary mask (0 or 1) for the instance
2020
int maskWidth; ///< Width of the mask array
2121
int maskHeight; ///< Height of the mask array
22-
std::string label; ///< Class label name
22+
int32_t classIndex; ///< Model output class index
2323
float score; ///< Confidence score [0, 1]
2424
int instanceId; ///< Unique identifier for this instance
2525
};

0 commit comments

Comments
 (0)