11#include " ObjectDetection.h"
22#include " Constants.h"
33
4+ #include < set>
5+
46#include < rnexecutorch/Error.h>
57#include < rnexecutorch/ErrorCodes.h>
68#include < rnexecutorch/Log.h>
@@ -18,21 +20,6 @@ ObjectDetection::ObjectDetection(
1820 std::shared_ptr<react::CallInvoker> callInvoker)
1921 : VisionModel(modelSource, callInvoker),
2022 labelNames_ (std::move(labelNames)) {
21- auto inputTensors = getAllInputShapes ();
22- if (inputTensors.empty ()) {
23- throw RnExecutorchError (RnExecutorchErrorCode::UnexpectedNumInputs,
24- " Model seems to not take any input tensors." );
25- }
26- modelInputShape_ = inputTensors[0 ];
27- if (modelInputShape_.size () < 2 ) {
28- char errorMessage[100 ];
29- std::snprintf (errorMessage, sizeof (errorMessage),
30- " Unexpected model input size, expected at least 2 dimensions "
31- " but got: %zu." ,
32- modelInputShape_.size ());
33- throw RnExecutorchError (RnExecutorchErrorCode::UnexpectedNumInputs,
34- errorMessage);
35- }
3623 if (normMean.size () == 3 ) {
3724 normMean_ = cv::Scalar (normMean[0 ], normMean[1 ], normMean[2 ]);
3825 } else if (!normMean.empty ()) {
@@ -47,14 +34,65 @@ ObjectDetection::ObjectDetection(
4734 }
4835}
4936
37+ cv::Size ObjectDetection::modelInputSize () const {
38+ if (currentlyLoadedMethod_.empty ()) {
39+ return VisionModel::modelInputSize ();
40+ }
41+ auto inputShapes = getAllInputShapes (currentlyLoadedMethod_);
42+ if (inputShapes.empty () || inputShapes[0 ].size () < 2 ) {
43+ return VisionModel::modelInputSize ();
44+ }
45+ const auto &shape = inputShapes[0 ];
46+ return {static_cast <int >(shape[shape.size () - 2 ]),
47+ static_cast <int >(shape[shape.size () - 1 ])};
48+ }
49+
50+ void ObjectDetection::ensureMethodLoaded (const std::string &methodName) {
51+ if (methodName.empty ()) {
52+ throw RnExecutorchError (RnExecutorchErrorCode::InvalidUserInput,
53+ " methodName cannot be empty" );
54+ }
55+ if (currentlyLoadedMethod_ == methodName) {
56+ return ;
57+ }
58+ if (!module_) {
59+ throw RnExecutorchError (RnExecutorchErrorCode::ModuleNotLoaded,
60+ " Model module is not loaded" );
61+ }
62+ if (!currentlyLoadedMethod_.empty ()) {
63+ module_->unload_method (currentlyLoadedMethod_);
64+ }
65+ auto loadResult = module_->load_method (methodName);
66+ if (loadResult != executorch::runtime::Error::Ok) {
67+ throw RnExecutorchError (
68+ loadResult, " Failed to load method '" + methodName +
69+ " '. Ensure the method exists in the exported model." );
70+ }
71+ currentlyLoadedMethod_ = methodName;
72+ }
73+
74+ std::set<int32_t > ObjectDetection::prepareAllowedClasses (
75+ const std::vector<int32_t > &classIndices) const {
76+ std::set<int32_t > allowedClasses;
77+ if (!classIndices.empty ()) {
78+ allowedClasses.insert (classIndices.begin (), classIndices.end ());
79+ }
80+ return allowedClasses;
81+ }
82+
5083std::vector<types::Detection>
5184ObjectDetection::postprocess (const std::vector<EValue> &tensors,
52- cv::Size originalSize, double detectionThreshold) {
85+ cv::Size originalSize, double detectionThreshold,
86+ double iouThreshold,
87+ const std::vector<int32_t > &classIndices) {
5388 const cv::Size inputSize = modelInputSize ();
5489 float widthRatio = static_cast <float >(originalSize.width ) / inputSize.width ;
5590 float heightRatio =
5691 static_cast <float >(originalSize.height ) / inputSize.height ;
5792
93+ // Prepare allowed classes set for filtering
94+ auto allowedClasses = prepareAllowedClasses (classIndices);
95+
5896 std::vector<types::Detection> detections;
5997 auto bboxTensor = tensors.at (0 ).toTensor ();
6098 std::span<const float > bboxes (
@@ -75,36 +113,62 @@ ObjectDetection::postprocess(const std::vector<EValue> &tensors,
75113 if (scores[i] < detectionThreshold) {
76114 continue ;
77115 }
116+
117+ auto labelIdx = static_cast <int32_t >(labels[i]);
118+
119+ // Filter by class if classesOfInterest is specified
120+ if (!allowedClasses.empty () &&
121+ allowedClasses.find (labelIdx) == allowedClasses.end ()) {
122+ continue ;
123+ }
124+
78125 float x1 = bboxes[i * 4 ] * widthRatio;
79126 float y1 = bboxes[i * 4 + 1 ] * heightRatio;
80127 float x2 = bboxes[i * 4 + 2 ] * widthRatio;
81128 float y2 = bboxes[i * 4 + 3 ] * heightRatio;
82- auto labelIdx = static_cast <std:: size_t >(labels[i]);
83- if (labelIdx >= labelNames_.size ()) {
129+
130+ if (static_cast <std:: size_t >( labelIdx) >= labelNames_.size ()) {
84131 throw RnExecutorchError (
85132 RnExecutorchErrorCode::InvalidConfig,
86133 " Model output class index " + std::to_string (labelIdx) +
87134 " exceeds labelNames size " + std::to_string (labelNames_.size ()) +
88135 " . Ensure the labelMap covers all model output classes." );
89136 }
90137 detections.emplace_back (utils::computer_vision::BBox{x1, y1, x2, y2},
91- labelNames_[labelIdx],
92- static_cast <int32_t >(labelIdx), scores[i]);
138+ labelNames_[labelIdx], labelIdx, scores[i]);
93139 }
94140
95- return utils::computer_vision::nonMaxSuppression (detections,
96- constants::IOU_THRESHOLD);
141+ return utils::computer_vision::nonMaxSuppression (detections, iouThreshold);
97142}
98143
99- std::vector<types::Detection>
100- ObjectDetection::runInference (cv::Mat image, double detectionThreshold) {
144+ std::vector<types::Detection> ObjectDetection::runInference (
145+ cv::Mat image, double detectionThreshold, double iouThreshold,
146+ const std::vector<int32_t > &classIndices, const std::string &methodName) {
101147 if (detectionThreshold < 0.0 || detectionThreshold > 1.0 ) {
102148 throw RnExecutorchError (RnExecutorchErrorCode::InvalidUserInput,
103149 " detectionThreshold must be in range [0, 1]" );
104150 }
151+ if (iouThreshold < 0.0 || iouThreshold > 1.0 ) {
152+ throw RnExecutorchError (RnExecutorchErrorCode::InvalidUserInput,
153+ " iouThreshold must be in range [0, 1]" );
154+ }
155+
105156 std::scoped_lock lock (inference_mutex_);
106157
158+ // Ensure the correct method is loaded
159+ ensureMethodLoaded (methodName);
160+
107161 cv::Size originalSize = image.size ();
162+
163+ // Query input shapes for the currently loaded method
164+ auto inputShapes = getAllInputShapes (methodName);
165+ if (inputShapes.empty () || inputShapes[0 ].size () < 2 ) {
166+ throw RnExecutorchError (RnExecutorchErrorCode::UnexpectedNumInputs,
167+ " Could not determine input shape for method: " +
168+ methodName);
169+ }
170+ modelInputShape_ = inputShapes[0 ];
171+
108172 cv::Mat preprocessed = preprocess (image);
109173
110174 auto inputTensor =
@@ -114,46 +178,50 @@ ObjectDetection::runInference(cv::Mat image, double detectionThreshold) {
114178 : image_processing::getTensorFromMatrix (modelInputShape_,
115179 preprocessed);
116180
117- auto forwardResult = BaseModel::forward (inputTensor);
118- if (!forwardResult.ok ()) {
119- throw RnExecutorchError (forwardResult.error (),
120- " The model's forward function did not succeed. "
121- " Ensure the model input is correct." );
181+ auto executeResult = execute (methodName, {inputTensor});
182+ if (!executeResult.ok ()) {
183+ throw RnExecutorchError (executeResult.error (),
184+ " The model's " + methodName +
185+ " method did not succeed. "
186+ " Ensure the model input is correct." );
122187 }
123188
124- return postprocess (forwardResult.get (), originalSize, detectionThreshold);
189+ return postprocess (executeResult.get (), originalSize, detectionThreshold,
190+ iouThreshold, classIndices);
125191}
126192
127- std::vector<types::Detection>
128- ObjectDetection::generateFromString ( std::string imageSource,
129- double detectionThreshold ) {
193+ std::vector<types::Detection> ObjectDetection::generateFromString (
194+ std::string imageSource, double detectionThreshold, double iouThreshold ,
195+ std::vector< int32_t > classIndices, std::string methodName ) {
130196 cv::Mat imageBGR = image_processing::readImage (imageSource);
131197
132198 cv::Mat imageRGB;
133199 cv::cvtColor (imageBGR, imageRGB, cv::COLOR_BGR2RGB);
134200
135- return runInference (imageRGB, detectionThreshold);
201+ return runInference (imageRGB, detectionThreshold, iouThreshold, classIndices,
202+ methodName);
136203}
137204
138- std::vector<types::Detection>
139- ObjectDetection::generateFromFrame (jsi::Runtime &runtime,
140- const jsi::Value &frameData,
141- double detectionThreshold) {
142- auto orient = ::rnexecutorch::utils::readFrameOrientation (runtime, frameData);
205+ std::vector<types::Detection> ObjectDetection::generateFromFrame (
206+ jsi::Runtime &runtime, const jsi::Value &frameData,
207+ double detectionThreshold, double iouThreshold,
208+ std::vector<int32_t > classIndices, std::string methodName) {
143209 cv::Mat frame = extractFromFrame (runtime, frameData);
144- cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel (frame, orient);
145- auto detections = runInference (rotated, detectionThreshold);
210+ auto detections = runInference (frame, detectionThreshold, iouThreshold,
211+ classIndices, methodName);
212+
146213 for (auto &det : detections) {
147214 ::rnexecutorch::utils::inverseRotateBbox (det.bbox, orient, rotated.size());
148215 }
149216 return detections;
150217}
151218
152- std::vector<types::Detection>
153- ObjectDetection::generateFromPixels ( JSTensorViewIn pixelData,
154- double detectionThreshold ) {
219+ std::vector<types::Detection> ObjectDetection::generateFromPixels (
220+ JSTensorViewIn pixelData, double detectionThreshold, double iouThreshold ,
221+ std::vector< int32_t > classIndices, std::string methodName ) {
155222 cv::Mat image = extractFromPixels (pixelData);
156223
157- return runInference (image, detectionThreshold);
224+ return runInference (image, detectionThreshold, iouThreshold, classIndices,
225+ methodName);
158226}
159227} // namespace rnexecutorch::models::object_detection
0 commit comments