Skip to content

Commit 22bdbfc

Browse files
committed
feat: Add multimethod handling to ObjectDetection
1 parent e86ec7b commit 22bdbfc

File tree

9 files changed

+500
-72
lines changed

9 files changed

+500
-72
lines changed

apps/computer-vision/app/object_detection/index.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
useObjectDetection,
88
RF_DETR_NANO,
99
SSDLITE_320_MOBILENET_V3_LARGE,
10+
YOLO26N,
1011
ObjectDetectionModelSources,
1112
} from 'react-native-executorch';
1213
import { View, StyleSheet, Image } from 'react-native';
@@ -18,6 +19,7 @@ import ScreenWrapper from '../../ScreenWrapper';
1819
const MODELS: ModelOption<ObjectDetectionModelSources>[] = [
1920
{ label: 'RF-DeTR Nano', value: RF_DETR_NANO },
2021
{ label: 'SSDLite MobileNet', value: SSDLITE_320_MOBILENET_V3_LARGE },
22+
{ label: 'YOLO26N', value: YOLO26N },
2123
];
2224

2325
export default function ObjectDetectionScreen() {

apps/computer-vision/app/vision_camera/index.tsx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type ModelId =
4242
| 'classification'
4343
| 'objectDetectionSsdlite'
4444
| 'objectDetectionRfdetr'
45+
| 'objectDetectionYolo26n'
4546
| 'segmentationDeeplabResnet50'
4647
| 'segmentationDeeplabResnet101'
4748
| 'segmentationDeeplabMobilenet'
@@ -88,6 +89,7 @@ const TASKS: Task[] = [
8889
variants: [
8990
{ id: 'objectDetectionSsdlite', label: 'SSDLite MobileNet' },
9091
{ id: 'objectDetectionRfdetr', label: 'RF-DETR Nano' },
92+
{ id: 'objectDetectionYolo26n', label: 'YOLO26N' },
9193
],
9294
},
9395
];
@@ -216,7 +218,10 @@ export default function VisionCameraScreen() {
216218
<ObjectDetectionTask
217219
{...taskProps}
218220
activeModel={
219-
activeModel as 'objectDetectionSsdlite' | 'objectDetectionRfdetr'
221+
activeModel as
222+
| 'objectDetectionSsdlite'
223+
| 'objectDetectionRfdetr'
224+
| 'objectDetectionYolo26n'
220225
}
221226
/>
222227
)}

apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@ import {
66
Detection,
77
RF_DETR_NANO,
88
SSDLITE_320_MOBILENET_V3_LARGE,
9+
YOLO26N,
910
useObjectDetection,
1011
} from 'react-native-executorch';
1112
import { labelColor, labelColorBg } from '../utils/colors';
1213
import { TaskProps } from './types';
1314

14-
type ObjModelId = 'objectDetectionSsdlite' | 'objectDetectionRfdetr';
15+
type ObjModelId =
16+
| 'objectDetectionSsdlite'
17+
| 'objectDetectionRfdetr'
18+
| 'objectDetectionYolo26n';
1519

1620
type Props = TaskProps & { activeModel: ObjModelId };
1721

@@ -34,8 +38,17 @@ export default function ObjectDetectionTask({
3438
model: RF_DETR_NANO,
3539
preventLoad: activeModel !== 'objectDetectionRfdetr',
3640
});
41+
const yolo26n = useObjectDetection({
42+
model: YOLO26N,
43+
preventLoad: activeModel !== 'objectDetectionYolo26n',
44+
});
3745

38-
const active = activeModel === 'objectDetectionSsdlite' ? ssdlite : rfdetr;
46+
const active =
47+
activeModel === 'objectDetectionSsdlite'
48+
? ssdlite
49+
: activeModel === 'objectDetectionRfdetr'
50+
? rfdetr
51+
: yolo26n;
3952

4053
const [detections, setDetections] = useState<Detection[]>([]);
4154
const [imageSize, setImageSize] = useState({ width: 1, height: 1 });
@@ -81,10 +94,10 @@ export default function ObjectDetectionTask({
8194
if (!detRof) return;
8295
const iw = frame.width > frame.height ? frame.height : frame.width;
8396
const ih = frame.width > frame.height ? frame.width : frame.height;
84-
const result = detRof(frame, 0.5);
97+
const result = detRof(frame, { detectionThreshold: 0.5 });
8598
if (result) {
8699
scheduleOnRN(updateDetections, {
87-
results: result,
100+
results: result as Detection[],
88101
imageWidth: iw,
89102
imageHeight: ih,
90103
});
Lines changed: 112 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "ObjectDetection.h"
22
#include "Constants.h"
33

4+
#include <set>
5+
46
#include <rnexecutorch/Error.h>
57
#include <rnexecutorch/ErrorCodes.h>
68
#include <rnexecutorch/Log.h>
@@ -17,21 +19,6 @@ ObjectDetection::ObjectDetection(
1719
std::shared_ptr<react::CallInvoker> callInvoker)
1820
: VisionModel(modelSource, callInvoker),
1921
labelNames_(std::move(labelNames)) {
20-
auto inputTensors = getAllInputShapes();
21-
if (inputTensors.empty()) {
22-
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
23-
"Model seems to not take any input tensors.");
24-
}
25-
modelInputShape_ = inputTensors[0];
26-
if (modelInputShape_.size() < 2) {
27-
char errorMessage[100];
28-
std::snprintf(errorMessage, sizeof(errorMessage),
29-
"Unexpected model input size, expected at least 2 dimensions "
30-
"but got: %zu.",
31-
modelInputShape_.size());
32-
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
33-
errorMessage);
34-
}
3522
if (normMean.size() == 3) {
3623
normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]);
3724
} else if (!normMean.empty()) {
@@ -46,14 +33,65 @@ ObjectDetection::ObjectDetection(
4633
}
4734
}
4835

