Skip to content

Commit ab4c33c

Browse files
committed
Add Mateusz's patch
1 parent b3a9b89 commit ab4c33c

File tree

3 files changed

+45
-63
lines changed

3 files changed

+45
-63
lines changed

apps/computer-vision/app/vision_camera/index.tsx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,6 @@ const TASKS: Task[] = [
7474
{ id: 'segmentationSelfie', label: 'Selfie' },
7575
],
7676
},
77-
{
78-
id: 'objectDetection',
79-
label: 'Detect',
80-
variants: [
81-
{ id: 'objectDetectionSsdlite', label: 'SSDLite MobileNet' },
82-
{ id: 'objectDetectionRfdetr', label: 'RF-DETR Nano' },
83-
],
84-
},
8577
{
8678
id: 'instanceSegmentation',
8779
label: 'Inst Seg',
@@ -90,6 +82,14 @@ const TASKS: Task[] = [
9082
{ id: 'instanceSegmentation_rfdetr', label: 'RF-DETR Nano Seg' },
9183
],
9284
},
85+
{
86+
id: 'objectDetection',
87+
label: 'Detect',
88+
variants: [
89+
{ id: 'objectDetectionSsdlite', label: 'SSDLite MobileNet' },
90+
{ id: 'objectDetectionRfdetr', label: 'RF-DETR Nano' },
91+
],
92+
},
9393
];
9494

9595
// Module-level const so worklets in task components can always reference the same stable object.

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp

Lines changed: 31 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@ cv::Size BaseInstanceSegmentation::modelInputSize() const {
4343
return {shape[shape.size() - 2], shape[shape.size() - 1]};
4444
}
4545

