Skip to content

Commit 42aaf02

Browse files
committed
feat: Add multimethod handling to ObjectDetection
1 parent 992d04b commit 42aaf02

File tree

9 files changed

+504
-78
lines changed

9 files changed

+504
-78
lines changed

apps/computer-vision/app/object_detection/index.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
useObjectDetection,
88
RF_DETR_NANO,
99
SSDLITE_320_MOBILENET_V3_LARGE,
10+
YOLO26N,
1011
ObjectDetectionModelSources,
1112
} from 'react-native-executorch';
1213
import { View, StyleSheet, Image } from 'react-native';
@@ -18,6 +19,7 @@ import ScreenWrapper from '../../ScreenWrapper';
1819
const MODELS: ModelOption<ObjectDetectionModelSources>[] = [
1920
{ label: 'RF-DeTR Nano', value: RF_DETR_NANO },
2021
{ label: 'SSDLite MobileNet', value: SSDLITE_320_MOBILENET_V3_LARGE },
22+
{ label: 'YOLO26N', value: YOLO26N },
2123
];
2224

2325
export default function ObjectDetectionScreen() {

apps/computer-vision/app/vision_camera/index.tsx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type ModelId =
4646
| 'classification'
4747
| 'objectDetectionSsdlite'
4848
| 'objectDetectionRfdetr'
49+
| 'objectDetectionYolo26n'
4950
| 'segmentationDeeplabResnet50'
5051
| 'segmentationDeeplabResnet101'
5152
| 'segmentationDeeplabMobilenet'
@@ -95,6 +96,7 @@ const TASKS: Task[] = [
9596
variants: [
9697
{ id: 'objectDetectionSsdlite', label: 'SSDLite MobileNet' },
9798
{ id: 'objectDetectionRfdetr', label: 'RF-DETR Nano' },
99+
{ id: 'objectDetectionYolo26n', label: 'YOLO26N' },
98100
],
99101
},
100102
{
@@ -241,7 +243,10 @@ export default function VisionCameraScreen() {
241243
<ObjectDetectionTask
242244
{...taskProps}
243245
activeModel={
244-
activeModel as 'objectDetectionSsdlite' | 'objectDetectionRfdetr'
246+
activeModel as
247+
| 'objectDetectionSsdlite'
248+
| 'objectDetectionRfdetr'
249+
| 'objectDetectionYolo26n'
245250
}
246251
/>
247252
)}

apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@ import {
66
Detection,
77
RF_DETR_NANO,
88
SSDLITE_320_MOBILENET_V3_LARGE,
9+
YOLO26N,
910
useObjectDetection,
1011
} from 'react-native-executorch';
1112
import { labelColor, labelColorBg } from '../utils/colors';
1213
import { TaskProps } from './types';
1314

14-
type ObjModelId = 'objectDetectionSsdlite' | 'objectDetectionRfdetr';
15+
type ObjModelId =
16+
| 'objectDetectionSsdlite'
17+
| 'objectDetectionRfdetr'
18+
| 'objectDetectionYolo26n';
1519

1620
type Props = TaskProps & { activeModel: ObjModelId };
1721

@@ -34,8 +38,17 @@ export default function ObjectDetectionTask({
3438
model: RF_DETR_NANO,
3539
preventLoad: activeModel !== 'objectDetectionRfdetr',
3640
});
41+
const yolo26n = useObjectDetection({
42+
model: YOLO26N,
43+
preventLoad: activeModel !== 'objectDetectionYolo26n',
44+
});
3745

38-
const active = activeModel === 'objectDetectionSsdlite' ? ssdlite : rfdetr;
46+
const active =
47+
activeModel === 'objectDetectionSsdlite'
48+
? ssdlite
49+
: activeModel === 'objectDetectionRfdetr'
50+
? rfdetr
51+
: yolo26n;
3952

4053
const [detections, setDetections] = useState<Detection[]>([]);
4154
const [imageSize, setImageSize] = useState({ width: 1, height: 1 });
Lines changed: 113 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "ObjectDetection.h"
22
#include "Constants.h"
33

4+
#include <set>
5+
46
#include <rnexecutorch/Error.h>
57
#include <rnexecutorch/ErrorCodes.h>
68
#include <rnexecutorch/Log.h>
@@ -18,21 +20,6 @@ ObjectDetection::ObjectDetection(
1820
std::shared_ptr<react::CallInvoker> callInvoker)
1921
: VisionModel(modelSource, callInvoker),
2022
labelNames_(std::move(labelNames)) {
21-
auto inputTensors = getAllInputShapes();
22-
if (inputTensors.empty()) {
23-
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
24-
"Model seems to not take any input tensors.");
25-
}
26-
modelInputShape_ = inputTensors[0];
27-
if (modelInputShape_.size() < 2) {
28-
char errorMessage[100];
29-
std::snprintf(errorMessage, sizeof(errorMessage),
30-
"Unexpected model input size, expected at least 2 dimensions "
31-
"but got: %zu.",
32-
modelInputShape_.size());
33-
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
34-
errorMessage);
35-
}
3623
if (normMean.size() == 3) {
3724
normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]);
3825
} else if (!normMean.empty()) {
@@ -47,14 +34,65 @@ ObjectDetection::ObjectDetection(
4734
}
4835
}
4936

