@@ -50,7 +50,8 @@ template <> struct ModelTraits<ObjectDetection> {
5050 }
5151
5252 static void callGenerate (ModelType &model) {
53- (void )model.generateFromString (kValidTestImagePath , 0.5 );
53+ (void )model.generateFromString (kValidTestImagePath , 0.5 , 0.55 , {},
54+ " forward" );
5455 }
5556};
5657} // namespace model_tests
@@ -67,57 +68,65 @@ INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, VisionModelTest,
6768TEST (ObjectDetectionGenerateTests, InvalidImagePathThrows) {
6869 ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
6970 nullptr );
70- EXPECT_THROW ((void )model.generateFromString (" nonexistent_image.jpg" , 0.5 ),
71+ EXPECT_THROW ((void )model.generateFromString (" nonexistent_image.jpg" , 0.5 ,
72+ 0.55 , {}, " forward" ),
7173 RnExecutorchError);
7274}
7375
7476TEST (ObjectDetectionGenerateTests, EmptyImagePathThrows) {
7577 ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
7678 nullptr );
77- EXPECT_THROW ((void )model.generateFromString (" " , 0.5 ), RnExecutorchError);
79+ EXPECT_THROW ((void )model.generateFromString (" " , 0.5 , 0.55 , {}, " forward" ),
80+ RnExecutorchError);
7881}
7982
8083TEST (ObjectDetectionGenerateTests, MalformedURIThrows) {
8184 ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
8285 nullptr );
83- EXPECT_THROW ((void )model.generateFromString (" not_a_valid_uri://bad" , 0.5 ),
86+ EXPECT_THROW ((void )model.generateFromString (" not_a_valid_uri://bad" , 0.5 ,
87+ 0.55 , {}, " forward" ),
8488 RnExecutorchError);
8589}
8690
8791TEST (ObjectDetectionGenerateTests, NegativeThresholdThrows) {
8892 ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
8993 nullptr );
90- EXPECT_THROW ((void )model.generateFromString (kValidTestImagePath , -0.1 ),
94+ EXPECT_THROW ((void )model.generateFromString (kValidTestImagePath , -0.1 , 0.55 ,
95+ {}, " forward" ),
9196 RnExecutorchError);
9297}
9398
9499TEST (ObjectDetectionGenerateTests, ThresholdAboveOneThrows) {
95100 ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
96101 nullptr );
97- EXPECT_THROW ((void )model.generateFromString (kValidTestImagePath , 1.1 ),
102+ EXPECT_THROW ((void )model.generateFromString (kValidTestImagePath , 1.1 , 0.55 ,
103+ {}, " forward" ),
98104 RnExecutorchError);
99105}
100106
101107TEST (ObjectDetectionGenerateTests, ValidImageReturnsResults) {
102108 ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
103109 nullptr );
104- auto results = model.generateFromString (kValidTestImagePath , 0.3 );
110+ auto results =
111+ model.generateFromString (kValidTestImagePath , 0.3 , 0.55 , {}, " forward" );
105112 EXPECT_GE (results.size (), 0u );
106113}
107114
108115TEST (ObjectDetectionGenerateTests, HighThresholdReturnsFewerResults) {
109116 ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
110117 nullptr );
111- auto lowThresholdResults = model.generateFromString (kValidTestImagePath , 0.1 );
118+ auto lowThresholdResults =
119+ model.generateFromString (kValidTestImagePath , 0.1 , 0.55 , {}, " forward" );
112120 auto highThresholdResults =
113- model.generateFromString (kValidTestImagePath , 0.9 );
121+ model.generateFromString (kValidTestImagePath , 0.9 , 0.55 , {}, " forward " );
114122 EXPECT_GE (lowThresholdResults.size (), highThresholdResults.size ());
115123}
116124
117125TEST (ObjectDetectionGenerateTests, DetectionsHaveValidBoundingBoxes) {
118126 ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
119127 nullptr );
120- auto results = model.generateFromString (kValidTestImagePath , 0.3 );
128+ auto results =
129+ model.generateFromString (kValidTestImagePath , 0.3 , 0.55 , {}, " forward" );
121130
122131 for (const auto &detection : results) {
123132 EXPECT_LE (detection.bbox .x1 , detection.bbox .x2 );
@@ -130,7 +139,8 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidBoundingBoxes) {
130139TEST (ObjectDetectionGenerateTests, DetectionsHaveValidScores) {
131140 ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
132141 nullptr );
133- auto results = model.generateFromString (kValidTestImagePath , 0.3 );
142+ auto results =
143+ model.generateFromString (kValidTestImagePath , 0.3 , 0.55 , {}, " forward" );
134144
135145 for (const auto &detection : results) {
136146 EXPECT_GE (detection.score , 0 .0f );
@@ -141,7 +151,8 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidScores) {
141151TEST (ObjectDetectionGenerateTests, DetectionsHaveValidLabels) {
142152 ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
143153 nullptr );
144- auto results = model.generateFromString (kValidTestImagePath , 0.3 );
154+ auto results =
155+ model.generateFromString (kValidTestImagePath , 0.3 , 0.55 , {}, " forward" );
145156
146157 for (const auto &detection : results) {
147158 const auto &label = detection.label ;
@@ -162,7 +173,7 @@ TEST(ObjectDetectionPixelTests, ValidPixelDataReturnsResults) {
162173 JSTensorViewIn tensorView{pixelData.data (),
163174 {height, width, channels},
164175 executorch::aten::ScalarType::Byte};
165- auto results = model.generateFromPixels (tensorView, 0.3 );
176+ auto results = model.generateFromPixels (tensorView, 0.3 , 0.55 , {}, " forward " );
166177 EXPECT_GE (results.size (), 0u );
167178}
168179
@@ -174,8 +185,9 @@ TEST(ObjectDetectionPixelTests, NegativeThresholdThrows) {
174185 JSTensorViewIn tensorView{pixelData.data (),
175186 {height, width, channels},
176187 executorch::aten::ScalarType::Byte};
177- EXPECT_THROW ((void )model.generateFromPixels (tensorView, -0.1 ),
178- RnExecutorchError);
188+ EXPECT_THROW (
189+ (void )model.generateFromPixels (tensorView, -0.1 , 0.55 , {}, " forward" ),
190+ RnExecutorchError);
179191}
180192
181193TEST (ObjectDetectionPixelTests, ThresholdAboveOneThrows) {
@@ -186,8 +198,9 @@ TEST(ObjectDetectionPixelTests, ThresholdAboveOneThrows) {
186198 JSTensorViewIn tensorView{pixelData.data (),
187199 {height, width, channels},
188200 executorch::aten::ScalarType::Byte};
189- EXPECT_THROW ((void )model.generateFromPixels (tensorView, 1.1 ),
190- RnExecutorchError);
201+ EXPECT_THROW (
202+ (void )model.generateFromPixels (tensorView, 1.1 , 0.55 , {}, " forward" ),
203+ RnExecutorchError);
191204}
192205
193206TEST (ObjectDetectionInheritedTests, GetInputShapeWorks) {
@@ -239,5 +252,67 @@ TEST(ObjectDetectionNormTests, ValidNormParamsGenerateSucceeds) {
239252 const std::vector<float > std = {0 .229f , 0 .224f , 0 .225f };
240253 ObjectDetection model (kValidObjectDetectionModelPath , mean, std, kCocoLabels ,
241254 nullptr );
242- EXPECT_NO_THROW ((void )model.generateFromString (kValidTestImagePath , 0.5 ));
255+ EXPECT_NO_THROW ((void )model.generateFromString (kValidTestImagePath , 0.5 , 0.55 ,
256+ {}, " forward" ));
257+ }
258+
259+ // ============================================================================
260+ // Method name tests
261+ // ============================================================================
262+ TEST (ObjectDetectionMethodTests, InvalidMethodNameThrows) {
263+ ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
264+ nullptr );
265+ EXPECT_THROW ((void )model.generateFromString (kValidTestImagePath , 0.5 , 0.55 ,
266+ {}, " forward_999" ),
267+ RnExecutorchError);
268+ }
269+
270+ TEST (ObjectDetectionMethodTests, EmptyMethodNameThrows) {
271+ ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
272+ nullptr );
273+ EXPECT_THROW (
274+ (void )model.generateFromString (kValidTestImagePath , 0.5 , 0.55 , {}, " " ),
275+ RnExecutorchError);
276+ }
277+
278+ // ============================================================================
279+ // Class indices filtering tests
280+ // ============================================================================
281+ TEST (ObjectDetectionClassFilterTests,
282+ FilteredResultsOnlyContainRequestedClasses) {
283+ ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
284+ nullptr );
285+ // Only request "person" class (index 0 in COCO)
286+ auto results =
287+ model.generateFromString (kValidTestImagePath , 0.3 , 0.55 , {0 }, " forward" );
288+ for (const auto &det : results) {
289+ EXPECT_EQ (det.label , " person" );
290+ }
291+ }
292+
293+ TEST (ObjectDetectionClassFilterTests,
294+ EmptyClassIndicesReturnsMoreOrEqualResults) {
295+ ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
296+ nullptr );
297+ auto allClasses =
298+ model.generateFromString (kValidTestImagePath , 0.3 , 0.55 , {}, " forward" );
299+ // person (0) only
300+ auto filtered =
301+ model.generateFromString (kValidTestImagePath , 0.3 , 0.55 , {0 }, " forward" );
302+ EXPECT_GE (allClasses.size (), filtered.size ());
303+ }
304+
305+ // ============================================================================
306+ // IoU threshold tests
307+ // ============================================================================
308+ TEST (ObjectDetectionIouTests, HigherIouThresholdReturnsSameOrMoreResults) {
309+ ObjectDetection model (kValidObjectDetectionModelPath , {}, {}, kCocoLabels ,
310+ nullptr );
311+ // High IoU threshold = less aggressive NMS = more boxes survive
312+ auto highIou =
313+ model.generateFromString (kValidTestImagePath , 0.3 , 0.9 , {}, " forward" );
314+ // Low IoU threshold = more aggressive NMS = fewer boxes survive
315+ auto lowIou =
316+ model.generateFromString (kValidTestImagePath , 0.3 , 0.1 , {}, " forward" );
317+ EXPECT_GE (highIou.size (), lowIou.size ());
243318}
0 commit comments