Skip to content

Commit b9e1ded

Browse files
committed
Speed up mask processing
1 parent bbe6e17 commit b9e1ded

File tree

3 files changed

+135
-60
lines changed

3 files changed

+135
-60
lines changed

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp

Lines changed: 101 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ std::vector<types::InstanceMask> BaseInstanceSegmentation::postprocess(
134134
static_cast<const float *>(scoresTensor.const_data_ptr());
135135
const float *maskData =
136136
static_cast<const float *>(maskTensor.const_data_ptr());
137+
int32_t processed = 0;
137138

138139
for (int i = 0; i < N; ++i) {
139140
float x1 = bboxData[i * 4 + 0];
@@ -157,50 +158,99 @@ std::vector<types::InstanceMask> BaseInstanceSegmentation::postprocess(
157158
". Ensure the labelMap covers all model output classes.");
158159
}
159160

160-
// Mask logits are pre-computed — just sigmoid + threshold
161+
// Scale bbox to original image coordinates
162+
float origX1 = x1 * widthRatio;
163+
float origY1 = y1 * heightRatio;
164+
float origX2 = x2 * widthRatio;
165+
float origY2 = y2 * heightRatio;
166+
167+
int bboxW = static_cast<int>(std::round(origX2 - origX1));
168+
int bboxH = static_cast<int>(std::round(origY2 - origY1));
169+
170+
if (bboxW <= 0 || bboxH <= 0)
171+
continue;
172+
173+
// Wrap logits in cv::Mat for vectorized operations
161174
const float *logits = maskData + (i * maskH * maskW);
162-
std::vector<uint8_t> binaryMask(maskH * maskW);
163-
for (int j = 0; j < maskH * maskW; j++) {
164-
float v = 1.0f / (1.0f + std::exp(-logits[j]));
165-
binaryMask[j] = (v > 0.5f) ? 1 : 0;
166-
}
175+
cv::Mat logitsMat(maskH, maskW, CV_32FC1, const_cast<float *>(logits));
167176

168-
x1 *= widthRatio;
169-
y1 *= heightRatio;
170-
x2 *= widthRatio;
171-
y2 *= heightRatio;
177+
// Float bounds in low-res mask space
178+
float mx1F = x1 * maskW / modelInputSize.width;
179+
float my1F = y1 * maskH / modelInputSize.height;
180+
float mx2F = x2 * maskW / modelInputSize.width;
181+
float my2F = y2 * maskH / modelInputSize.height;
172182

173-
int finalMaskWidth = maskW;
174-
int finalMaskHeight = maskH;
175-
std::vector<uint8_t> finalMask = binaryMask;
183+
// Exact integer bounds (bbox region in mask coordinates)
184+
int mx1 = std::max(0, static_cast<int>(std::floor(mx1F)));
185+
int my1 = std::max(0, static_cast<int>(std::floor(my1F)));
186+
int mx2 = std::min(maskW, static_cast<int>(std::ceil(mx2F)));
187+
int my2 = std::min(maskH, static_cast<int>(std::ceil(my2F)));
188+
189+
if (mx2 <= mx1 || my2 <= my1)
190+
continue;
191+
192+
cv::Mat finalBinaryMat;
193+
int finalMaskWidth = bboxW;
194+
int finalMaskHeight = bboxH;
176195

177196
if (returnMaskAtOriginalResolution) {
178-
cv::Mat maskMat(maskH, maskW, CV_8UC1, binaryMask.data());
179-
cv::Mat resizedMaskMat;
180-
cv::resize(maskMat, resizedMaskMat, originalSize, 0, 0,
181-
cv::INTER_NEAREST);
182-
finalMaskWidth = originalSize.width;
183-
finalMaskHeight = originalSize.height;
184-
for (int y = 0; y < finalMaskHeight; y++)
185-
for (int x = 0; x < finalMaskWidth; x++)
186-
if (x < x1 || x > x2 || y < y1 || y > y2)
187-
resizedMaskMat.data[y * finalMaskWidth + x] = 0;
188-
finalMask.assign(resizedMaskMat.data,
189-
resizedMaskMat.data + resizedMaskMat.total());
197+
// 1px padding for warpAffine interpolation (prevents edge artifacts)
198+
int pmx1 = std::max(0, mx1 - 1);
199+
int pmy1 = std::max(0, my1 - 1);
200+
int pmx2 = std::min(maskW, mx2 + 1);
201+
int pmy2 = std::min(maskH, my2 + 1);
202+
203+
cv::Mat croppedLogits =
204+
logitsMat(cv::Rect(pmx1, pmy1, pmx2 - pmx1, pmy2 - pmy1));
205+
cv::Mat probMat;
206+
cv::exp(-croppedLogits, probMat);
207+
probMat = 255.0f / (1.0f + probMat);
208+
probMat.convertTo(probMat, CV_8UC1);
209+
210+
// Affine matrix mapping padded crop -> bbox in original image space.
211+
// Padding pixels fall outside the output and are clipped naturally.
212+
float maskToOrigX = static_cast<float>(originalSize.width) / maskW;
213+
float maskToOrigY = static_cast<float>(originalSize.height) / maskH;
214+
215+
cv::Mat M = (cv::Mat_<float>(2, 3) << maskToOrigX, 0,
216+
(pmx1 * maskToOrigX - origX1), 0, maskToOrigY,
217+
(pmy1 * maskToOrigY - origY1));
218+
219+
cv::Mat warpedMat;
220+
cv::warpAffine(probMat, warpedMat, M, cv::Size(bboxW, bboxH),
221+
cv::INTER_LINEAR);
222+
223+
cv::threshold(warpedMat, finalBinaryMat, 127, 1, cv::THRESH_BINARY);
224+
} else {
225+
// No padding needed — no interpolation, just threshold
226+
cv::Mat croppedLogits =
227+
logitsMat(cv::Rect(mx1, my1, mx2 - mx1, my2 - my1));
228+
cv::Mat probMat;
229+
cv::exp(-croppedLogits, probMat);
230+
probMat = 255.0f / (1.0f + probMat);
231+
probMat.convertTo(probMat, CV_8UC1);
232+
233+
cv::threshold(probMat, finalBinaryMat, 127, 1, cv::THRESH_BINARY);
234+
finalMaskWidth = finalBinaryMat.cols;
235+
finalMaskHeight = finalBinaryMat.rows;
190236
}
191237

238+
std::vector<uint8_t> finalMask(
239+
finalBinaryMat.data, finalBinaryMat.data + finalBinaryMat.total());
240+
192241
types::InstanceMask instance;
193-
instance.x1 = x1;
194-
instance.y1 = y1;
195-
instance.x2 = x2;
196-
instance.y2 = y2;
242+
instance.x1 = origX1;
243+
instance.y1 = origY1;
244+
instance.x2 = origX2;
245+
instance.y2 = origY2;
197246
instance.mask = std::move(finalMask);
198247
instance.maskWidth = finalMaskWidth;
199248
instance.maskHeight = finalMaskHeight;
200249
instance.label = labelNames_[labelIdx];
201250
instance.score = score;
202251
instance.instanceId = i;
203252
instances.push_back(std::move(instance));
253+
++processed;
204254
}
205255

