2525 HRNetEstimator ,
2626 HumanPostProcessor ,
2727 Pose2DResult ,
28+ ResultStatus ,
2829 Pose3DResult ,
2930 SuperAnimalEstimator ,
3031 _default_components ,
@@ -615,8 +616,9 @@ def test_predict_end_to_end_with_mock_estimator(self):
615616
616617 mock_kpts = np .random .randn (1 , 1 , 17 , 2 ).astype ("float32" )
617618 mock_scores = np .ones ((1 , 1 , 17 ), dtype = "float32" )
619+ mock_mask = np .array ([True ], dtype = bool )
618620 api ._estimator_2d = MagicMock ()
619- api ._estimator_2d .predict .return_value = (mock_kpts , mock_scores )
621+ api ._estimator_2d .predict .return_value = (mock_kpts , mock_scores , mock_mask )
620622 api ._estimator_2d .setup_runtime = MagicMock ()
621623
622624 frame = np .random .randint (0 , 255 , (480 , 640 , 3 ), dtype = np .uint8 )
@@ -626,6 +628,24 @@ def test_predict_end_to_end_with_mock_estimator(self):
626628 assert result .poses_3d .shape == (1 , 17 , 3 )
627629 api ._estimator_2d .predict .assert_called_once ()
628630
631+ def test_predict_applies_partial_2d_mask_to_3d (self ):
632+ """predict() masks invalid 2D frames to NaN in 3D outputs."""
633+ api = _make_ready_api ("fmpose3d_humans" , test_augmentation = False )
634+ mock_kpts = np .random .randn (1 , 3 , 17 , 2 ).astype ("float32" )
635+ mock_scores = np .ones ((1 , 3 , 17 ), dtype = "float32" )
636+ mask = np .array ([True , False , True ], dtype = bool )
637+ api ._estimator_2d = MagicMock ()
638+ api ._estimator_2d .predict .return_value = (mock_kpts , mock_scores , mask )
639+ api ._estimator_2d .setup_runtime = MagicMock ()
640+
641+ frame = np .random .randint (0 , 255 , (480 , 640 , 3 ), dtype = np .uint8 )
642+ result = api .predict ([frame , frame , frame ], seed = 42 )
643+
644+ np .testing .assert_array_equal (result .valid_frames_mask , mask )
645+ assert result .status == ResultStatus .PARTIAL
646+ assert np .all (np .isnan (result .poses_3d [1 ]))
647+ assert np .all (np .isnan (result .poses_3d_world [1 ]))
648+
629649
630650# =========================================================================
631651# Unit tests — dataclasses
@@ -648,12 +668,41 @@ def test_pose2d_result_default_image_size(self):
648668 )
649669 assert result .image_size == (0 , 0 )
650670
671+ def test_pose2d_status_success (self ):
672+ result = Pose2DResult (
673+ keypoints = np .zeros ((1 , 2 , 17 , 2 )),
674+ scores = np .zeros ((1 , 2 , 17 )),
675+ valid_frames_mask = np .array ([True , True ], dtype = bool ),
676+ )
677+ assert result .status == ResultStatus .SUCCESS
678+ assert "all frames" in result .status_message
679+
680+ def test_pose2d_status_partial (self ):
681+ result = Pose2DResult (
682+ keypoints = np .zeros ((1 , 2 , 17 , 2 )),
683+ scores = np .zeros ((1 , 2 , 17 )),
684+ valid_frames_mask = np .array ([True , False ], dtype = bool ),
685+ )
686+ assert result .status == ResultStatus .PARTIAL
687+ assert "subset" in result .status_message
688+
689+ def test_pose2d_status_invalid_mask_length (self ):
690+ result = Pose2DResult (
691+ keypoints = np .zeros ((1 , 2 , 17 , 2 )),
692+ scores = np .zeros ((1 , 2 , 17 )),
693+ valid_frames_mask = np .array ([True ], dtype = bool ),
694+ )
695+ assert result .status == ResultStatus .INVALID
696+ assert "mismatches" in result .status_message
697+
651698 def test_pose3d_result (self ):
652699 p3d = np .random .randn (10 , 17 , 3 )
653700 pw = np .random .randn (10 , 17 , 3 )
654- result = Pose3DResult (poses_3d = p3d , poses_3d_world = pw )
701+ mask = np .ones ((10 ,), dtype = bool )
702+ result = Pose3DResult (poses_3d = p3d , poses_3d_world = pw , valid_frames_mask = mask )
655703 assert result .poses_3d is p3d
656704 assert result .poses_3d_world is pw
705+ assert result .status == ResultStatus .SUCCESS
657706
658707
659708# =========================================================================
@@ -672,12 +721,13 @@ def test_predict_returns_zeros_when_no_bodyparts(self):
672721 "deeplabcut.pose_estimation_pytorch.apis.superanimal_analyze_images" ,
673722 ) as mock_fn :
674723 mock_fn .return_value = {"frame.png" : {"bodyparts" : None }}
675- kpts , scores = estimator .predict (frames )
724+ kpts , scores , mask = estimator .predict (frames )
676725
677726 assert kpts .shape == (1 , 2 , 26 , 2 )
678727 np .testing .assert_allclose (kpts , 0.0 )
679728 assert scores .shape == (1 , 2 , 26 )
680- np .testing .assert_allclose (scores , 1.0 )
729+ np .testing .assert_allclose (scores , 0.0 )
730+ np .testing .assert_array_equal (mask , np .array ([False , False ]))
681731
682732 def test_predict_maps_valid_bodyparts (self ):
683733 """Valid DLC bodyparts are mapped to Animal3D layout."""
@@ -692,9 +742,10 @@ def test_predict_maps_valid_bodyparts(self):
692742 "deeplabcut.pose_estimation_pytorch.apis.superanimal_analyze_images" ,
693743 ) as mock_fn :
694744 mock_fn .return_value = {"frame.png" : {"bodyparts" : fake_bp }}
695- kpts , scores = estimator .predict (frames )
745+ kpts , scores , mask = estimator .predict (frames )
696746
697747 assert kpts .shape == (1 , 1 , 26 , 2 )
698748 assert scores .shape == (1 , 1 , 26 )
749+ np .testing .assert_array_equal (mask , np .array ([True ]))
699750 # target[24] ← source[0] → (0*3, 0*3+1) = (0.0, 1.0)
700751 np .testing .assert_allclose (kpts [0 , 0 , 24 ], fake_bp [0 , 0 , :2 ])
0 commit comments