11#include " ObjectDetection.h"
22#include " Constants.h"
33
4+ #include < set>
5+
46#include < rnexecutorch/Error.h>
57#include < rnexecutorch/ErrorCodes.h>
68#include < rnexecutorch/Log.h>
@@ -17,21 +19,6 @@ ObjectDetection::ObjectDetection(
1719 std::shared_ptr<react::CallInvoker> callInvoker)
1820 : VisionModel(modelSource, callInvoker),
1921 labelNames_ (std::move(labelNames)) {
20- auto inputTensors = getAllInputShapes ();
21- if (inputTensors.empty ()) {
22- throw RnExecutorchError (RnExecutorchErrorCode::UnexpectedNumInputs,
23- " Model seems to not take any input tensors." );
24- }
25- modelInputShape_ = inputTensors[0 ];
26- if (modelInputShape_.size () < 2 ) {
27- char errorMessage[100 ];
28- std::snprintf (errorMessage, sizeof (errorMessage),
29- " Unexpected model input size, expected at least 2 dimensions "
30- " but got: %zu." ,
31- modelInputShape_.size ());
32- throw RnExecutorchError (RnExecutorchErrorCode::UnexpectedNumInputs,
33- errorMessage);
34- }
3522 if (normMean.size () == 3 ) {
3623 normMean_ = cv::Scalar (normMean[0 ], normMean[1 ], normMean[2 ]);
3724 } else if (!normMean.empty ()) {
@@ -46,14 +33,65 @@ ObjectDetection::ObjectDetection(
4633 }
4734}
4835
36+ cv::Size ObjectDetection::modelInputSize () const {
37+ if (currentlyLoadedMethod_.empty ()) {
38+ return VisionModel::modelInputSize ();
39+ }
40+ auto inputShapes = getAllInputShapes (currentlyLoadedMethod_);
41+ if (inputShapes.empty () || inputShapes[0 ].size () < 2 ) {
42+ return VisionModel::modelInputSize ();
43+ }
44+ const auto &shape = inputShapes[0 ];
45+ return {static_cast <int >(shape[shape.size () - 2 ]),
46+ static_cast <int >(shape[shape.size () - 1 ])};
47+ }
48+
49+ void ObjectDetection::ensureMethodLoaded (const std::string &methodName) {
50+ if (methodName.empty ()) {
51+ throw RnExecutorchError (RnExecutorchErrorCode::InvalidUserInput,
52+ " methodName cannot be empty" );
53+ }
54+ if (currentlyLoadedMethod_ == methodName) {
55+ return ;
56+ }
57+ if (!module_) {
58+ throw RnExecutorchError (RnExecutorchErrorCode::ModuleNotLoaded,
59+ " Model module is not loaded" );
60+ }
61+ if (!currentlyLoadedMethod_.empty ()) {
62+ module_->unload_method (currentlyLoadedMethod_);
63+ }
64+ auto loadResult = module_->load_method (methodName);
65+ if (loadResult != executorch::runtime::Error::Ok) {
66+ throw RnExecutorchError (
67+ loadResult, " Failed to load method '" + methodName +
68+ " '. Ensure the method exists in the exported model." );
69+ }
70+ currentlyLoadedMethod_ = methodName;
71+ }
72+
73+ std::set<int32_t > ObjectDetection::prepareAllowedClasses (
74+ const std::vector<int32_t > &classIndices) const {
75+ std::set<int32_t > allowedClasses;
76+ if (!classIndices.empty ()) {
77+ allowedClasses.insert (classIndices.begin (), classIndices.end ());
78+ }
79+ return allowedClasses;
80+ }
81+
4982std::vector<types::Detection>
5083ObjectDetection::postprocess (const std::vector<EValue> &tensors,
51- cv::Size originalSize, double detectionThreshold) {
84+ cv::Size originalSize, double detectionThreshold,
85+ double iouThreshold,
86+ const std::vector<int32_t > &classIndices) {
5287 const cv::Size inputSize = modelInputSize ();
5388 float widthRatio = static_cast <float >(originalSize.width ) / inputSize.width ;
5489 float heightRatio =
5590 static_cast <float >(originalSize.height ) / inputSize.height ;
5691
92+ // Prepare allowed classes set for filtering
93+ auto allowedClasses = prepareAllowedClasses (classIndices);
94+
5795 std::vector<types::Detection> detections;
5896 auto bboxTensor = tensors.at (0 ).toTensor ();
5997 std::span<const float > bboxes (
@@ -74,36 +112,62 @@ ObjectDetection::postprocess(const std::vector<EValue> &tensors,
74112 if (scores[i] < detectionThreshold) {
75113 continue ;
76114 }
115+
116+ auto labelIdx = static_cast <int32_t >(labels[i]);
117+
118+ // Filter by class if classesOfInterest is specified
119+ if (!allowedClasses.empty () &&
120+ allowedClasses.find (labelIdx) == allowedClasses.end ()) {
121+ continue ;
122+ }
123+
77124 float x1 = bboxes[i * 4 ] * widthRatio;
78125 float y1 = bboxes[i * 4 + 1 ] * heightRatio;
79126 float x2 = bboxes[i * 4 + 2 ] * widthRatio;
80127 float y2 = bboxes[i * 4 + 3 ] * heightRatio;
81- auto labelIdx = static_cast <std:: size_t >(labels[i]);
82- if (labelIdx >= labelNames_.size ()) {
128+
129+ if (static_cast <std:: size_t >( labelIdx) >= labelNames_.size ()) {
83130 throw RnExecutorchError (
84131 RnExecutorchErrorCode::InvalidConfig,
85132 " Model output class index " + std::to_string (labelIdx) +
86133 " exceeds labelNames size " + std::to_string (labelNames_.size ()) +
87134 " . Ensure the labelMap covers all model output classes." );
88135 }
89136 detections.emplace_back (utils::computer_vision::BBox{x1, y1, x2, y2},
90- labelNames_[labelIdx],
91- static_cast <int32_t >(labelIdx), scores[i]);
137+ labelNames_[labelIdx], labelIdx, scores[i]);
92138 }
93139
94- return utils::computer_vision::nonMaxSuppression (detections,
95- constants::IOU_THRESHOLD);
140+ return utils::computer_vision::nonMaxSuppression (detections, iouThreshold);
96141}
97142
98- std::vector<types::Detection>
99- ObjectDetection::runInference (cv::Mat image, double detectionThreshold) {
143+ std::vector<types::Detection> ObjectDetection::runInference (
144+ cv::Mat image, double detectionThreshold, double iouThreshold,
145+ const std::vector<int32_t > &classIndices, const std::string &methodName) {
100146 if (detectionThreshold < 0.0 || detectionThreshold > 1.0 ) {
101147 throw RnExecutorchError (RnExecutorchErrorCode::InvalidUserInput,
102148 " detectionThreshold must be in range [0, 1]" );
103149 }
150+ if (iouThreshold < 0.0 || iouThreshold > 1.0 ) {
151+ throw RnExecutorchError (RnExecutorchErrorCode::InvalidUserInput,
152+ " iouThreshold must be in range [0, 1]" );
153+ }
154+
104155 std::scoped_lock lock (inference_mutex_);
105156
157+ // Ensure the correct method is loaded
158+ ensureMethodLoaded (methodName);
159+
106160 cv::Size originalSize = image.size ();
161+
162+ // Query input shapes for the currently loaded method
163+ auto inputShapes = getAllInputShapes (methodName);
164+ if (inputShapes.empty () || inputShapes[0 ].size () < 2 ) {
165+ throw RnExecutorchError (RnExecutorchErrorCode::UnexpectedNumInputs,
166+ " Could not determine input shape for method: " +
167+ methodName);
168+ }
169+ modelInputShape_ = inputShapes[0 ];
170+
107171 cv::Mat preprocessed = preprocess (image);
108172
109173 auto inputTensor =
@@ -113,40 +177,45 @@ ObjectDetection::runInference(cv::Mat image, double detectionThreshold) {
113177 : image_processing::getTensorFromMatrix (modelInputShape_,
114178 preprocessed);
115179
116- auto forwardResult = BaseModel::forward (inputTensor);
117- if (!forwardResult.ok ()) {
118- throw RnExecutorchError (forwardResult.error (),
119- " The model's forward function did not succeed. "
120- " Ensure the model input is correct." );
180+ auto executeResult = execute (methodName, {inputTensor});
181+ if (!executeResult.ok ()) {
182+ throw RnExecutorchError (executeResult.error (),
183+ " The model's " + methodName +
184+ " method did not succeed. "
185+ " Ensure the model input is correct." );
121186 }
122187
123- return postprocess (forwardResult.get (), originalSize, detectionThreshold);
188+ return postprocess (executeResult.get (), originalSize, detectionThreshold,
189+ iouThreshold, classIndices);
124190}
125191
126- std::vector<types::Detection>
127- ObjectDetection::generateFromString ( std::string imageSource,
128- double detectionThreshold ) {
192+ std::vector<types::Detection> ObjectDetection::generateFromString (
193+ std::string imageSource, double detectionThreshold, double iouThreshold ,
194+ std::vector< int32_t > classIndices, std::string methodName ) {
129195 cv::Mat imageBGR = image_processing::readImage (imageSource);
130196
131197 cv::Mat imageRGB;
132198 cv::cvtColor (imageBGR, imageRGB, cv::COLOR_BGR2RGB);
133199
134- return runInference (imageRGB, detectionThreshold);
200+ return runInference (imageRGB, detectionThreshold, iouThreshold, classIndices,
201+ methodName);
135202}
136203
137- std::vector<types::Detection>
138- ObjectDetection::generateFromFrame ( jsi::Runtime &runtime,
139- const jsi::Value &frameData ,
140- double detectionThreshold ) {
204+ std::vector<types::Detection> ObjectDetection::generateFromFrame (
205+ jsi::Runtime &runtime, const jsi::Value &frameData ,
206+ double detectionThreshold, double iouThreshold ,
207+ std::vector< int32_t > classIndices, std::string methodName ) {
141208 cv::Mat frame = extractFromFrame (runtime, frameData);
142- return runInference (frame, detectionThreshold);
209+ return runInference (frame, detectionThreshold, iouThreshold, classIndices,
210+ methodName);
143211}
144212
145- std::vector<types::Detection>
146- ObjectDetection::generateFromPixels ( JSTensorViewIn pixelData,
147- double detectionThreshold ) {
213+ std::vector<types::Detection> ObjectDetection::generateFromPixels (
214+ JSTensorViewIn pixelData, double detectionThreshold, double iouThreshold ,
215+ std::vector< int32_t > classIndices, std::string methodName ) {
148216 cv::Mat image = extractFromPixels (pixelData);
149217
150- return runInference (image, detectionThreshold);
218+ return runInference (image, detectionThreshold, iouThreshold, classIndices,
219+ methodName);
151220}
152221} // namespace rnexecutorch::models::object_detection
0 commit comments