206256
// Finalize: NMS + limit + renumber
@@ -222,18 +272,27 @@ std::vector<types::InstanceMask> BaseInstanceSegmentation::postprocess(
222272
std::vector<types::InstanceMask> BaseInstanceSegmentation::generate(
223273
std::string imageSource, double confidenceThreshold, double iouThreshold,
224274
int maxInstances, std::vector<int32_t> classIndices,
225-
bool returnMaskAtOriginalResolution, int32_t inputSize) {
275+
bool returnMaskAtOriginalResolution, std::string methodName) {
226276

227-
std::string methodName = getMethodName(inputSize);
228-
if (currentlyLoadedMethod_ == "") {
229-
currentlyLoadedMethod_ = methodName;
230-
} else {
231-
module_->unload_method(currentlyLoadedMethod_);
277+
if (methodName.empty()) {
278+
throw RnExecutorchError(
279+
RnExecutorchErrorCode::InvalidConfig,
280+
"methodName cannot be empty. Use 'forward' for single-method models "
281+
"or 'forward_{inputSize}' for multi-method models.");
282+
}
283+
284+
if (currentlyLoadedMethod_ != methodName) {
285+
if (!currentlyLoadedMethod_.empty()) {
286+
module_->unload_method(currentlyLoadedMethod_);
287+
}
232288
currentlyLoadedMethod_ = methodName;
289+
module_->load_method(methodName);
233290
}
234-
module_->load_method(methodName);
291+
292+
auto inputShapes = getAllInputShapes(methodName);
293+
std::vector<int32_t> inputShape = inputShapes[0];
294+
int32_t inputSize = inputShape[inputShape.size() - 1];
235295
cv::Size modelInputSize(inputSize, inputSize);
236-
std::vector<int32_t> inputShape = {1, 3, inputSize, inputSize};
237296

238297
auto [inputTensor, originalSize] = image_processing::readImageToTensor(
239298
imageSource, inputShape, false, normMean_, normStd_);
@@ -247,9 +306,11 @@ std::vector<types::InstanceMask> BaseInstanceSegmentation::generate(
247306
methodName + "' is valid.");
248307
}
249308

250-
return postprocess(forwardResult.get(), originalSize, modelInputSize,
251-
confidenceThreshold, iouThreshold, maxInstances,
252-
classIndices, returnMaskAtOriginalResolution);
309+
auto result = postprocess(forwardResult.get(), originalSize, modelInputSize,
310+
confidenceThreshold, iouThreshold, maxInstances,
311+
classIndices, returnMaskAtOriginalResolution);
312+
313+
return result;
253314
}
254315

255316
} // namespace rnexecutorch::models::instance_segmentation

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class BaseInstanceSegmentation : public BaseModel {
2727
generate(std::string imageSource, double confidenceThreshold,
2828
double iouThreshold, int maxInstances,
2929
std::vector<int32_t> classIndices,
30-
bool returnMaskAtOriginalResolution, int32_t inputSize);
30+
bool returnMaskAtOriginalResolution, std::string methodName);
3131