36+
cv::Size ObjectDetection::modelInputSize() const {
37+
if (currentlyLoadedMethod_.empty()) {
38+
return VisionModel::modelInputSize();
39+
}
40+
auto inputShapes = getAllInputShapes(currentlyLoadedMethod_);
41+
if (inputShapes.empty() || inputShapes[0].size() < 2) {
42+
return VisionModel::modelInputSize();
43+
}
44+
const auto &shape = inputShapes[0];
45+
return {static_cast<int>(shape[shape.size() - 2]),
46+
static_cast<int>(shape[shape.size() - 1])};
47+
}
48+
49+
void ObjectDetection::ensureMethodLoaded(const std::string &methodName) {
50+
if (methodName.empty()) {
51+
throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
52+
"methodName cannot be empty");
53+
}
54+
if (currentlyLoadedMethod_ == methodName) {
55+
return;
56+
}
57+
if (!module_) {
58+
throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded,
59+
"Model module is not loaded");
60+
}
61+
if (!currentlyLoadedMethod_.empty()) {
62+
module_->unload_method(currentlyLoadedMethod_);
63+
}
64+
auto loadResult = module_->load_method(methodName);
65+
if (loadResult != executorch::runtime::Error::Ok) {
66+
throw RnExecutorchError(
67+
loadResult, "Failed to load method '" + methodName +
68+
"'. Ensure the method exists in the exported model.");
69+
}
70+
currentlyLoadedMethod_ = methodName;
71+
}
72+
73+
std::set<int32_t> ObjectDetection::prepareAllowedClasses(
74+
const std::vector<int32_t> &classIndices) const {
75+
std::set<int32_t> allowedClasses;
76+
if (!classIndices.empty()) {
77+
allowedClasses.insert(classIndices.begin(), classIndices.end());
78+
}
79+
return allowedClasses;
80+
}
81+
4982
std::vector<types::Detection>
5083
ObjectDetection::postprocess(const std::vector<EValue> &tensors,
51-
cv::Size originalSize, double detectionThreshold) {
84+
cv::Size originalSize, double detectionThreshold,
85+
double iouThreshold,
86+
const std::vector<int32_t> &classIndices) {
5287
const cv::Size inputSize = modelInputSize();
5388
float widthRatio = static_cast<float>(originalSize.width) / inputSize.width;
5489
float heightRatio =
5590
static_cast<float>(originalSize.height) / inputSize.height;
5691

92+
// Prepare allowed classes set for filtering
93+
auto allowedClasses = prepareAllowedClasses(classIndices);
94+
5795
std::vector<types::Detection> detections;
5896
auto bboxTensor = tensors.at(0).toTensor();
5997
std::span<const float> bboxes(
@@ -74,36 +112,62 @@ ObjectDetection::postprocess(const std::vector<EValue> &tensors,
74112
if (scores[i] < detectionThreshold) {
75113
continue;
76114
}
115+
116+
auto labelIdx = static_cast<int32_t>(labels[i]);
117+
118+
// Filter by class if classesOfInterest is specified
119+
if (!allowedClasses.empty() &&
120+
allowedClasses.find(labelIdx) == allowedClasses.end()) {
121+
continue;
122+
}
123+
77124
float x1 = bboxes[i * 4] * widthRatio;
78125
float y1 = bboxes[i * 4 + 1] * heightRatio;
79126
float x2 = bboxes[i * 4 + 2] * widthRatio;
80127
float y2 = bboxes[i * 4 + 3] * heightRatio;
81-
auto labelIdx = static_cast<std::size_t>(labels[i]);
82-
if (labelIdx >= labelNames_.size()) {
128+
129+
if (static_cast<std::size_t>(labelIdx) >= labelNames_.size()) {
83130
throw RnExecutorchError(
84131
RnExecutorchErrorCode::InvalidConfig,
85132
"Model output class index " + std::to_string(labelIdx) +
86133
" exceeds labelNames size " + std::to_string(labelNames_.size()) +
87134
". Ensure the labelMap covers all model output classes.");
88135
}
89136
detections.emplace_back(utils::computer_vision::BBox{x1, y1, x2, y2},
90-
labelNames_[labelIdx],
91-
static_cast<int32_t>(labelIdx), scores[i]);
137+
labelNames_[labelIdx], labelIdx, scores[i]);
92138
}
93139

94-
return utils::computer_vision::nonMaxSuppression(detections,
95-
constants::IOU_THRESHOLD);
140+
return utils::computer_vision::nonMaxSuppression(detections, iouThreshold);
96141
}
97142

98-
std::vector<types::Detection>
99-
ObjectDetection::runInference(cv::Mat image, double detectionThreshold) {
143+
std::vector<types::Detection> ObjectDetection::runInference(
144+
cv::Mat image, double detectionThreshold, double iouThreshold,
145+
const std::vector<int32_t> &classIndices, const std::string &methodName) {
100146
if (detectionThreshold < 0.0 || detectionThreshold > 1.0) {
101147
throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
102148
"detectionThreshold must be in range [0, 1]");
103149
}
150+
if (iouThreshold < 0.0 || iouThreshold > 1.0) {
151+
throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
152+
"iouThreshold must be in range [0, 1]");
153+
}
154+
104155
std::scoped_lock lock(inference_mutex_);
105156

157+
// Ensure the correct method is loaded
158+
ensureMethodLoaded(methodName);
159+
106160
cv::Size originalSize = image.size();
161+
162+
// Query input shapes for the currently loaded method
163+
auto inputShapes = getAllInputShapes(methodName);
164+
if (inputShapes.empty() || inputShapes[0].size() < 2) {
165+
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
166+
"Could not determine input shape for method: " +
167+
methodName);
168+
}
169+
modelInputShape_ = inputShapes[0];
170+
107171
cv::Mat preprocessed = preprocess(image);
108172

