Skip to content

Commit 030f124

Browse files
committed
chore: address suggestions from code review
1 parent 61d9548 commit 030f124

File tree

2 files changed

+108
-18
lines changed

2 files changed

+108
-18
lines changed

packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ template <> struct ModelTraits<ObjectDetection> {
5050
}
5151

5252
static void callGenerate(ModelType &model) {
53-
(void)model.generateFromString(kValidTestImagePath, 0.5);
53+
(void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, {},
54+
"forward");
5455
}
5556
};
5657
} // namespace model_tests
@@ -67,57 +68,65 @@ INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, VisionModelTest,
6768
TEST(ObjectDetectionGenerateTests, InvalidImagePathThrows) {
6869
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
6970
nullptr);
70-
EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5),
71+
EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5,
72+
0.55, {}, "forward"),
7173
RnExecutorchError);
7274
}
7375

7476
TEST(ObjectDetectionGenerateTests, EmptyImagePathThrows) {
7577
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
7678
nullptr);
77-
EXPECT_THROW((void)model.generateFromString("", 0.5), RnExecutorchError);
79+
EXPECT_THROW((void)model.generateFromString("", 0.5, 0.55, {}, "forward"),
80+
RnExecutorchError);
7881
}
7982

8083
TEST(ObjectDetectionGenerateTests, MalformedURIThrows) {
8184
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
8285
nullptr);
83-
EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", 0.5),
86+
EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", 0.5,
87+
0.55, {}, "forward"),
8488
RnExecutorchError);
8589
}
8690

8791
TEST(ObjectDetectionGenerateTests, NegativeThresholdThrows) {
8892
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
8993
nullptr);
90-
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1),
94+
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.55,
95+
{}, "forward"),
9196
RnExecutorchError);
9297
}
9398

9499
TEST(ObjectDetectionGenerateTests, ThresholdAboveOneThrows) {
95100
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
96101
nullptr);
97-
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1),
102+
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.55,
103+
{}, "forward"),
98104
RnExecutorchError);
99105
}
100106

101107
TEST(ObjectDetectionGenerateTests, ValidImageReturnsResults) {
102108
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
103109
nullptr);
104-
auto results = model.generateFromString(kValidTestImagePath, 0.3);
110+
auto results =
111+
model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward");
105112
EXPECT_GE(results.size(), 0u);
106113
}
107114

108115
TEST(ObjectDetectionGenerateTests, HighThresholdReturnsFewerResults) {
109116
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
110117
nullptr);
111-
auto lowThresholdResults = model.generateFromString(kValidTestImagePath, 0.1);
118+
auto lowThresholdResults =
119+
model.generateFromString(kValidTestImagePath, 0.1, 0.55, {}, "forward");
112120
auto highThresholdResults =
113-
model.generateFromString(kValidTestImagePath, 0.9);
121+
model.generateFromString(kValidTestImagePath, 0.9, 0.55, {}, "forward");
114122
EXPECT_GE(lowThresholdResults.size(), highThresholdResults.size());
115123
}
116124

117125
TEST(ObjectDetectionGenerateTests, DetectionsHaveValidBoundingBoxes) {
118126
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
119127
nullptr);
120-
auto results = model.generateFromString(kValidTestImagePath, 0.3);
128+
auto results =
129+
model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward");
121130