3232
private:
3333
std::vector<types::InstanceMask>
@@ -44,10 +44,6 @@ class BaseInstanceSegmentation : public BaseModel {
4444
nonMaxSuppression(std::vector<types::InstanceMask> instances,
4545
double iouThreshold);
4646

47-
std::string getMethodName(int32_t inputSize) const {
48-
return "forward_" + std::to_string(inputSize);
49-
}
50-
5147
// Member variables
5248
std::optional<cv::Scalar> normMean_;
5349
std::optional<cv::Scalar> normStd_;

packages/react-native-executorch/src/modules/computer_vision/InstanceSegmentationModule.ts

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,25 @@ const YOLO_SEG_CONFIG: InstanceSegmentationConfig<typeof CocoLabel> = {
1919
defaultConfidenceThreshold: 0.5,
2020
defaultIouThreshold: 0.5,
2121
postprocessorConfig: {
22-
applyNMS: false,
22+
applyNMS: true,
2323
},
2424
};
2525

2626
/**
27-
* Builds an ordered label name array from a label map, indexed by class ID.
28-
* Index i corresponds to class index i produced by the model.
27+
* Builds an ordered label name array from a label map, indexed by model output
28+
* class ID. Subtracts the minimum label value so that 0-indexed model outputs
29+
* map correctly (e.g. COCO labels start at 1, but models output 0 for the
30+
* first class).
2931
*/
3032
function buildLabelNames(labelMap: LabelEnum): string[] {
31-
const allLabelNames: string[] = [];
33+
const entries: [string, number][] = [];
3234
for (const [name, value] of Object.entries(labelMap)) {
33-
if (typeof value === 'number') allLabelNames[value] = name;
35+
if (typeof value === 'number') entries.push([name, value]);
36+
}
37+
const minValue = Math.min(...entries.map(([, v]) => v));
38+
const allLabelNames: string[] = [];
39+
for (const [name, value] of entries) {
40+
allLabelNames[value - minValue] = name;
3441
}
3542
for (let i = 0; i < allLabelNames.length; i++) {
3643
if (allLabelNames[i] == null) allLabelNames[i] = '';
@@ -142,12 +149,19 @@ export class InstanceSegmentationModule<
142149
);
143150
}
144151

152+
const labelNames = buildLabelNames(modelConfig.labelMap);
153+
console.log(
154+
'[InstanceSegmentation] Label names:',
155+
labelNames.length,
156+
'classes'
157+
);
158+
145159
const nativeModule = global.loadInstanceSegmentation(
146160
path,
147161
modelConfig.preprocessorConfig?.normMean || [],
148162
modelConfig.preprocessorConfig?.normStd || [],
149163
modelConfig.postprocessorConfig?.applyNMS ?? true,
150-
buildLabelNames(modelConfig.labelMap)
164+
labelNames
151165
);
152166

153167
return new InstanceSegmentationModule<InstanceModelNameOf<C>>(
@@ -258,15 +272,9 @@ export class InstanceSegmentationModule<
258272

259273
const inputSize = options?.inputSize ?? this.modelConfig.defaultInputSize;
260274

261-
if (inputSize === undefined) {
262-
throw new RnExecutorchError(
263-
RnExecutorchErrorCode.InvalidArgument,
264-
'inputSize must be specified in options when the model config does not define availableInputSizes'
265-
);
266-
}
267-
268275
if (
269276
this.modelConfig.availableInputSizes &&
277+
inputSize !== undefined &&
270278
!this.modelConfig.availableInputSizes.includes(
271279
inputSize as (typeof this.modelConfig.availableInputSizes)[number]
272280
)
@@ -277,6 +285,9 @@ export class InstanceSegmentationModule<
277285
);
278286
}
279287

288+
const methodName =
289+
inputSize !== undefined ? `forward_${inputSize}` : 'forward';
290+
280291
const classIndices = options?.classesOfInterest
281292
? options.classesOfInterest.map((label) => {
282293
const labelStr = String(label);
@@ -285,14 +296,21 @@ export class InstanceSegmentationModule<
285296
})
286297
: [];
287298

288-
return await this.nativeModule.generate(
299+
const startTime = performance.now();
300+
const result = await this.nativeModule.generate(
289301
imageSource,
290302
confidenceThreshold,
291303
iouThreshold,
292304
maxInstances,
293305
classIndices,
294306
returnMaskAtOriginalResolution,
295-
inputSize
307+
methodName
308+
);
309+
const inferenceTime = performance.now() - startTime;
310+
console.log(
311+
`[InstanceSegmentation] Inference: ${inferenceTime.toFixed(2)}ms | Method: ${methodName} | Detected: ${result.length} instances`
296312
);
313+
314+
return result;
297315
}
298316
}

0 commit comments

Comments
 (0)