109173
auto inputTensor =
@@ -113,40 +177,45 @@ ObjectDetection::runInference(cv::Mat image, double detectionThreshold) {
113177
: image_processing::getTensorFromMatrix(modelInputShape_,
114178
preprocessed);
115179

116-
auto forwardResult = BaseModel::forward(inputTensor);
117-
if (!forwardResult.ok()) {
118-
throw RnExecutorchError(forwardResult.error(),
119-
"The model's forward function did not succeed. "
120-
"Ensure the model input is correct.");
180+
auto executeResult = execute(methodName, {inputTensor});
181+
if (!executeResult.ok()) {
182+
throw RnExecutorchError(executeResult.error(),
183+
"The model's " + methodName +
184+
" method did not succeed. "
185+
"Ensure the model input is correct.");
121186
}
122187

123-
return postprocess(forwardResult.get(), originalSize, detectionThreshold);
188+
return postprocess(executeResult.get(), originalSize, detectionThreshold,
189+
iouThreshold, classIndices);
124190
}
125191

126-
std::vector<types::Detection>
127-
ObjectDetection::generateFromString(std::string imageSource,
128-
double detectionThreshold) {
192+
std::vector<types::Detection> ObjectDetection::generateFromString(
193+
std::string imageSource, double detectionThreshold, double iouThreshold,
194+
std::vector<int32_t> classIndices, std::string methodName) {
129195
cv::Mat imageBGR = image_processing::readImage(imageSource);
130196

131197
cv::Mat imageRGB;
132198
cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB);
133199

134-
return runInference(imageRGB, detectionThreshold);
200+
return runInference(imageRGB, detectionThreshold, iouThreshold, classIndices,
201+
methodName);
135202
}
136203

137-
std::vector<types::Detection>
138-
ObjectDetection::generateFromFrame(jsi::Runtime &runtime,
139-
const jsi::Value &frameData,
140-
double detectionThreshold) {
204+
std::vector<types::Detection> ObjectDetection::generateFromFrame(
205+
jsi::Runtime &runtime, const jsi::Value &frameData,
206+
double detectionThreshold, double iouThreshold,
207+
std::vector<int32_t> classIndices, std::string methodName) {
141208
cv::Mat frame = extractFromFrame(runtime, frameData);
142-
return runInference(frame, detectionThreshold);
209+
return runInference(frame, detectionThreshold, iouThreshold, classIndices,
210+
methodName);
143211
}
144212

145-
std::vector<types::Detection>
146-
ObjectDetection::generateFromPixels(JSTensorViewIn pixelData,
147-
double detectionThreshold) {
213+
std::vector<types::Detection> ObjectDetection::generateFromPixels(
214+
JSTensorViewIn pixelData, double detectionThreshold, double iouThreshold,
215+
std::vector<int32_t> classIndices, std::string methodName) {
148216
cv::Mat image = extractFromPixels(pixelData);
149217

150-
return runInference(image, detectionThreshold);
218+
return runInference(image, detectionThreshold, iouThreshold, classIndices,
219+
methodName);
151220
}
152221
} // namespace rnexecutorch::models::object_detection

0 commit comments

Comments
 (0)