122131
for (const auto &detection : results) {
123132
EXPECT_LE(detection.bbox.x1, detection.bbox.x2);
@@ -130,7 +139,8 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidBoundingBoxes) {
130139
TEST(ObjectDetectionGenerateTests, DetectionsHaveValidScores) {
131140
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
132141
nullptr);
133-
auto results = model.generateFromString(kValidTestImagePath, 0.3);
142+
auto results =
143+
model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward");
134144

135145
for (const auto &detection : results) {
136146
EXPECT_GE(detection.score, 0.0f);
@@ -141,7 +151,8 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidScores) {
141151
TEST(ObjectDetectionGenerateTests, DetectionsHaveValidLabels) {
142152
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
143153
nullptr);
144-
auto results = model.generateFromString(kValidTestImagePath, 0.3);
154+
auto results =
155+
model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward");
145156

146157
for (const auto &detection : results) {
147158
const auto &label = detection.label;
@@ -162,7 +173,7 @@ TEST(ObjectDetectionPixelTests, ValidPixelDataReturnsResults) {
162173
JSTensorViewIn tensorView{pixelData.data(),
163174
{height, width, channels},
164175
executorch::aten::ScalarType::Byte};
165-
auto results = model.generateFromPixels(tensorView, 0.3);
176+
auto results = model.generateFromPixels(tensorView, 0.3, 0.55, {}, "forward");
166177
EXPECT_GE(results.size(), 0u);
167178
}
168179

@@ -174,8 +185,9 @@ TEST(ObjectDetectionPixelTests, NegativeThresholdThrows) {
174185
JSTensorViewIn tensorView{pixelData.data(),
175186
{height, width, channels},
176187
executorch::aten::ScalarType::Byte};
177-
EXPECT_THROW((void)model.generateFromPixels(tensorView, -0.1),
178-
RnExecutorchError);
188+
EXPECT_THROW(
189+
(void)model.generateFromPixels(tensorView, -0.1, 0.55, {}, "forward"),
190+
RnExecutorchError);
179191
}
180192

181193
TEST(ObjectDetectionPixelTests, ThresholdAboveOneThrows) {
@@ -186,8 +198,9 @@ TEST(ObjectDetectionPixelTests, ThresholdAboveOneThrows) {
186198
JSTensorViewIn tensorView{pixelData.data(),
187199
{height, width, channels},
188200
executorch::aten::ScalarType::Byte};
189-
EXPECT_THROW((void)model.generateFromPixels(tensorView, 1.1),
190-
RnExecutorchError);
201+
EXPECT_THROW(
202+
(void)model.generateFromPixels(tensorView, 1.1, 0.55, {}, "forward"),
203+
RnExecutorchError);
191204
}
192205

193206
TEST(ObjectDetectionInheritedTests, GetInputShapeWorks) {
@@ -239,5 +252,67 @@ TEST(ObjectDetectionNormTests, ValidNormParamsGenerateSucceeds) {
239252
const std::vector<float> std = {0.229f, 0.224f, 0.225f};
240253
ObjectDetection model(kValidObjectDetectionModelPath, mean, std, kCocoLabels,
241254
nullptr);
242-
EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5));
255+
EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.55,
256+
{}, "forward"));
257+
}
258+
259+
// ============================================================================
260+
// Method name tests
261+
// ============================================================================
262+
TEST(ObjectDetectionMethodTests, InvalidMethodNameThrows) {
263+
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
264+
nullptr);
265+
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.55,
266+
{}, "forward_999"),
267+
RnExecutorchError);
268+
}
269+
270+
TEST(ObjectDetectionMethodTests, EmptyMethodNameThrows) {
271+
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
272+
nullptr);
273+
EXPECT_THROW(
274+
(void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, {}, ""),
275+
RnExecutorchError);
276+
}
277+
278+
// ============================================================================
279+
// Class indices filtering tests
280+
// ============================================================================
281+
TEST(ObjectDetectionClassFilterTests,
282+
FilteredResultsOnlyContainRequestedClasses) {
283+
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
284+
nullptr);
285+
// Only request "person" class (index 0 in COCO)
286+
auto results =
287+
model.generateFromString(kValidTestImagePath, 0.3, 0.55, {0}, "forward");
288+
for (const auto &det : results) {
289+
EXPECT_EQ(det.label, "person");
290+
}
291+
}
292+
293+
TEST(ObjectDetectionClassFilterTests,
294+
EmptyClassIndicesReturnsMoreOrEqualResults) {
295+
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
296+
nullptr);
297+
auto allClasses =
298+
model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward");
299+
// person (0) only
300+
auto filtered =
301+
model.generateFromString(kValidTestImagePath, 0.3, 0.55, {0}, "forward");
302+
EXPECT_GE(allClasses.size(), filtered.size());
303+
}
304+
305+
// ============================================================================
306+
// IoU threshold tests
307+
// ============================================================================
308+
TEST(ObjectDetectionIouTests, HigherIouThresholdReturnsSameOrMoreResults) {
309+
ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels,
310+
nullptr);
311+
// High IoU threshold = less aggressive NMS = more boxes survive
312+
auto highIou =
313+
model.generateFromString(kValidTestImagePath, 0.3, 0.9, {}, "forward");
314+
// Low IoU threshold = more aggressive NMS = fewer boxes survive
315+
auto lowIou =
316+
model.generateFromString(kValidTestImagePath, 0.3, 0.1, {}, "forward");
317+
EXPECT_GE(highIou.size(), lowIou.size());
243318
}

packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ export class ObjectDetectionModule<
168168
this.modelConfig.defaultDetectionThreshold ?? 0.7;
169169
const defaultIouThreshold = this.modelConfig.defaultIouThreshold ?? 0.55;
170170
const defaultInputSize = this.modelConfig.defaultInputSize;
171+
const availableInputSizes = this.modelConfig.availableInputSizes;
171172

172173
return (
173174
frame: any,
@@ -180,6 +181,20 @@ export class ObjectDetectionModule<
180181
options?.detectionThreshold ?? defaultDetectionThreshold;
181182
const iouThreshold = options?.iouThreshold ?? defaultIouThreshold;
182183
const inputSize = options?.inputSize ?? defaultInputSize;
184+
185+
if (
186+
availableInputSizes &&
187+
inputSize !== undefined &&
188+
!availableInputSizes.includes(
189+
inputSize as (typeof availableInputSizes)[number]
190+
)
191+
) {
192+
throw new RnExecutorchError(
193+
RnExecutorchErrorCode.InvalidArgument,
194+
`Invalid inputSize: ${inputSize}. Available sizes: ${availableInputSizes.join(', ')}`
195+
);
196+
}
197+
183198
const methodName =
184199
inputSize !== undefined ? `forward_${inputSize}` : 'forward';
185200

0 commit comments

Comments
 (0)