37+
cv::Size ObjectDetection::modelInputSize() const {
38+
if (currentlyLoadedMethod_.empty()) {
39+
return VisionModel::modelInputSize();
40+
}
41+
auto inputShapes = getAllInputShapes(currentlyLoadedMethod_);
42+
if (inputShapes.empty() || inputShapes[0].size() < 2) {
43+
return VisionModel::modelInputSize();
44+
}
45+
const auto &shape = inputShapes[0];
46+
return {static_cast<int>(shape[shape.size() - 2]),
47+
static_cast<int>(shape[shape.size() - 1])};
48+
}
49+
50+
void ObjectDetection::ensureMethodLoaded(const std::string &methodName) {
51+
if (methodName.empty()) {
52+
throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
53+
"methodName cannot be empty");
54+
}
55+
if (currentlyLoadedMethod_ == methodName) {
56+
return;
57+
}
58+
if (!module_) {
59+
throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded,
60+
"Model module is not loaded");
61+
}
62+
if (!currentlyLoadedMethod_.empty()) {
63+
module_->unload_method(currentlyLoadedMethod_);
64+
}
65+
auto loadResult = module_->load_method(methodName);
66+
if (loadResult != executorch::runtime::Error::Ok) {
67+
throw RnExecutorchError(
68+
loadResult, "Failed to load method '" + methodName +
69+
"'. Ensure the method exists in the exported model.");
70+
}
71+
currentlyLoadedMethod_ = methodName;
72+
}
73+
74+
std::set<int32_t> ObjectDetection::prepareAllowedClasses(
75+
const std::vector<int32_t> &classIndices) const {
76+
std::set<int32_t> allowedClasses;
77+
if (!classIndices.empty()) {
78+
allowedClasses.insert(classIndices.begin(), classIndices.end());
79+
}
80+
return allowedClasses;
81+
}
82+
5083
std::vector<types::Detection>
5184
ObjectDetection::postprocess(const std::vector<EValue> &tensors,
52-
cv::Size originalSize, double detectionThreshold) {
85+
cv::Size originalSize, double detectionThreshold,
86+
double iouThreshold,
87+
const std::vector<int32_t> &classIndices) {
5388
const cv::Size inputSize = modelInputSize();
5489
float widthRatio = static_cast<float>(originalSize.width) / inputSize.width;
5590
float heightRatio =
5691
static_cast<float>(originalSize.height) / inputSize.height;
5792

93+
// Prepare allowed classes set for filtering
94+
auto allowedClasses = prepareAllowedClasses(classIndices);
95+
5896
std::vector<types::Detection> detections;
5997
auto bboxTensor = tensors.at(0).toTensor();
6098
std::span<const float> bboxes(
@@ -75,36 +113,62 @@ ObjectDetection::postprocess(const std::vector<EValue> &tensors,
75113
if (scores[i] < detectionThreshold) {
76114
continue;
77115
}
116+
117+
auto labelIdx = static_cast<int32_t>(labels[i]);
118+
119+
// Filter by class if classesOfInterest is specified
120+
if (!allowedClasses.empty() &&
121+
allowedClasses.find(labelIdx) == allowedClasses.end()) {
122+
continue;
123+
}
124+
78125
float x1 = bboxes[i * 4] * widthRatio;
79126
float y1 = bboxes[i * 4 + 1] * heightRatio;
80127
float x2 = bboxes[i * 4 + 2] * widthRatio;
81128
float y2 = bboxes[i * 4 + 3] * heightRatio;
82-
auto labelIdx = static_cast<std::size_t>(labels[i]);
83-
if (labelIdx >= labelNames_.size()) {
129+
130+
if (static_cast<std::size_t>(labelIdx) >= labelNames_.size()) {
84131
throw RnExecutorchError(
85132
RnExecutorchErrorCode::InvalidConfig,
86133
"Model output class index " + std::to_string(labelIdx) +
87134
" exceeds labelNames size " + std::to_string(labelNames_.size()) +
88135
". Ensure the labelMap covers all model output classes.");
89136
}
90137
detections.emplace_back(utils::computer_vision::BBox{x1, y1, x2, y2},
91-
labelNames_[labelIdx],
92-
static_cast<int32_t>(labelIdx), scores[i]);
138+
labelNames_[labelIdx], labelIdx, scores[i]);
93139
}
94140

95-
return utils::computer_vision::nonMaxSuppression(detections,
96-
constants::IOU_THRESHOLD);
141+
return utils::computer_vision::nonMaxSuppression(detections, iouThreshold);
97142
}
98143

99-
std::vector<types::Detection>
100-
ObjectDetection::runInference(cv::Mat image, double detectionThreshold) {
144+
std::vector<types::Detection> ObjectDetection::runInference(
145+
cv::Mat image, double detectionThreshold, double iouThreshold,
146+
const std::vector<int32_t> &classIndices, const std::string &methodName) {
101147
if (detectionThreshold < 0.0 || detectionThreshold > 1.0) {
102148
throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
103149
"detectionThreshold must be in range [0, 1]");
104150
}
151+
if (iouThreshold < 0.0 || iouThreshold > 1.0) {
152+
throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
153+
"iouThreshold must be in range [0, 1]");
154+
}
155+
105156
std::scoped_lock lock(inference_mutex_);
106157

158+
// Ensure the correct method is loaded
159+
ensureMethodLoaded(methodName);
160+
107161
cv::Size originalSize = image.size();
162+
163+
// Query input shapes for the currently loaded method
164+
auto inputShapes = getAllInputShapes(methodName);
165+
if (inputShapes.empty() || inputShapes[0].size() < 2) {
166+
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
167+
"Could not determine input shape for method: " +
168+
methodName);
169+
}
170+
modelInputShape_ = inputShapes[0];
171+
108172
cv::Mat preprocessed = preprocess(image);
109173

