@@ -595,8 +595,8 @@ def _default_components(
595595# ---------------------------------------------------------------------------
596596
597597
598- class Pose2DStatus (str , Enum ):
599- """High-level status for 2D pose estimation."""
598+ class ResultStatus (str , Enum ):
599+ """High-level status for pose estimation outputs ."""
600600
601601 SUCCESS = "success"
602602 PARTIAL = "partial"
@@ -623,7 +623,7 @@ class Pose2DResult:
623623 """Boolean mask of frames with at least one valid detection, shape ``(N,)``."""
624624
625625 @property
626- def status (self ) -> Pose2DStatus :
626+ def status (self ) -> ResultStatus :
627627 """Prediction status derived from ``valid_frames_mask``."""
628628 return self .get_status_info ()[0 ]
629629
@@ -632,30 +632,29 @@ def status_message(self) -> str:
632632 """Human-readable explanation for :attr:`status`."""
633633 return self .get_status_info ()[1 ]
634634
635- def get_status_info (self ) -> tuple [Pose2DStatus , str ]:
635+ def get_status_info (self ) -> tuple [ResultStatus , str ]:
636636 """Prediction status derived from ``valid_frames_mask``."""
637- if self .valid_frames_mask is None :
638- return Pose2DStatus .UNKNOWN , "No frame-validity mask provided by the estimator."
639- elif not isinstance (self .valid_frames_mask , np .ndarray ) or self .valid_frames_mask .ndim != 1 :
640- return Pose2DStatus .UNKNOWN , "invalid valid_frames_mask: must be a 1D numpy array."
641-
642637 # Derive expected frame count from canonical shapes.
643638 if self .keypoints .ndim == 4 :
644639 num_frames = int (self .keypoints .shape [1 ])
645640 elif self .scores .ndim == 3 :
646641 num_frames = int (self .scores .shape [1 ])
647642 else :
648- return Pose2DStatus .INVALID , "Incorrect keypoints/scores dimensions."
643+ return ResultStatus .INVALID , "Incorrect 2D pose keypoints/scores dimensions."
649644
645+ if self .valid_frames_mask is None :
646+ return ResultStatus .UNKNOWN , "No frame-validity mask provided by the 2D pose."
647+ if not isinstance (self .valid_frames_mask , np .ndarray ) or self .valid_frames_mask .ndim != 1 :
648+ return ResultStatus .UNKNOWN , "invalid 2D pose valid_frames_mask: must be a 1D numpy array."
650649 if self .valid_frames_mask .shape [0 ] != num_frames :
651- return Pose2DStatus .INVALID , "valid_frames_mask mismatches the number of frames."
650+ return ResultStatus .INVALID , "2D pose valid_frames_mask mismatches the number of frames."
652651
653652 valid_count = int (np .sum (self .valid_frames_mask ))
654653 if valid_count == 0 :
655- return Pose2DStatus .EMPTY , "No valid predictions in any frame."
654+ return ResultStatus .EMPTY , "No valid 2D pose predictions in any frame."
656655 if valid_count < num_frames :
657- return Pose2DStatus .PARTIAL , "Missing predictions in a subset of frames."
658- return Pose2DStatus .SUCCESS , "Valid predictions for all frames."
656+ return ResultStatus .PARTIAL , "Missing 2D pose predictions in a subset of frames."
657+ return ResultStatus .SUCCESS , "Valid 2D pose predictions for all frames."
659658
660659
661660@dataclass
@@ -675,6 +674,40 @@ class Pose3DResult:
675674 ``camera_to_world``). For animal poses this contains the
676675 limb-regularised output.
677676 """
677+ valid_frames_mask : np .ndarray | None = None
678+ """Boolean mask of frames with valid 3D poses, shape ``(num_frames,)``."""
679+
680+ @property
681+ def status (self ) -> ResultStatus :
682+ """Prediction status derived from ``valid_frames_mask``."""
683+ return self .get_status_info ()[0 ]
684+
685+ @property
686+ def status_message (self ) -> str :
687+ """Human-readable explanation for :attr:`status`."""
688+ return self .get_status_info ()[1 ]
689+
690+ def get_status_info (self ) -> tuple [ResultStatus , str ]:
691+ """Prediction status derived from ``valid_frames_mask``."""
692+ if self .poses_3d .ndim != 3 or self .poses_3d_world .ndim != 3 :
693+ return ResultStatus .INVALID , "Incorrect 3D result dimensions."
694+ num_frames = int (self .poses_3d .shape [0 ])
695+ if self .poses_3d_world .shape [0 ] != num_frames :
696+ return ResultStatus .INVALID , "poses_3d and poses_3d_world frame counts differ."
697+
698+ if self .valid_frames_mask is None :
699+ return ResultStatus .UNKNOWN , "No frame-validity mask provided by the 3D pose."
700+ if not isinstance (self .valid_frames_mask , np .ndarray ) or self .valid_frames_mask .ndim != 1 :
701+ return ResultStatus .UNKNOWN , "invalid 3D pose valid_frames_mask: must be a 1D numpy array."
702+ if self .valid_frames_mask .shape [0 ] != num_frames :
703+ return ResultStatus .INVALID , "3D pose valid_frames_mask mismatches the number of frames."
704+
705+ valid_count = int (np .sum (self .valid_frames_mask ))
706+ if valid_count == 0 :
707+ return ResultStatus .EMPTY , "No valid 3D pose predictions in any frame."
708+ if valid_count < num_frames :
709+ return ResultStatus .PARTIAL , "Missing 3D pose predictions in a subset of frames."
710+ return ResultStatus .SUCCESS , "Valid 3D pose predictions for all frames."
678711
679712
680713#: Accepted source types for :meth:`FMPose3DInference.predict`.
@@ -889,15 +922,23 @@ def predict(
889922 """
890923 result_2d = self .prepare_2d (source )
891924 status , status_msg = result_2d .get_status_info ()
892- if status in {Pose2DStatus .EMPTY , Pose2DStatus .INVALID }:
925+ if status in {ResultStatus .EMPTY , ResultStatus .INVALID }:
893926 raise ValueError (f"2D pose estimation is not usable for 3D lifting: { status .value } . { status_msg } " )
894- return self .pose_3d (
927+ result_3d = self .pose_3d (
895928 result_2d .keypoints ,
896929 result_2d .image_size ,
897930 camera_rotation = camera_rotation ,
898931 seed = seed ,
899932 progress = progress ,
900933 )
934+ mask = result_2d .valid_frames_mask
935+ if mask is not None :
936+ invalid = ~ mask
937+ if np .any (invalid ):
938+ result_3d .poses_3d [invalid ] = np .nan
939+ result_3d .poses_3d_world [invalid ] = np .nan
940+ result_3d .valid_frames_mask = mask
941+ return result_3d
901942
902943 @torch .no_grad ()
903944 def prepare_2d (
0 commit comments