@@ -107,9 +107,9 @@ cv::Mat BaseInstanceSegmentation::thresholdToBinary(const cv::Mat &probMat) {
107107cv::Mat BaseInstanceSegmentation::processMaskFromLogits (
108108 const cv::Mat &logitsMat, const utils::computer_vision::BBox &bboxModel,
109109 const utils::computer_vision::BBox &bboxOriginal, cv::Size modelInputSize,
110- cv::Size originalSize, cv::Size maskSize, bool warpToOriginal,
111- cv::Size &outSize) {
110+ cv::Size originalSize, bool warpToOriginal) {
112111
112+ cv::Size maskSize = logitsMat.size ();
113113 cv::Rect cropRect = computeMaskCropRect (bboxModel, modelInputSize, maskSize);
114114
115115 if (warpToOriginal) {
@@ -123,55 +123,87 @@ cv::Mat BaseInstanceSegmentation::processMaskFromLogits(
123123 probMat = warpToOriginalResolution (probMat, cropRect, originalSize,
124124 maskSize, bboxOriginal);
125125 }
126- cv::Mat binaryMask = thresholdToBinary (probMat);
127- outSize = cv::Size (binaryMask.cols , binaryMask.rows );
126+ return thresholdToBinary (probMat);
127+ }
128+
129+ void BaseInstanceSegmentation::validateThresholds (double confidenceThreshold,
130+ double iouThreshold) const {
131+ if (confidenceThreshold < 0 || confidenceThreshold > 1 ) {
132+ throw RnExecutorchError (
133+ RnExecutorchErrorCode::InvalidConfig,
134+ " Confidence threshold must be greater or equal to 0 "
135+ " and less than or equal to 1." );
136+ }
128137
129- return binaryMask;
138+ if (iouThreshold < 0 || iouThreshold > 1 ) {
139+ throw RnExecutorchError (RnExecutorchErrorCode::InvalidConfig,
140+ " IoU threshold must be greater or equal to 0 "
141+ " and less than or equal to 1." );
142+ }
130143}
131144
132- std::optional<types::Instance> BaseInstanceSegmentation::processDetection (
133- int32_t detectionIndex, const float *bboxData, const float *scoresData,
134- const cv::Mat &logitsMat, cv::Size modelInputSize, cv::Size originalSize,
135- float widthRatio, float heightRatio, double confidenceThreshold,
136- const std::set<int32_t > &allowedClasses,
137- bool returnMaskAtOriginalResolution) {
145+ void BaseInstanceSegmentation::validateOutputTensors (
146+ const std::vector<EValue> &tensors) const {
147+ if (tensors.size () != 3 ) {
148+ throw RnExecutorchError (RnExecutorchErrorCode::UnexpectedNumInputs,
149+ " Expected 3 output tensors ([1,N,4] + [1,N,2] + "
150+ " [1,N,H,W]), got " +
151+ std::to_string (tensors.size ()));
152+ }
153+ }
138154
139- // Extract detection data
140- auto [bboxModel, score, labelIdx] =
141- extractDetectionData (bboxData, scoresData, detectionIndex);
155+ std::set<int32_t > BaseInstanceSegmentation::prepareAllowedClasses (
156+ const std::vector<int32_t > &classIndices) const {
157+ std::set<int32_t > allowedClasses;
158+ if (!classIndices.empty ()) {
159+ allowedClasses.insert (classIndices.begin (), classIndices.end ());
160+ }
161+ return allowedClasses;
162+ }
142163
143- // Filter by confidence
144- if (score < confidenceThreshold) {
145- return std::nullopt ;
164+ void BaseInstanceSegmentation::ensureMethodLoaded (
165+ const std::string &methodName) {
166+ if (methodName.empty ()) {
167+ throw RnExecutorchError (
168+ RnExecutorchErrorCode::InvalidConfig,
169+ " methodName cannot be empty. Use 'forward' for single-method models "
170+ " or 'forward_{inputSize}' for multi-method models." );
146171 }
147172
148- // Filter by class
149- if (!allowedClasses.empty () &&
150- allowedClasses.find (labelIdx) == allowedClasses.end ()) {
151- return std::nullopt ;
173+ if (currentlyLoadedMethod_ != methodName) {
174+ if (!currentlyLoadedMethod_.empty ()) {
175+ module_->unload_method (currentlyLoadedMethod_);
176+ }
177+ currentlyLoadedMethod_ = methodName;
178+ module_->load_method (methodName);
152179 }
180+ }
153181
154- // Scale bbox to original image coordinates
155- utils::computer_vision::BBox bboxOriginal =
156- bboxModel.scale (widthRatio, heightRatio);
182+ cv::Size BaseInstanceSegmentation::getInputSize (const std::string &methodName) {
183+ auto inputShapes = getAllInputShapes (methodName);
184+ std::vector<int32_t > inputShape = inputShapes[0 ];
185+ int32_t inputSize = inputShape[inputShape.size () - 1 ];
186+ return cv::Size (inputSize, inputSize);
187+ }
188+
189+ std::vector<types::Instance> BaseInstanceSegmentation::finalizeInstances (
190+ std::vector<types::Instance> instances, double iouThreshold,
191+ int32_t maxInstances) const {
157192
158- if (!bboxOriginal.isValid ()) {
159- return std::nullopt ;
193+ if (applyNMS_) {
194+ instances =
195+ utils::computer_vision::nonMaxSuppression (instances, iouThreshold);
160196 }
161197
162- // Process mask
163- cv::Size maskSize (logitsMat.cols , logitsMat.rows );
164- cv::Size outSize;
165- cv::Mat finalBinaryMat = processMaskFromLogits (
166- logitsMat, bboxModel, bboxOriginal, modelInputSize, originalSize,
167- maskSize, returnMaskAtOriginalResolution, outSize);
198+ if (std::cmp_greater (instances.size (), maxInstances)) {
199+ instances.resize (maxInstances);
200+ }
168201
169- // Create instance
170- std::vector< uint8_t > finalMask (finalBinaryMat. data ,
171- finalBinaryMat. data + finalBinaryMat. total ());
202+ for ( int32_t i = 0 ; i < instances. size (); ++i) {
203+ instances[i]. instanceId = static_cast < int32_t >(i);
204+ }
172205
173- return types::Instance (bboxOriginal, std::move (finalMask), outSize.width ,
174- outSize.height , labelIdx, score, detectionIndex);
206+ return instances;
175207}
176208
177209std::vector<types::Instance> BaseInstanceSegmentation::postprocess (
@@ -180,116 +212,79 @@ std::vector<types::Instance> BaseInstanceSegmentation::postprocess(
180212 int32_t maxInstances, const std::vector<int32_t > &classIndices,
181213 bool returnMaskAtOriginalResolution) {
182214
183- if (confidenceThreshold < 0 || confidenceThreshold > 1 ) {
184- throw RnExecutorchError (
185- RnExecutorchErrorCode::InvalidConfig,
186- " Confidence threshold must be greater or equal to 0 "
187- " and less than or equal to 1." );
188- }
189-
190- if (iouThreshold < 0 || iouThreshold > 1 ) {
191- throw RnExecutorchError (RnExecutorchErrorCode::InvalidConfig,
192- " IoU threshold must be greater or equal to 0 "
193- " and less than or equal to 1." );
194- }
215+ validateThresholds (confidenceThreshold, iouThreshold);
216+ validateOutputTensors (tensors);
195217
196218 float widthRatio =
197219 static_cast <float >(originalSize.width ) / modelInputSize.width ;
198220 float heightRatio =
199221 static_cast <float >(originalSize.height ) / modelInputSize.height ;
222+ std::set<int32_t > allowedClasses = prepareAllowedClasses (classIndices);
200223
201- std::set<int32_t > allowedClasses;
202- if (!classIndices.empty ()) {
203- allowedClasses.insert (classIndices.begin (), classIndices.end ());
204- }
205-
206- std::vector<types::Instance> instances;
207-
208- size_t numTensors = tensors.size ();
209- if (numTensors != 3 ) {
210- throw RnExecutorchError (RnExecutorchErrorCode::UnexpectedNumInputs,
211- " Expected 3 output tensors ([1,N,4] + [1,N,2] + "
212- " [1,N,H,W]), got " +
213- std::to_string (numTensors));
214- }
215-
216- // CONTRACT: [1,N,4] + [1,N,2] + [1,N,H,W]
217- // bbox: [x1, y1, x2, y2] in model input coordinates
218- // scores: [max_score, class_id] — post-sigmoid
219- // mask_logits: pre-sigmoid, per-detection
224+ // CONTRACT
220225 auto bboxTensor = tensors[0 ].toTensor (); // [1, N, 4]
221226 auto scoresTensor = tensors[1 ].toTensor (); // [1, N, 2]
222227 auto maskTensor = tensors[2 ].toTensor (); // [1, N, H, W]
223228
224229 int32_t N = bboxTensor.size (1 );
225230 int32_t maskH = maskTensor.size (2 );
226231 int32_t maskW = maskTensor.size (3 );
232+
227233 const float *bboxData = bboxTensor.const_data_ptr <float >();
228234 const float *scoresData = scoresTensor.const_data_ptr <float >();
229235 const float *maskData = maskTensor.const_data_ptr <float >();
230236
231- int32_t processed = 0 ;
237+ auto isValidDetection =
238+ [&allowedClasses, &confidenceThreshold](float score, int32_t labelIdx) {
239+ if (score < confidenceThreshold)
240+ return false ;
241+ if (!allowedClasses.empty () && allowedClasses.count (labelIdx) == 0 )
242+ return false ;
243+ return true ;
244+ };
245+
246+ std::vector<types::Instance> instances;
232247
233248 for (int32_t i = 0 ; i < N; ++i) {
234- // Extract mask logits for this detection
235- const float *logits = maskData + (i * maskH * maskW);
236- cv::Mat logitsMat (maskH, maskW, CV_32FC1, const_cast <float *>(logits));
249+ auto [bboxModel, score, labelIdx] =
250+ extractDetectionData (bboxData, scoresData, i);
237251
238- auto instance = processDetection (
239- i, bboxData, scoresData, logitsMat, modelInputSize, originalSize,
240- widthRatio, heightRatio, confidenceThreshold, allowedClasses,
241- returnMaskAtOriginalResolution);
252+ if (!isValidDetection (score, labelIdx))
253+ continue ;
242254
243- if (instance.has_value ()) {
244- instances.push_back (std::move (*instance));
245- ++processed;
246- }
247- }
255+ utils::computer_vision::BBox bboxOriginal =
256+ bboxModel.scale (widthRatio, heightRatio);
257+ if (!bboxOriginal.isValid ())
258+ continue ;
248259
249- // Finalize: NMS + limit + renumber
250- if (applyNMS_) {
251- instances =
252- utils::computer_vision::nonMaxSuppression (instances, iouThreshold);
253- }
260+ cv::Mat logitsMat (maskH, maskW, CV_32FC1,
261+ const_cast <float *>(maskData + (i * maskH * maskW)));
254262
255- if ( std::cmp_greater (instances. size (), maxInstances)) {
256- instances. resize (maxInstances);
257- }
263+ cv::Mat binaryMask = processMaskFromLogits (
264+ logitsMat, bboxModel, bboxOriginal, modelInputSize, originalSize,
265+ returnMaskAtOriginalResolution);
258266
259- for (size_t i = 0 ; i < instances.size (); ++i) {
260- instances[i].instanceId = static_cast <int32_t >(i);
267+ instances.emplace_back (
268+ bboxOriginal,
269+ std::vector<uint8_t >(binaryMask.data ,
270+ binaryMask.data + binaryMask.total ()),
271+ binaryMask.cols , binaryMask.rows , labelIdx, score, i);
261272 }
262273
263- return instances;
274+ return finalizeInstances ( std::move ( instances), iouThreshold, maxInstances) ;
264275}
265276
266277std::vector<types::Instance> BaseInstanceSegmentation::generate (
267278 std::string imageSource, double confidenceThreshold, double iouThreshold,
268279 int32_t maxInstances, std::vector<int32_t > classIndices,
269280 bool returnMaskAtOriginalResolution, std::string methodName) {
270281
271- if (methodName.empty ()) {
272- throw RnExecutorchError (
273- RnExecutorchErrorCode::InvalidConfig,
274- " methodName cannot be empty. Use 'forward' for single-method models "
275- " or 'forward_{inputSize}' for multi-method models." );
276- }
277-
278- if (currentlyLoadedMethod_ != methodName) {
279- if (!currentlyLoadedMethod_.empty ()) {
280- module_->unload_method (currentlyLoadedMethod_);
281- }
282- currentlyLoadedMethod_ = methodName;
283- module_->load_method (methodName);
284- }
285-
286- auto inputShapes = getAllInputShapes (methodName);
287- std::vector<int32_t > inputShape = inputShapes[0 ];
288- int32_t inputSize = inputShape[inputShape.size () - 1 ];
289- cv::Size modelInputSize (inputSize, inputSize);
282+ ensureMethodLoaded (methodName);
283+ cv::Size modelInputSize = getInputSize (methodName);
290284
291285 auto [inputTensor, originalSize] = image_processing::readImageToTensor (
292- imageSource, inputShape, false , normMean_, normStd_);
286+ imageSource, getAllInputShapes (methodName)[0 ], false , normMean_,
287+ normStd_);
293288
294289 auto forwardResult = BaseModel::execute (methodName, {inputTensor});
295290 if (!forwardResult.ok ()) {
@@ -300,11 +295,9 @@ std::vector<types::Instance> BaseInstanceSegmentation::generate(
300295 methodName + " ' is valid." );
301296 }
302297
303- auto result = postprocess (forwardResult.get (), originalSize, modelInputSize,
304- confidenceThreshold, iouThreshold, maxInstances,
305- classIndices, returnMaskAtOriginalResolution);
306-
307- return result;
298+ return postprocess (forwardResult.get (), originalSize, modelInputSize,
299+ confidenceThreshold, iouThreshold, maxInstances,
300+ classIndices, returnMaskAtOriginalResolution);
308301}
309302
310303} // namespace rnexecutorch::models::instance_segmentation
0 commit comments