@@ -43,6 +43,16 @@ cv::Size BaseInstanceSegmentation::modelInputSize() const {
4343 return {shape[shape.size () - 2 ], shape[shape.size () - 1 ]};
4444}
4545
46+ TensorPtr BaseInstanceSegmentation::buildInputTensor (const cv::Mat &image) {
47+ cv::Mat preprocessed = preprocess (image);
48+ return (normMean_.has_value () && normStd_.has_value ())
49+ ? image_processing::getTensorFromMatrix (
50+ modelInputShape_, preprocessed, normMean_.value (),
51+ normStd_.value ())
52+ : image_processing::getTensorFromMatrix (modelInputShape_,
53+ preprocessed);
54+ }
55+
4656std::vector<types::Instance> BaseInstanceSegmentation::runInference (
4757 const cv::Mat &image, double confidenceThreshold, double iouThreshold,
4858 int32_t maxInstances, const std::vector<int32_t > &classIndices,
@@ -64,16 +74,8 @@ std::vector<types::Instance> BaseInstanceSegmentation::runInference(
6474 cv::Size modelInputSize (shape[shape.size () - 2 ], shape[shape.size () - 1 ]);
6575 cv::Size originalSize (image.cols , image.rows );
6676
67- cv::Mat preprocessed = preprocess (image);
68-
69- auto inputTensor = (normMean_.has_value () && normStd_.has_value ())
70- ? image_processing::getTensorFromMatrix (
71- modelInputShape_, preprocessed,
72- normMean_.value (), normStd_.value ())
73- : image_processing::getTensorFromMatrix (
74- modelInputShape_, preprocessed);
75-
76- auto forwardResult = BaseModel::execute (methodName, {inputTensor});
77+ auto forwardResult =
78+ BaseModel::execute (methodName, {buildInputTensor (image)});
7779 if (!forwardResult.ok ()) {
7880 throw RnExecutorchError (
7981 forwardResult.error (),
@@ -82,9 +84,13 @@ std::vector<types::Instance> BaseInstanceSegmentation::runInference(
8284 methodName + " ' is valid." );
8385 }
8486
85- return postprocess (forwardResult.get (), originalSize, modelInputSize,
86- confidenceThreshold, iouThreshold, maxInstances,
87- classIndices, returnMaskAtOriginalResolution);
87+ validateThresholds (confidenceThreshold, iouThreshold);
88+ validateOutputTensors (forwardResult.get ());
89+
90+ auto instances = collectInstances (
91+ forwardResult.get (), originalSize, modelInputSize, confidenceThreshold,
92+ classIndices, returnMaskAtOriginalResolution);
93+ return finalizeInstances (std::move (instances), iouThreshold, maxInstances);
8894}
8995
9096std::vector<types::Instance> BaseInstanceSegmentation::generateFromString (
@@ -259,37 +265,16 @@ void BaseInstanceSegmentation::ensureMethodLoaded(
259265 methodName + " '" );
260266 }
261267
262- if (currentlyLoadedMethod_ != methodName) {
263- if (!currentlyLoadedMethod_.empty ()) {
264- module_->unload_method (currentlyLoadedMethod_);
265- }
266- currentlyLoadedMethod_ = methodName;
267- auto loadResult = module_->load_method (methodName);
268- if (loadResult != executorch::runtime::Error::Ok) {
269- throw RnExecutorchError (
270- loadResult, " Failed to load method '" + methodName +
271- " '. Ensure the method exists in the exported model." );
272- }
268+ if (!currentlyLoadedMethod_.empty ()) {
269+ module_->unload_method (currentlyLoadedMethod_);
273270 }
274- }
275-
276- cv::Size BaseInstanceSegmentation::getInputSize (const std::string &methodName) {
277- auto inputShapes = getAllInputShapes (methodName);
278- if (inputShapes.empty ()) {
279- throw RnExecutorchError (RnExecutorchErrorCode::UnexpectedNumInputs,
280- " Method '" + methodName +
281- " ' has no input tensors." );
282- }
283-
284- const auto &inputShape = inputShapes[0 ];
285- if (inputShape.empty ()) {
286- throw RnExecutorchError (RnExecutorchErrorCode::UnexpectedNumInputs,
287- " Method '" + methodName +
288- " ' input tensor has no dimensions." );
271+ currentlyLoadedMethod_ = methodName;
272+ auto loadResult = module_->load_method (methodName);
273+ if (loadResult != executorch::runtime::Error::Ok) {
274+ throw RnExecutorchError (
275+ loadResult, " Failed to load method '" + methodName +
276+ " '. Ensure the method exists in the exported model." );
289277 }
290-
291- int32_t inputSize = inputShape[inputShape.size () - 1 ];
292- return cv::Size (inputSize, inputSize);
293278}
294279
295280std::vector<types::Instance> BaseInstanceSegmentation::finalizeInstances (
@@ -308,15 +293,12 @@ std::vector<types::Instance> BaseInstanceSegmentation::finalizeInstances(
308293 return instances;
309294}
310295
311- std::vector<types::Instance> BaseInstanceSegmentation::postprocess (
296+ std::vector<types::Instance> BaseInstanceSegmentation::collectInstances (
312297 const std::vector<EValue> &tensors, cv::Size originalSize,
313- cv::Size modelInputSize, double confidenceThreshold, double iouThreshold,
314- int32_t maxInstances, const std::vector<int32_t > &classIndices,
298+ cv::Size modelInputSize, double confidenceThreshold,
299+ const std::vector<int32_t > &classIndices,
315300 bool returnMaskAtOriginalResolution) {
316301
317- validateThresholds (confidenceThreshold, iouThreshold);
318- validateOutputTensors (tensors);
319-
320302 float widthRatio =
321303 static_cast <float >(originalSize.width ) / modelInputSize.width ;
322304 float heightRatio =
@@ -371,7 +353,7 @@ std::vector<types::Instance> BaseInstanceSegmentation::postprocess(
371353 binaryMask.cols , binaryMask.rows , labelIdx, score);
372354 }
373355
374- return finalizeInstances ( std::move ( instances), iouThreshold, maxInstances) ;
356+ return instances;
375357}
376358
377359} // namespace rnexecutorch::models::instance_segmentation
0 commit comments