Skip to content

Commit 8557eba

Browse files
committed
Streamline postprocessing
1 parent b9794ea commit 8557eba

File tree

2 files changed

+126
-129
lines changed

2 files changed

+126
-129
lines changed

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp

Lines changed: 110 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ cv::Mat BaseInstanceSegmentation::thresholdToBinary(const cv::Mat &probMat) {
107107
cv::Mat BaseInstanceSegmentation::processMaskFromLogits(
108108
const cv::Mat &logitsMat, const utils::computer_vision::BBox &bboxModel,
109109
const utils::computer_vision::BBox &bboxOriginal, cv::Size modelInputSize,
110-
cv::Size originalSize, cv::Size maskSize, bool warpToOriginal,
111-
cv::Size &outSize) {
110+
cv::Size originalSize, bool warpToOriginal) {
112111

112+
cv::Size maskSize = logitsMat.size();
113113
cv::Rect cropRect = computeMaskCropRect(bboxModel, modelInputSize, maskSize);
114114

115115
if (warpToOriginal) {
@@ -123,55 +123,87 @@ cv::Mat BaseInstanceSegmentation::processMaskFromLogits(
123123
probMat = warpToOriginalResolution(probMat, cropRect, originalSize,
124124
maskSize, bboxOriginal);
125125
}
126-
cv::Mat binaryMask = thresholdToBinary(probMat);
127-
outSize = cv::Size(binaryMask.cols, binaryMask.rows);
126+
return thresholdToBinary(probMat);
127+
}
128+
129+
void BaseInstanceSegmentation::validateThresholds(double confidenceThreshold,
130+
double iouThreshold) const {
131+
if (confidenceThreshold < 0 || confidenceThreshold > 1) {
132+
throw RnExecutorchError(
133+
RnExecutorchErrorCode::InvalidConfig,
134+
"Confidence threshold must be greater or equal to 0 "
135+
"and less than or equal to 1.");
136+
}
128137

129-
return binaryMask;
138+
if (iouThreshold < 0 || iouThreshold > 1) {
139+
throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig,
140+
"IoU threshold must be greater or equal to 0 "
141+
"and less than or equal to 1.");
142+
}
130143
}
131144

132-
std::optional<types::Instance> BaseInstanceSegmentation::processDetection(
133-
int32_t detectionIndex, const float *bboxData, const float *scoresData,
134-
const cv::Mat &logitsMat, cv::Size modelInputSize, cv::Size originalSize,
135-
float widthRatio, float heightRatio, double confidenceThreshold,
136-
const std::set<int32_t> &allowedClasses,
137-
bool returnMaskAtOriginalResolution) {
145+
void BaseInstanceSegmentation::validateOutputTensors(
146+
const std::vector<EValue> &tensors) const {
147+
if (tensors.size() != 3) {
148+
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
149+
"Expected 3 output tensors ([1,N,4] + [1,N,2] + "
150+
"[1,N,H,W]), got " +
151+
std::to_string(tensors.size()));
152+
}
153+
}
138154

139-
// Extract detection data
140-
auto [bboxModel, score, labelIdx] =
141-
extractDetectionData(bboxData, scoresData, detectionIndex);
155+
std::set<int32_t> BaseInstanceSegmentation::prepareAllowedClasses(
156+
const std::vector<int32_t> &classIndices) const {
157+
std::set<int32_t> allowedClasses;
158+
if (!classIndices.empty()) {
159+
allowedClasses.insert(classIndices.begin(), classIndices.end());
160+
}
161+
return allowedClasses;
162+
}
142163

143-
// Filter by confidence
144-
if (score < confidenceThreshold) {
145-
return std::nullopt;
164+
void BaseInstanceSegmentation::ensureMethodLoaded(
165+
const std::string &methodName) {
166+
if (methodName.empty()) {
167+
throw RnExecutorchError(
168+
RnExecutorchErrorCode::InvalidConfig,
169+
"methodName cannot be empty. Use 'forward' for single-method models "
170+
"or 'forward_{inputSize}' for multi-method models.");
146171
}
147172

148-
// Filter by class
149-
if (!allowedClasses.empty() &&
150-
allowedClasses.find(labelIdx) == allowedClasses.end()) {
151-
return std::nullopt;
173+
if (currentlyLoadedMethod_ != methodName) {
174+
if (!currentlyLoadedMethod_.empty()) {
175+
module_->unload_method(currentlyLoadedMethod_);
176+
}
177+
currentlyLoadedMethod_ = methodName;
178+
module_->load_method(methodName);
152179
}
180+
}
153181

154-
// Scale bbox to original image coordinates
155-
utils::computer_vision::BBox bboxOriginal =
156-
bboxModel.scale(widthRatio, heightRatio);
182+
cv::Size BaseInstanceSegmentation::getInputSize(const std::string &methodName) {
183+
auto inputShapes = getAllInputShapes(methodName);
184+
std::vector<int32_t> inputShape = inputShapes[0];
185+
int32_t inputSize = inputShape[inputShape.size() - 1];
186+
return cv::Size(inputSize, inputSize);
187+
}
188+
189+
std::vector<types::Instance> BaseInstanceSegmentation::finalizeInstances(
190+
std::vector<types::Instance> instances, double iouThreshold,
191+
int32_t maxInstances) const {
157192

158-
if (!bboxOriginal.isValid()) {
159-
return std::nullopt;
193+
if (applyNMS_) {
194+
instances =
195+
utils::computer_vision::nonMaxSuppression(instances, iouThreshold);
160196
}
161197

162-
// Process mask
163-
cv::Size maskSize(logitsMat.cols, logitsMat.rows);
164-
cv::Size outSize;
165-
cv::Mat finalBinaryMat = processMaskFromLogits(
166-
logitsMat, bboxModel, bboxOriginal, modelInputSize, originalSize,
167-
maskSize, returnMaskAtOriginalResolution, outSize);
198+
if (std::cmp_greater(instances.size(), maxInstances)) {
199+
instances.resize(maxInstances);
200+
}
168201

169-
// Create instance
170-
std::vector<uint8_t> finalMask(finalBinaryMat.data,
171-
finalBinaryMat.data + finalBinaryMat.total());
202+
for (int32_t i = 0; i < instances.size(); ++i) {
203+
instances[i].instanceId = static_cast<int32_t>(i);
204+
}
172205

173-
return types::Instance(bboxOriginal, std::move(finalMask), outSize.width,
174-
outSize.height, labelIdx, score, detectionIndex);
206+
return instances;
175207
}
176208

177209
std::vector<types::Instance> BaseInstanceSegmentation::postprocess(
@@ -180,116 +212,79 @@ std::vector<types::Instance> BaseInstanceSegmentation::postprocess(
180212
int32_t maxInstances, const std::vector<int32_t> &classIndices,
181213
bool returnMaskAtOriginalResolution) {
182214

183-
if (confidenceThreshold < 0 || confidenceThreshold > 1) {
184-
throw RnExecutorchError(
185-
RnExecutorchErrorCode::InvalidConfig,
186-
"Confidence threshold must be greater or equal to 0 "
187-
"and less than or equal to 1.");
188-
}
189-
190-
if (iouThreshold < 0 || iouThreshold > 1) {
191-
throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig,
192-
"IoU threshold must be greater or equal to 0 "
193-
"and less than or equal to 1.");
194-
}
215+
validateThresholds(confidenceThreshold, iouThreshold);
216+
validateOutputTensors(tensors);
195217

196218
float widthRatio =
197219
static_cast<float>(originalSize.width) / modelInputSize.width;
198220
float heightRatio =
199221
static_cast<float>(originalSize.height) / modelInputSize.height;
222+
std::set<int32_t> allowedClasses = prepareAllowedClasses(classIndices);
200223

201-
std::set<int32_t> allowedClasses;
202-
if (!classIndices.empty()) {
203-
allowedClasses.insert(classIndices.begin(), classIndices.end());
204-
}
205-
206-
std::vector<types::Instance> instances;
207-
208-
size_t numTensors = tensors.size();
209-
if (numTensors != 3) {
210-
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
211-
"Expected 3 output tensors ([1,N,4] + [1,N,2] + "
212-
"[1,N,H,W]), got " +
213-
std::to_string(numTensors));
214-
}
215-
216-
// CONTRACT: [1,N,4] + [1,N,2] + [1,N,H,W]
217-
// bbox: [x1, y1, x2, y2] in model input coordinates
218-
// scores: [max_score, class_id] — post-sigmoid
219-
// mask_logits: pre-sigmoid, per-detection
224+
// CONTRACT
220225
auto bboxTensor = tensors[0].toTensor(); // [1, N, 4]
221226
auto scoresTensor = tensors[1].toTensor(); // [1, N, 2]
222227
auto maskTensor = tensors[2].toTensor(); // [1, N, H, W]
223228

224229
int32_t N = bboxTensor.size(1);
225230
int32_t maskH = maskTensor.size(2);
226231
int32_t maskW = maskTensor.size(3);
232+
227233
const float *bboxData = bboxTensor.const_data_ptr<float>();
228234
const float *scoresData = scoresTensor.const_data_ptr<float>();
229235
const float *maskData = maskTensor.const_data_ptr<float>();
230236

231-
int32_t processed = 0;
237+
auto isValidDetection =
238+
[&allowedClasses, &confidenceThreshold](float score, int32_t labelIdx) {
239+
if (score < confidenceThreshold)
240+
return false;
241+
if (!allowedClasses.empty() && allowedClasses.count(labelIdx) == 0)
242+
return false;
243+
return true;
244+
};
245+
246+
std::vector<types::Instance> instances;
232247

233248
for (int32_t i = 0; i < N; ++i) {
234-
// Extract mask logits for this detection
235-
const float *logits = maskData + (i * maskH * maskW);
236-
cv::Mat logitsMat(maskH, maskW, CV_32FC1, const_cast<float *>(logits));
249+
auto [bboxModel, score, labelIdx] =
250+
extractDetectionData(bboxData, scoresData, i);
237251

238-
auto instance = processDetection(
239-
i, bboxData, scoresData, logitsMat, modelInputSize, originalSize,
240-
widthRatio, heightRatio, confidenceThreshold, allowedClasses,
241-
returnMaskAtOriginalResolution);
252+
if (!isValidDetection(score, labelIdx))
253+
continue;
242254

243-
if (instance.has_value()) {
244-
instances.push_back(std::move(*instance));
245-
++processed;
246-
}
247-
}
255+
utils::computer_vision::BBox bboxOriginal =
256+
bboxModel.scale(widthRatio, heightRatio);
257+
if (!bboxOriginal.isValid())
258+
continue;
248259

249-
// Finalize: NMS + limit + renumber
250-
if (applyNMS_) {
251-
instances =
252-
utils::computer_vision::nonMaxSuppression(instances, iouThreshold);
253-
}
260+
cv::Mat logitsMat(maskH, maskW, CV_32FC1,
261+
const_cast<float *>(maskData + (i * maskH * maskW)));
254262

255-
if (std::cmp_greater(instances.size(), maxInstances)) {
256-
instances.resize(maxInstances);
257-
}
263+
cv::Mat binaryMask = processMaskFromLogits(
264+
logitsMat, bboxModel, bboxOriginal, modelInputSize, originalSize,
265+
returnMaskAtOriginalResolution);
258266

259-
for (size_t i = 0; i < instances.size(); ++i) {
260-
instances[i].instanceId = static_cast<int32_t>(i);
267+
instances.emplace_back(
268+
bboxOriginal,
269+
std::vector<uint8_t>(binaryMask.data,
270+
binaryMask.data + binaryMask.total()),
271+
binaryMask.cols, binaryMask.rows, labelIdx, score, i);
261272
}
262273

263-
return instances;
274+
return finalizeInstances(std::move(instances), iouThreshold, maxInstances);
264275
}
265276

266277
std::vector<types::Instance> BaseInstanceSegmentation::generate(
267278
std::string imageSource, double confidenceThreshold, double iouThreshold,
268279
int32_t maxInstances, std::vector<int32_t> classIndices,
269280
bool returnMaskAtOriginalResolution, std::string methodName) {
270281

271-
if (methodName.empty()) {
272-
throw RnExecutorchError(
273-
RnExecutorchErrorCode::InvalidConfig,
274-
"methodName cannot be empty. Use 'forward' for single-method models "
275-
"or 'forward_{inputSize}' for multi-method models.");
276-
}
277-
278-
if (currentlyLoadedMethod_ != methodName) {
279-
if (!currentlyLoadedMethod_.empty()) {
280-
module_->unload_method(currentlyLoadedMethod_);
281-
}
282-
currentlyLoadedMethod_ = methodName;
283-
module_->load_method(methodName);
284-
}
285-
286-
auto inputShapes = getAllInputShapes(methodName);
287-
std::vector<int32_t> inputShape = inputShapes[0];
288-
int32_t inputSize = inputShape[inputShape.size() - 1];
289-
cv::Size modelInputSize(inputSize, inputSize);
282+
ensureMethodLoaded(methodName);
283+
cv::Size modelInputSize = getInputSize(methodName);
290284

291285
auto [inputTensor, originalSize] = image_processing::readImageToTensor(
292-
imageSource, inputShape, false, normMean_, normStd_);
286+
imageSource, getAllInputShapes(methodName)[0], false, normMean_,
287+
normStd_);
293288

294289
auto forwardResult = BaseModel::execute(methodName, {inputTensor});
295290
if (!forwardResult.ok()) {
@@ -300,11 +295,9 @@ std::vector<types::Instance> BaseInstanceSegmentation::generate(
300295
methodName + "' is valid.");
301296
}
302297

303-
auto result = postprocess(forwardResult.get(), originalSize, modelInputSize,
304-
confidenceThreshold, iouThreshold, maxInstances,
305-
classIndices, returnMaskAtOriginalResolution);
306-
307-
return result;
298+
return postprocess(forwardResult.get(), originalSize, modelInputSize,
299+
confidenceThreshold, iouThreshold, maxInstances,
300+
classIndices, returnMaskAtOriginalResolution);
308301
}
309302

310303
} // namespace rnexecutorch::models::instance_segmentation

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,21 @@ class BaseInstanceSegmentation : public BaseModel {
3737
const std::vector<int32_t> &classIndices,
3838
bool returnMaskAtOriginalResolution);
3939

40-
// Data extraction helpers
40+
void validateThresholds(double confidenceThreshold,
41+
double iouThreshold) const;
42+
void validateOutputTensors(const std::vector<EValue> &tensors) const;
43+
44+
std::set<int32_t>
45+
prepareAllowedClasses(const std::vector<int32_t> &classIndices) const;
46+
47+
// Model loading and input helpers
48+
void ensureMethodLoaded(const std::string &methodName);
49+
cv::Size getInputSize(const std::string &methodName);
50+
4151
std::tuple<utils::computer_vision::BBox, float, int32_t>
4252
extractDetectionData(const float *bboxData, const float *scoresData,
4353
int32_t index);
4454

45-
// Helper functions for mask processing
4655
cv::Rect computeMaskCropRect(const utils::computer_vision::BBox &bboxModel,
4756
cv::Size modelInputSize, cv::Size maskSize);
4857

@@ -57,20 +66,15 @@ class BaseInstanceSegmentation : public BaseModel {
5766

5867
cv::Mat thresholdToBinary(const cv::Mat &probMat);
5968

69+
std::vector<types::Instance>
70+
finalizeInstances(std::vector<types::Instance> instances, double iouThreshold,
71+
int32_t maxInstances) const;
72+
6073
cv::Mat processMaskFromLogits(
6174
const cv::Mat &logitsMat, const utils::computer_vision::BBox &bboxModel,
6275
const utils::computer_vision::BBox &bboxOriginal, cv::Size modelInputSize,
63-
cv::Size originalSize, cv::Size maskSize, bool warpToOriginal,
64-
cv::Size &outSize);
65-
66-
std::optional<types::Instance> processDetection(
67-
int32_t detectionIndex, const float *bboxData, const float *scoresData,
68-
const cv::Mat &logitsMat, cv::Size modelInputSize, cv::Size originalSize,
69-
float widthRatio, float heightRatio, double confidenceThreshold,
70-
const std::set<int32_t> &allowedClasses,
71-
bool returnMaskAtOriginalResolution);
76+
cv::Size originalSize, bool warpToOriginal);
7277

73-
// Member variables
7478
std::optional<cv::Scalar> normMean_;
7579
std::optional<cv::Scalar> normStd_;
7680
bool applyNMS_;

0 commit comments

Comments
 (0)