110174
auto inputTensor =
@@ -114,46 +178,50 @@ ObjectDetection::runInference(cv::Mat image, double detectionThreshold) {
114178
: image_processing::getTensorFromMatrix(modelInputShape_,
115179
preprocessed);
116180

117-
auto forwardResult = BaseModel::forward(inputTensor);
118-
if (!forwardResult.ok()) {
119-
throw RnExecutorchError(forwardResult.error(),
120-
"The model's forward function did not succeed. "
121-
"Ensure the model input is correct.");
181+
auto executeResult = execute(methodName, {inputTensor});
182+
if (!executeResult.ok()) {
183+
throw RnExecutorchError(executeResult.error(),
184+
"The model's " + methodName +
185+
" method did not succeed. "
186+
"Ensure the model input is correct.");
122187
}
123188

124-
return postprocess(forwardResult.get(), originalSize, detectionThreshold);
189+
return postprocess(executeResult.get(), originalSize, detectionThreshold,
190+
iouThreshold, classIndices);
125191
}
126192

127-
std::vector<types::Detection>
128-
ObjectDetection::generateFromString(std::string imageSource,
129-
double detectionThreshold) {
193+
std::vector<types::Detection> ObjectDetection::generateFromString(
194+
std::string imageSource, double detectionThreshold, double iouThreshold,
195+
std::vector<int32_t> classIndices, std::string methodName) {
130196
cv::Mat imageBGR = image_processing::readImage(imageSource);
131197

132198
cv::Mat imageRGB;
133199
cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB);
134200

135-
return runInference(imageRGB, detectionThreshold);
201+
return runInference(imageRGB, detectionThreshold, iouThreshold, classIndices,
202+
methodName);
136203
}
137204

138-
std::vector<types::Detection>
139-
ObjectDetection::generateFromFrame(jsi::Runtime &runtime,
140-
const jsi::Value &frameData,
141-
double detectionThreshold) {
142-
auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData);
205+
std::vector<types::Detection> ObjectDetection::generateFromFrame(
206+
jsi::Runtime &runtime, const jsi::Value &frameData,
207+
double detectionThreshold, double iouThreshold,
208+
std::vector<int32_t> classIndices, std::string methodName) {
143209
cv::Mat frame = extractFromFrame(runtime, frameData);
144-
cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(frame, orient);
145-
auto detections = runInference(rotated, detectionThreshold);
210+
auto detections = runInference(frame, detectionThreshold, iouThreshold,
211+
classIndices, methodName);
212+
146213
for (auto &det : detections) {
147214
::rnexecutorch::utils::inverseRotateBbox(det.bbox, orient, rotated.size());
148215
}
149216
return detections;
150217
}
151218

152-
std::vector<types::Detection>
153-
ObjectDetection::generateFromPixels(JSTensorViewIn pixelData,
154-
double detectionThreshold) {
219+
std::vector<types::Detection> ObjectDetection::generateFromPixels(
220+
JSTensorViewIn pixelData, double detectionThreshold, double iouThreshold,
221+
std::vector<int32_t> classIndices, std::string methodName) {
155222
cv::Mat image = extractFromPixels(pixelData);
156223

157-
return runInference(image, detectionThreshold);
224+
return runInference(image, detectionThreshold, iouThreshold, classIndices,
225+
methodName);
158226
}
159227
} // namespace rnexecutorch::models::object_detection

0 commit comments

Comments
 (0)