46+
TensorPtr BaseInstanceSegmentation::buildInputTensor(const cv::Mat &image) {
47+
cv::Mat preprocessed = preprocess(image);
48+
return (normMean_.has_value() && normStd_.has_value())
49+
? image_processing::getTensorFromMatrix(
50+
modelInputShape_, preprocessed, normMean_.value(),
51+
normStd_.value())
52+
: image_processing::getTensorFromMatrix(modelInputShape_,
53+
preprocessed);
54+
}
55+
4656
std::vector<types::Instance> BaseInstanceSegmentation::runInference(
4757
const cv::Mat &image, double confidenceThreshold, double iouThreshold,
4858
int32_t maxInstances, const std::vector<int32_t> &classIndices,
@@ -64,16 +74,8 @@ std::vector<types::Instance> BaseInstanceSegmentation::runInference(
6474
cv::Size modelInputSize(shape[shape.size() - 2], shape[shape.size() - 1]);
6575
cv::Size originalSize(image.cols, image.rows);
6676

67-
cv::Mat preprocessed = preprocess(image);
68-
69-
auto inputTensor = (normMean_.has_value() && normStd_.has_value())
70-
? image_processing::getTensorFromMatrix(
71-
modelInputShape_, preprocessed,
72-
normMean_.value(), normStd_.value())
73-
: image_processing::getTensorFromMatrix(
74-
modelInputShape_, preprocessed);
75-
76-
auto forwardResult = BaseModel::execute(methodName, {inputTensor});
77+
auto forwardResult =
78+
BaseModel::execute(methodName, {buildInputTensor(image)});
7779
if (!forwardResult.ok()) {
7880
throw RnExecutorchError(
7981
forwardResult.error(),
@@ -82,9 +84,13 @@ std::vector<types::Instance> BaseInstanceSegmentation::runInference(
8284
methodName + "' is valid.");
8385
}
8486

85-
return postprocess(forwardResult.get(), originalSize, modelInputSize,
86-
confidenceThreshold, iouThreshold, maxInstances,
87-
classIndices, returnMaskAtOriginalResolution);
87+
validateThresholds(confidenceThreshold, iouThreshold);
88+
validateOutputTensors(forwardResult.get());
89+
90+
auto instances = collectInstances(
91+
forwardResult.get(), originalSize, modelInputSize, confidenceThreshold,
92+
classIndices, returnMaskAtOriginalResolution);
93+
return finalizeInstances(std::move(instances), iouThreshold, maxInstances);
8894
}
8995

9096
std::vector<types::Instance> BaseInstanceSegmentation::generateFromString(
@@ -259,37 +265,16 @@ void BaseInstanceSegmentation::ensureMethodLoaded(
259265
methodName + "'");
260266
}
261267

262-
if (currentlyLoadedMethod_ != methodName) {
263-
if (!currentlyLoadedMethod_.empty()) {
264-
module_->unload_method(currentlyLoadedMethod_);
265-
}
266-
currentlyLoadedMethod_ = methodName;
267-
auto loadResult = module_->load_method(methodName);
268-
if (loadResult != executorch::runtime::Error::Ok) {
269-
throw RnExecutorchError(
270-
loadResult, "Failed to load method '" + methodName +
271-
"'. Ensure the method exists in the exported model.");
272-
}
268+
if (!currentlyLoadedMethod_.empty()) {
269+
module_->unload_method(currentlyLoadedMethod_);
273270
}
274-
}
275-
276-
cv::Size BaseInstanceSegmentation::getInputSize(const std::string &methodName) {
277-
auto inputShapes = getAllInputShapes(methodName);
278-
if (inputShapes.empty()) {
279-
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
280-
"Method '" + methodName +
281-
"' has no input tensors.");
282-
}
283-
284-
const auto &inputShape = inputShapes[0];
285-
if (inputShape.empty()) {
286-
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
287-
"Method '" + methodName +
288-
"' input tensor has no dimensions.");
271+
currentlyLoadedMethod_ = methodName;
272+
auto loadResult = module_->load_method(methodName);
273+
if (loadResult != executorch::runtime::Error::Ok) {
274+
throw RnExecutorchError(
275+
loadResult, "Failed to load method '" + methodName +
276+
"'. Ensure the method exists in the exported model.");
289277
}
290-
291-
int32_t inputSize = inputShape[inputShape.size() - 1];
292-
return cv::Size(inputSize, inputSize);
293278
}
294279

295280
std::vector<types::Instance> BaseInstanceSegmentation::finalizeInstances(
@@ -308,15 +293,12 @@ std::vector<types::Instance> BaseInstanceSegmentation::finalizeInstances(
308293
return instances;
309294
}
310295

311-
std::vector<types::Instance> BaseInstanceSegmentation::postprocess(
296+
std::vector<types::Instance> BaseInstanceSegmentation::collectInstances(
312297
const std::vector<EValue> &tensors, cv::Size originalSize,
313-
cv::Size modelInputSize, double confidenceThreshold, double iouThreshold,
314-
int32_t maxInstances, const std::vector<int32_t> &classIndices,
298+
cv::Size modelInputSize, double confidenceThreshold,
299+
const std::vector<int32_t> &classIndices,
315300
bool returnMaskAtOriginalResolution) {
316301

317-
validateThresholds(confidenceThreshold, iouThreshold);
318-
validateOutputTensors(tensors);
319-
320302
float widthRatio =
321303
static_cast<float>(originalSize.width) / modelInputSize.width;
322304
float heightRatio =
@@ -371,7 +353,7 @@ std::vector<types::Instance> BaseInstanceSegmentation::postprocess(
371353
binaryMask.cols, binaryMask.rows, labelIdx, score);
372354
}
373355

374-
return finalizeInstances(std::move(instances), iouThreshold, maxInstances);
356+
return instances;
375357
}
376358

377359
} // namespace rnexecutorch::models::instance_segmentation

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,13 @@ class BaseInstanceSegmentation : public VisionModel {
5353
int32_t maxInstances, const std::vector<int32_t> &classIndices,
5454
bool returnMaskAtOriginalResolution, const std::string &methodName);
5555

56+
TensorPtr buildInputTensor(const cv::Mat &image);
57+
5658
std::vector<types::Instance>
57-
postprocess(const std::vector<EValue> &tensors, cv::Size originalSize,
58-
cv::Size modelInputSize, double confidenceThreshold,
59-
double iouThreshold, int32_t maxInstances,
60-
const std::vector<int32_t> &classIndices,
61-
bool returnMaskAtOriginalResolution);
59+
collectInstances(const std::vector<EValue> &tensors, cv::Size originalSize,
60+
cv::Size modelInputSize, double confidenceThreshold,
61+
const std::vector<int32_t> &classIndices,
62+
bool returnMaskAtOriginalResolution);
6263

6364
void validateThresholds(double confidenceThreshold,
6465
double iouThreshold) const;
@@ -69,7 +70,6 @@ class BaseInstanceSegmentation : public VisionModel {
6970

7071
// Model loading and input helpers
7172
void ensureMethodLoaded(const std::string &methodName);
72-
cv::Size getInputSize(const std::string &methodName);
7373

7474
std::tuple<utils::computer_vision::BBox, float, int32_t>
7575
extractDetectionData(const float *bboxData, const float *scoresData,

0 commit comments

Comments
 (0)