Skip to content

Commit 54a2e95

Browse files
committed
Migrate ObjectDetection to common CV utils
1 parent 8398ee9 commit 54a2e95

File tree

9 files changed

+37
-112
lines changed

9 files changed

+37
-112
lines changed

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -446,13 +446,8 @@ inline jsi::Value getJsiValue(
446446
jsi::Array array(runtime, detections.size());
447447
for (std::size_t i = 0; i < detections.size(); ++i) {
448448
jsi::Object detection(runtime);
449-
jsi::Object bbox(runtime);
450-
bbox.setProperty(runtime, "x1", detections[i].x1);
451-
bbox.setProperty(runtime, "y1", detections[i].y1);
452-
bbox.setProperty(runtime, "x2", detections[i].x2);
453-
bbox.setProperty(runtime, "y2", detections[i].y2);
454-
455-
detection.setProperty(runtime, "bbox", bbox);
449+
detection.setProperty(runtime, "bbox",
450+
getJsiValue(detections[i].bbox, runtime));
456451
detection.setProperty(
457452
runtime, "label",
458453
jsi::String::createFromUtf8(runtime, detections[i].label));
@@ -462,10 +457,10 @@ inline jsi::Value getJsiValue(
462457
return array;
463458
}
464459

465-
inline jsi::Value getJsiValue(
466-
const std::vector<models::instance_segmentation::types::InstanceMask>
467-
&instances,
468-
jsi::Runtime &runtime) {
460+
inline jsi::Value
461+
getJsiValue(const std::vector<models::instance_segmentation::types::Instance>
462+
&instances,
463+
jsi::Runtime &runtime) {
469464
jsi::Array array(runtime, instances.size());
470465
for (std::size_t i = 0; i < instances.size(); ++i) {
471466
jsi::Object instance(runtime);

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ cv::Mat BaseInstanceSegmentation::processMaskFromLogits(
9090
return finalBinaryMat;
9191
}
9292

93-
std::optional<types::InstanceMask> BaseInstanceSegmentation::processDetection(
93+
std::optional<types::Instance> BaseInstanceSegmentation::processDetection(
9494
int32_t detectionIndex, const float *bboxData, const float *scoresData,
9595
const float *maskData, int32_t maskH, int32_t maskW,
9696
cv::Size modelInputSize, cv::Size originalSize, float widthRatio,
@@ -140,13 +140,13 @@ std::optional<types::InstanceMask> BaseInstanceSegmentation::processDetection(
140140
std::vector<uint8_t> finalMask(finalBinaryMat.data,
141141
finalBinaryMat.data + finalBinaryMat.total());
142142

143-
return types::InstanceMask(
143+
return types::Instance(
144144
utils::computer_vision::BBox{origX1, origY1, origX2, origY2},
145145
std::move(finalMask), finalMaskWidth, finalMaskHeight,
146146
static_cast<int32_t>(labelIdx), score, i);
147147
}
148148

149-
std::vector<types::InstanceMask> BaseInstanceSegmentation::postprocess(
149+
std::vector<types::Instance> BaseInstanceSegmentation::postprocess(
150150
const std::vector<EValue> &tensors, cv::Size originalSize,
151151
cv::Size modelInputSize, double confidenceThreshold, double iouThreshold,
152152
int32_t maxInstances, const std::vector<int32_t> &classIndices,
@@ -175,7 +175,7 @@ std::vector<types::InstanceMask> BaseInstanceSegmentation::postprocess(
175175
allowedClasses.insert(classIndices.begin(), classIndices.end());
176176
}
177177

178-
std::vector<types::InstanceMask> instances;
178+
std::vector<types::Instance> instances;
179179

180180
size_t numTensors = tensors.size();
181181
if (numTensors != 3) {
@@ -231,7 +231,7 @@ std::vector<types::InstanceMask> BaseInstanceSegmentation::postprocess(
231231
return instances;
232232
}
233233

234-
std::vector<types::InstanceMask> BaseInstanceSegmentation::generate(
234+
std::vector<types::Instance> BaseInstanceSegmentation::generate(
235235
std::string imageSource, double confidenceThreshold, double iouThreshold,
236236
int32_t maxInstances, std::vector<int32_t> classIndices,
237237
bool returnMaskAtOriginalResolution, std::string methodName) {

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ class BaseInstanceSegmentation : public BaseModel {
2222
std::vector<float> normStd, bool applyNMS,
2323
std::shared_ptr<react::CallInvoker> callInvoker);
2424

25-
[[nodiscard("Registered non-void function")]] std::vector<types::InstanceMask>
25+
[[nodiscard("Registered non-void function")]] std::vector<types::Instance>
2626
generate(std::string imageSource, double confidenceThreshold,
2727
double iouThreshold, int32_t maxInstances,
2828
std::vector<int32_t> classIndices,
2929
bool returnMaskAtOriginalResolution, std::string methodName);
3030

3131
private:
32-
std::vector<types::InstanceMask>
32+
std::vector<types::Instance>
3333
postprocess(const std::vector<EValue> &tensors, cv::Size originalSize,
3434
cv::Size modelInputSize, double confidenceThreshold,
3535
double iouThreshold, int32_t maxInstances,
@@ -43,7 +43,7 @@ class BaseInstanceSegmentation : public BaseModel {
4343
float origX1, float origY1, bool warpToOriginal,
4444
int32_t &outWidth, int32_t &outHeight);
4545

46-
std::optional<types::InstanceMask>
46+
std::optional<types::Instance>
4747
processDetection(int32_t detectionIndex, const float *bboxData,
4848
const float *scoresData, const float *maskData,
4949
int32_t maskH, int32_t maskW, cv::Size modelInputSize,

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/Types.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace rnexecutorch::models::instance_segmentation::types {
1212
* Contains bounding box coordinates, binary segmentation mask, class label,
1313
* confidence score, and a unique instance identifier.
1414
*/
15-
struct InstanceMask {
15+
struct Instance {
1616
utils::computer_vision::BBox bbox; ///< Bounding box coordinates
1717
std::vector<uint8_t> mask; ///< Binary mask (0 or 1) for the instance
1818
int32_t maskWidth; ///< Width of the mask array
@@ -21,15 +21,10 @@ struct InstanceMask {
2121
float score; ///< Confidence score [0, 1]
2222
int32_t instanceId; ///< Unique identifier for this instance
2323

24-
InstanceMask() = default;
25-
InstanceMask(const InstanceMask &) = default;
26-
InstanceMask(InstanceMask &&) = default;
27-
InstanceMask &operator=(const InstanceMask &) = default;
28-
InstanceMask &operator=(InstanceMask &&) = default;
29-
30-
InstanceMask(utils::computer_vision::BBox bbox, std::vector<uint8_t> mask,
31-
int32_t maskWidth, int32_t maskHeight, int32_t classIndex,
32-
float score, int32_t instanceId)
24+
Instance() = default;
25+
Instance(utils::computer_vision::BBox bbox, std::vector<uint8_t> mask,
26+
int32_t maskWidth, int32_t maskHeight, int32_t classIndex,
27+
float score, int32_t instanceId)
3328
: bbox(bbox), mask(std::move(mask)), maskWidth(maskWidth),
3429
maskHeight(maskHeight), classIndex(classIndex), score(score),
3530
instanceId(instanceId) {}

packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#include "ObjectDetection.h"
2+
#include "Constants.h"
23

34
#include <rnexecutorch/Error.h>
45
#include <rnexecutorch/ErrorCodes.h>
56
#include <rnexecutorch/Log.h>
67
#include <rnexecutorch/data_processing/ImageProcessing.h>
78
#include <rnexecutorch/host_objects/JsiConversions.h>
89
#include <rnexecutorch/utils/FrameProcessor.h>
10+
#include <rnexecutorch/utils/computer_vision/Processing.h>
911

1012
namespace rnexecutorch::models::object_detection {
1113

@@ -119,10 +121,13 @@ ObjectDetection::postprocess(const std::vector<EValue> &tensors,
119121
" exceeds labelNames size " + std::to_string(labelNames_.size()) +
120122
". Ensure the labelMap covers all model output classes.");
121123
}
122-
detections.emplace_back(x1, y1, x2, y2, labelNames_[labelIdx], scores[i]);
124+
detections.emplace_back(utils::computer_vision::BBox{x1, y1, x2, y2},
125+
labelNames_[labelIdx],
126+
static_cast<int32_t>(labelIdx), scores[i]);
123127
}
124128

125-
return utils::nonMaxSuppression(detections);
129+
return utils::computer_vision::nonMaxSuppression(detections,
130+
constants::IOU_THRESHOLD);
126131
}
127132

128133
std::vector<types::Detection>

packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "Types.h"
99
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
1010
#include <rnexecutorch/models/VisionModel.h>
11-
#include <rnexecutorch/models/object_detection/Utils.h>
1211

1312
namespace rnexecutorch {
1413
namespace models::object_detection {
Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
#pragma once
22

3+
#include <cstdint>
4+
#include <rnexecutorch/utils/computer_vision/Types.h>
35
#include <string>
46

57
namespace rnexecutorch::models::object_detection::types {
68
struct Detection {
7-
float x1;
8-
float y1;
9-
float x2;
10-
float y2;
9+
utils::computer_vision::BBox bbox;
1110
std::string label;
11+
int32_t classIndex;
1212
float score;
13+
14+
Detection() = default;
15+
Detection(utils::computer_vision::BBox bbox, std::string label,
16+
int32_t classIndex, float score)
17+
: bbox(bbox), label(std::move(label)), classIndex(classIndex),
18+
score(score) {}
1319
};
1420

15-
} // namespace rnexecutorch::models::object_detection::types
21+
} // namespace rnexecutorch::models::object_detection::types

packages/react-native-executorch/common/rnexecutorch/models/object_detection/Utils.cpp

Lines changed: 0 additions & 64 deletions
This file was deleted.

packages/react-native-executorch/common/rnexecutorch/models/object_detection/Utils.h

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)