Skip to content

Commit b9794ea

Browse files
committed
Break up postprocessing helpers in Instance Segmentation
1 parent f87df42 commit b9794ea

File tree

3 files changed

+154
-99
lines changed

3 files changed

+154
-99
lines changed

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp

Lines changed: 116 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -30,120 +30,148 @@ BaseInstanceSegmentation::BaseInstanceSegmentation(
3030
}
3131
}
3232

33-
cv::Mat BaseInstanceSegmentation::processMaskFromLogits(
34-
const cv::Mat &logitsMat, float x1, float y1, float x2, float y2,
35-
cv::Size modelInputSize, cv::Size originalSize, int32_t maskW,
36-
int32_t maskH, int32_t bboxW, int32_t bboxH, float origX1, float origY1,
37-
bool warpToOriginal, int32_t &outWidth, int32_t &outHeight) {
33+
std::tuple<utils::computer_vision::BBox, float, int32_t>
34+
BaseInstanceSegmentation::extractDetectionData(const float *bboxData,
35+
const float *scoresData,
36+
int32_t index) {
37+
utils::computer_vision::BBox bbox{
38+
bboxData[index * 4], bboxData[index * 4 + 1], bboxData[index * 4 + 2],
39+
bboxData[index * 4 + 3]};
40+
float score = scoresData[index * 2];
41+
int32_t label = static_cast<int32_t>(scoresData[index * 2 + 1]);
42+
43+
return {bbox, score, label};
44+
}
3845

39-
float mx1F = x1 * maskW / modelInputSize.width;
40-
float my1F = y1 * maskH / modelInputSize.height;
41-
float mx2F = x2 * maskW / modelInputSize.width;
42-
float my2F = y2 * maskH / modelInputSize.height;
46+
cv::Rect BaseInstanceSegmentation::computeMaskCropRect(
47+
const utils::computer_vision::BBox &bboxModel, cv::Size modelInputSize,
48+
cv::Size maskSize) {
49+
50+
float mx1F = bboxModel.x1 * maskSize.width / modelInputSize.width;
51+
float my1F = bboxModel.y1 * maskSize.height / modelInputSize.height;
52+
float mx2F = bboxModel.x2 * maskSize.width / modelInputSize.width;
53+
float my2F = bboxModel.y2 * maskSize.height / modelInputSize.height;
4354

4455
int32_t mx1 = std::max(0, static_cast<int32_t>(std::floor(mx1F)));
4556
int32_t my1 = std::max(0, static_cast<int32_t>(std::floor(my1F)));
46-
int32_t mx2 = std::min(maskW, static_cast<int32_t>(std::ceil(mx2F)));
47-
int32_t my2 = std::min(maskH, static_cast<int32_t>(std::ceil(my2F)));
57+
int32_t mx2 = std::min(maskSize.width, static_cast<int32_t>(std::ceil(mx2F)));
58+
int32_t my2 =
59+
std::min(maskSize.height, static_cast<int32_t>(std::ceil(my2F)));
60+
61+
return cv::Rect(mx1, my1, mx2 - mx1, my2 - my1);
62+
}
63+
64+
cv::Rect BaseInstanceSegmentation::addPaddingToRect(const cv::Rect &rect,
65+
cv::Size maskSize) {
66+
int32_t x1 = std::max(0, rect.x - 1);
67+
int32_t y1 = std::max(0, rect.y - 1);
68+
int32_t x2 = std::min(maskSize.width, rect.x + rect.width + 1);
69+
int32_t y2 = std::min(maskSize.height, rect.y + rect.height + 1);
70+
71+
return cv::Rect(x1, y1, x2 - x1, y2 - y1);
72+
}
4873

49-
cv::Mat finalBinaryMat;
50-
outWidth = bboxW;
51-
outHeight = bboxH;
74+
cv::Mat BaseInstanceSegmentation::applySigmoid(const cv::Mat &logits) {
75+
cv::Mat probMat;
76+
cv::exp(-logits, probMat);
77+
probMat = 255.0f / (1.0f + probMat);
78+
probMat.convertTo(probMat, CV_8UC1);
79+
return probMat;
80+
}
81+
82+
cv::Mat BaseInstanceSegmentation::warpToOriginalResolution(
83+
const cv::Mat &probMat, const cv::Rect &maskRect, cv::Size originalSize,
84+
cv::Size maskSize, const utils::computer_vision::BBox &bboxOriginal) {
85+
86+
float scaleX = static_cast<float>(originalSize.width) / maskSize.width;
87+
float scaleY = static_cast<float>(originalSize.height) / maskSize.height;
88+
89+
cv::Mat M = (cv::Mat_<float>(2, 3) << scaleX, 0,
90+
(maskRect.x * scaleX - bboxOriginal.x1), 0, scaleY,
91+
(maskRect.y * scaleY - bboxOriginal.y1));
92+
93+
cv::Size bboxSize(static_cast<int32_t>(std::round(bboxOriginal.width())),
94+
static_cast<int32_t>(std::round(bboxOriginal.height())));
95+
96+
cv::Mat warped;
97+
cv::warpAffine(probMat, warped, M, bboxSize, cv::INTER_LINEAR);
98+
return warped;
99+
}
100+
101+
cv::Mat BaseInstanceSegmentation::thresholdToBinary(const cv::Mat &probMat) {
102+
cv::Mat binary;
103+
cv::threshold(probMat, binary, 127, 1, cv::THRESH_BINARY);
104+
return binary;
105+
}
106+
107+
cv::Mat BaseInstanceSegmentation::processMaskFromLogits(
108+
const cv::Mat &logitsMat, const utils::computer_vision::BBox &bboxModel,
109+
const utils::computer_vision::BBox &bboxOriginal, cv::Size modelInputSize,
110+
cv::Size originalSize, cv::Size maskSize, bool warpToOriginal,
111+
cv::Size &outSize) {
112+
113+
cv::Rect cropRect = computeMaskCropRect(bboxModel, modelInputSize, maskSize);
114+
115+
if (warpToOriginal) {
116+
cropRect = addPaddingToRect(cropRect, maskSize);
117+
}
118+
119+
cv::Mat cropped = logitsMat(cropRect);
120+
cv::Mat probMat = applySigmoid(cropped);
52121

53122
if (warpToOriginal) {
54-
int32_t pmx1 = std::max(0, mx1 - 1);
55-
int32_t pmy1 = std::max(0, my1 - 1);
56-
int32_t pmx2 = std::min(maskW, mx2 + 1);
57-
int32_t pmy2 = std::min(maskH, my2 + 1);
58-
59-
cv::Mat croppedLogits =
60-
logitsMat(cv::Rect(pmx1, pmy1, pmx2 - pmx1, pmy2 - pmy1));
61-
cv::Mat probMat;
62-
cv::exp(-croppedLogits, probMat);
63-
probMat = 255.0f / (1.0f + probMat);
64-
probMat.convertTo(probMat, CV_8UC1);
65-
66-
float maskToOrigX = static_cast<float>(originalSize.width) / maskW;
67-
float maskToOrigY = static_cast<float>(originalSize.height) / maskH;
68-
69-
cv::Mat M =
70-
(cv::Mat_<float>(2, 3) << maskToOrigX, 0, (pmx1 * maskToOrigX - origX1),
71-
0, maskToOrigY, (pmy1 * maskToOrigY - origY1));
72-
73-
cv::Mat warpedMat;
74-
cv::warpAffine(probMat, warpedMat, M, cv::Size(bboxW, bboxH),
75-
cv::INTER_LINEAR);
76-
77-
cv::threshold(warpedMat, finalBinaryMat, 127, 1, cv::THRESH_BINARY);
78-
} else {
79-
cv::Mat croppedLogits = logitsMat(cv::Rect(mx1, my1, mx2 - mx1, my2 - my1));
80-
cv::Mat probMat;
81-
cv::exp(-croppedLogits, probMat);
82-
probMat = 255.0f / (1.0f + probMat);
83-
probMat.convertTo(probMat, CV_8UC1);
84-
85-
cv::threshold(probMat, finalBinaryMat, 127, 1, cv::THRESH_BINARY);
86-
outWidth = finalBinaryMat.cols;
87-
outHeight = finalBinaryMat.rows;
123+
probMat = warpToOriginalResolution(probMat, cropRect, originalSize,
124+
maskSize, bboxOriginal);
88125
}
126+
cv::Mat binaryMask = thresholdToBinary(probMat);
127+
outSize = cv::Size(binaryMask.cols, binaryMask.rows);
89128

90-
return finalBinaryMat;
129+
return binaryMask;
91130
}
92131

93132
std::optional<types::Instance> BaseInstanceSegmentation::processDetection(
94133
int32_t detectionIndex, const float *bboxData, const float *scoresData,
95-
const float *maskData, int32_t maskH, int32_t maskW,
96-
cv::Size modelInputSize, cv::Size originalSize, float widthRatio,
97-
float heightRatio, double confidenceThreshold,
134+
const cv::Mat &logitsMat, cv::Size modelInputSize, cv::Size originalSize,
135+
float widthRatio, float heightRatio, double confidenceThreshold,
98136
const std::set<int32_t> &allowedClasses,
99137
bool returnMaskAtOriginalResolution) {
100138

101-
int32_t i = detectionIndex;
102-
103-
float x1 = bboxData[i * 4 + 0];
104-
float y1 = bboxData[i * 4 + 1];
105-
float x2 = bboxData[i * 4 + 2];
106-
float y2 = bboxData[i * 4 + 3];
107-
float score = scoresData[i * 2 + 0];
108-
auto labelIdx = static_cast<std::size_t>(scoresData[i * 2 + 1]);
139+
// Extract detection data
140+
auto [bboxModel, score, labelIdx] =
141+
extractDetectionData(bboxData, scoresData, detectionIndex);
109142

143+
// Filter by confidence
110144
if (score < confidenceThreshold) {
111145
return std::nullopt;
112146
}
113-
if (!allowedClasses.empty() && allowedClasses.find(static_cast<int32_t>(
114-
labelIdx)) == allowedClasses.end()) {
147+
148+
// Filter by class
149+
if (!allowedClasses.empty() &&
150+
allowedClasses.find(labelIdx) == allowedClasses.end()) {
115151
return std::nullopt;
116152
}
117153

118154
// Scale bbox to original image coordinates
119-
float origX1 = x1 * widthRatio;
120-
float origY1 = y1 * heightRatio;
121-
float origX2 = x2 * widthRatio;
122-
float origY2 = y2 * heightRatio;
155+
utils::computer_vision::BBox bboxOriginal =
156+
bboxModel.scale(widthRatio, heightRatio);
123157

124-
int32_t bboxW = static_cast<int32_t>(std::round(origX2 - origX1));
125-
int32_t bboxH = static_cast<int32_t>(std::round(origY2 - origY1));
126-
127-
if (bboxW <= 0 || bboxH <= 0) {
158+
if (!bboxOriginal.isValid()) {
128159
return std::nullopt;
129160
}
130161

131-
const float *logits = maskData + (i * maskH * maskW);
132-
cv::Mat logitsMat(maskH, maskW, CV_32FC1, const_cast<float *>(logits));
133-
134-
int32_t finalMaskWidth, finalMaskHeight;
162+
// Process mask
163+
cv::Size maskSize(logitsMat.cols, logitsMat.rows);
164+
cv::Size outSize;
135165
cv::Mat finalBinaryMat = processMaskFromLogits(
136-
logitsMat, x1, y1, x2, y2, modelInputSize, originalSize, maskW, maskH,
137-
bboxW, bboxH, origX1, origY1, returnMaskAtOriginalResolution,
138-
finalMaskWidth, finalMaskHeight);
166+
logitsMat, bboxModel, bboxOriginal, modelInputSize, originalSize,
167+
maskSize, returnMaskAtOriginalResolution, outSize);
139168

169+
// Create instance
140170
std::vector<uint8_t> finalMask(finalBinaryMat.data,
141171
finalBinaryMat.data + finalBinaryMat.total());
142172

143-
return types::Instance(
144-
utils::computer_vision::BBox{origX1, origY1, origX2, origY2},
145-
std::move(finalMask), finalMaskWidth, finalMaskHeight,
146-
static_cast<int32_t>(labelIdx), score, i);
173+
return types::Instance(bboxOriginal, std::move(finalMask), outSize.width,
174+
outSize.height, labelIdx, score, detectionIndex);
147175
}
148176

149177
std::vector<types::Instance> BaseInstanceSegmentation::postprocess(
@@ -203,10 +231,14 @@ std::vector<types::Instance> BaseInstanceSegmentation::postprocess(
203231
int32_t processed = 0;
204232

205233
for (int32_t i = 0; i < N; ++i) {
234+
// Extract mask logits for this detection
235+
const float *logits = maskData + (i * maskH * maskW);
236+
cv::Mat logitsMat(maskH, maskW, CV_32FC1, const_cast<float *>(logits));
237+
206238
auto instance = processDetection(
207-
i, bboxData, scoresData, maskData, maskH, maskW, modelInputSize,
208-
originalSize, widthRatio, heightRatio, confidenceThreshold,
209-
allowedClasses, returnMaskAtOriginalResolution);
239+
i, bboxData, scoresData, logitsMat, modelInputSize, originalSize,
240+
widthRatio, heightRatio, confidenceThreshold, allowedClasses,
241+
returnMaskAtOriginalResolution);
210242

211243
if (instance.has_value()) {
212244
instances.push_back(std::move(*instance));

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "Types.h"
1010
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
1111
#include <rnexecutorch/models/BaseModel.h>
12+
#include <rnexecutorch/utils/computer_vision/Types.h>
1213

1314
namespace rnexecutorch {
1415
namespace models::instance_segmentation {
@@ -36,21 +37,38 @@ class BaseInstanceSegmentation : public BaseModel {
3637
const std::vector<int32_t> &classIndices,
3738
bool returnMaskAtOriginalResolution);
3839

39-
cv::Mat processMaskFromLogits(const cv::Mat &logitsMat, float x1, float y1,
40-
float x2, float y2, cv::Size modelInputSize,
41-
cv::Size originalSize, int32_t maskW,
42-
int32_t maskH, int32_t bboxW, int32_t bboxH,
43-
float origX1, float origY1, bool warpToOriginal,
44-
int32_t &outWidth, int32_t &outHeight);
45-
46-
std::optional<types::Instance>
47-
processDetection(int32_t detectionIndex, const float *bboxData,
48-
const float *scoresData, const float *maskData,
49-
int32_t maskH, int32_t maskW, cv::Size modelInputSize,
50-
cv::Size originalSize, float widthRatio, float heightRatio,
51-
double confidenceThreshold,
52-
const std::set<int32_t> &allowedClasses,
53-
bool returnMaskAtOriginalResolution);
40+
// Data extraction helpers
41+
std::tuple<utils::computer_vision::BBox, float, int32_t>
42+
extractDetectionData(const float *bboxData, const float *scoresData,
43+
int32_t index);
44+
45+
// Helper functions for mask processing
46+
cv::Rect computeMaskCropRect(const utils::computer_vision::BBox &bboxModel,
47+
cv::Size modelInputSize, cv::Size maskSize);
48+
49+
cv::Rect addPaddingToRect(const cv::Rect &rect, cv::Size maskSize);
50+
51+
cv::Mat applySigmoid(const cv::Mat &logits);
52+
53+
cv::Mat
54+
warpToOriginalResolution(const cv::Mat &probMat, const cv::Rect &maskRect,
55+
cv::Size originalSize, cv::Size maskSize,
56+
const utils::computer_vision::BBox &bboxOriginal);
57+
58+
cv::Mat thresholdToBinary(const cv::Mat &probMat);
59+
60+
cv::Mat processMaskFromLogits(
61+
const cv::Mat &logitsMat, const utils::computer_vision::BBox &bboxModel,
62+
const utils::computer_vision::BBox &bboxOriginal, cv::Size modelInputSize,
63+
cv::Size originalSize, cv::Size maskSize, bool warpToOriginal,
64+
cv::Size &outSize);
65+
66+
std::optional<types::Instance> processDetection(
67+
int32_t detectionIndex, const float *bboxData, const float *scoresData,
68+
const cv::Mat &logitsMat, cv::Size modelInputSize, cv::Size originalSize,
69+
float widthRatio, float heightRatio, double confidenceThreshold,
70+
const std::set<int32_t> &allowedClasses,
71+
bool returnMaskAtOriginalResolution);
5472

5573
// Member variables
5674
std::optional<cv::Scalar> normMean_;

packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Types.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ struct BBox {
1616
bool isValid() const {
1717
return x2 > x1 && y2 > y1 && x1 >= 0.0f && y1 >= 0.0f;
1818
}
19+
20+
BBox scale(float widthRatio, float heightRatio) const {
21+
return {x1 * widthRatio, y1 * heightRatio, x2 * widthRatio,
22+
y2 * heightRatio};
23+
}
1924
};
2025

2126
template <typename T>

0 commit comments

Comments
 (0)