@@ -24,6 +24,8 @@ def test_classes_replacement_when_object_detection_object_is_none() -> None:
2424 result = step .run (
2525 object_detection_predictions = None ,
2626 classification_predictions = None ,
27+ fallback_class_name = None ,
28+ fallback_class_id = None ,
2729 )
2830
2931 # then
@@ -43,6 +45,8 @@ def test_classes_replacement_when_there_are_no_predictions_is_none() -> None:
4345 result = step .run (
4446 object_detection_predictions = detections ,
4547 classification_predictions = None ,
48+ fallback_class_name = None ,
49+ fallback_class_id = None ,
4650 )
4751
4852 # then
@@ -100,6 +104,8 @@ def test_classes_replacement_when_replacement_to_happen_without_filtering_for_mu
100104 result = step .run (
101105 object_detection_predictions = detections ,
102106 classification_predictions = classification_predictions ,
107+ fallback_class_name = None ,
108+ fallback_class_id = None ,
103109 )
104110
105111 # then
@@ -183,6 +189,8 @@ def test_classes_replacement_when_replacement_to_happen_without_filtering_for_mu
183189 result = step .run (
184190 object_detection_predictions = detections ,
185191 classification_predictions = classification_predictions ,
192+ fallback_class_name = None ,
193+ fallback_class_id = None ,
186194 )
187195
188196 # then
@@ -245,6 +253,8 @@ def test_classes_replacement_when_replacement_to_happen_and_one_result_to_be_fil
245253 result = step .run (
246254 object_detection_predictions = detections ,
247255 classification_predictions = classification_predictions ,
256+ fallback_class_name = None ,
257+ fallback_class_id = None ,
248258 )
249259
250260 # then
@@ -271,7 +281,7 @@ def test_classes_replacement_when_replacement_to_happen_and_one_result_to_be_fil
271281 ], "Expected to generate new detection id"
272282
273283
274- def test_classes_replacement_when_empty_classification_predictions ():
284+ def test_classes_replacement_when_empty_classification_predictions_no_fallback_class ():
275285 # given
276286 step = DetectionsClassesReplacementBlockV1 ()
277287 detections = sv .Detections (
@@ -305,6 +315,8 @@ def test_classes_replacement_when_empty_classification_predictions():
305315 result = step .run (
306316 object_detection_predictions = detections ,
307317 classification_predictions = classification_predictions ,
318+ fallback_class_name = None ,
319+ fallback_class_id = None ,
308320 )
309321
310322 # then
@@ -313,6 +325,72 @@ def test_classes_replacement_when_empty_classification_predictions():
313325 ), "Expected sv.Detections.empty(), as empty classification was passed"
314326
315327
328+ def test_classes_replacement_when_empty_classification_predictions_fallback_class_provided ():
329+ # given
330+ step = DetectionsClassesReplacementBlockV1 ()
331+ detections = sv .Detections (
332+ xyxy = np .array (
333+ [
334+ [10 , 20 , 30 , 40 ],
335+ [11 , 21 , 31 , 41 ],
336+ ]
337+ ),
338+ class_id = np .array ([7 , 7 ]),
339+ confidence = np .array ([0.36 , 0.91 ]),
340+ data = {
341+ "class_name" : np .array (["animal" , "animal" ]),
342+ "detection_id" : np .array (["zero" , "one" ]),
343+ },
344+ )
345+ first_cls_prediction = ClassificationInferenceResponse (
346+ image = InferenceResponseImage (width = 128 , height = 256 ),
347+ predictions = [
348+ ClassificationPrediction (
349+ ** {"class" : "cat" , "class_id" : 0 , "confidence" : 0.6 }
350+ ),
351+ ClassificationPrediction (
352+ ** {"class" : "dog" , "class_id" : 1 , "confidence" : 0.4 }
353+ ),
354+ ],
355+ top = "cat" ,
356+ confidence = 0.6 ,
357+ parent_id = "some" ,
358+ ).dict (by_alias = True , exclude_none = True )
359+ first_cls_prediction ["parent_id" ] = "zero"
360+ second_cls_prediction = ClassificationInferenceResponse (
361+ image = InferenceResponseImage (width = 128 , height = 256 ),
362+ predictions = [],
363+ top = "cat" ,
364+ confidence = 0.6 ,
365+ parent_id = "some" ,
366+ ).dict (by_alias = True , exclude_none = True )
367+ second_cls_prediction ["parent_id" ] = "one"
368+ classification_predictions = Batch (
369+ content = [
370+ first_cls_prediction ,
371+ second_cls_prediction ,
372+ ],
373+ indices = [(0 , 0 ), (0 , 1 )],
374+ )
375+
376+ # when
377+ result = step .run (
378+ object_detection_predictions = detections ,
379+ classification_predictions = classification_predictions ,
380+ fallback_class_name = "unknown" ,
381+ fallback_class_id = 123 ,
382+ )
383+
384+ # then
385+ assert (
386+ len (result ["predictions" ]) == 2
387+ ), "Expected sv.Detections.empty(), as empty classification was passed"
388+ detections = result ["predictions" ]
389+ assert detections .confidence [1 ] == 0 , "Fallback class confidence expected to be set to 0"
390+ assert detections .class_id [1 ] == 123 , "class id expected to be set to value passed with fallback_class_id parameter"
391+ assert detections .data ["class_name" ][1 ] == "unknown" , "class name expected to be set to value passed with fallback_class_name parameter"
392+
393+
316394def test_extract_leading_class_from_prediction_when_prediction_is_multi_label () -> None :
317395 # given
318396 prediction = ClassificationInferenceResponse (
0 commit comments