Skip to content

Commit 6edcef8

Browse files
committed
Inference api: add/propagate validation for 3d predictions
- results for missing 2d predicitons are set to NaN - Pose3DResults now contain status and valid_frames_mask
1 parent 4193b92 commit 6edcef8

1 file changed

Lines changed: 57 additions & 16 deletions

File tree

fmpose3d/inference_api/fmpose3d.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,8 @@ def _default_components(
595595
# ---------------------------------------------------------------------------
596596

597597

598-
class Pose2DStatus(str, Enum):
599-
"""High-level status for 2D pose estimation."""
598+
class ResultStatus(str, Enum):
599+
"""High-level status for pose estimation outputs."""
600600

601601
SUCCESS = "success"
602602
PARTIAL = "partial"
@@ -623,7 +623,7 @@ class Pose2DResult:
623623
"""Boolean mask of frames with at least one valid detection, shape ``(N,)``."""
624624

625625
@property
626-
def status(self) -> Pose2DStatus:
626+
def status(self) -> ResultStatus:
627627
"""Prediction status derived from ``valid_frames_mask``."""
628628
return self.get_status_info()[0]
629629

@@ -632,30 +632,29 @@ def status_message(self) -> str:
632632
"""Human-readable explanation for :attr:`status`."""
633633
return self.get_status_info()[1]
634634

635-
def get_status_info(self) -> tuple[Pose2DStatus, str]:
635+
def get_status_info(self) -> tuple[ResultStatus, str]:
636636
"""Prediction status derived from ``valid_frames_mask``."""
637-
if self.valid_frames_mask is None:
638-
return Pose2DStatus.UNKNOWN, "No frame-validity mask provided by the estimator."
639-
elif not isinstance(self.valid_frames_mask, np.ndarray) or self.valid_frames_mask.ndim != 1:
640-
return Pose2DStatus.UNKNOWN, "invalid valid_frames_mask: must be a 1D numpy array."
641-
642637
# Derive expected frame count from canonical shapes.
643638
if self.keypoints.ndim == 4:
644639
num_frames = int(self.keypoints.shape[1])
645640
elif self.scores.ndim == 3:
646641
num_frames = int(self.scores.shape[1])
647642
else:
648-
return Pose2DStatus.INVALID, "Incorrect keypoints/scores dimensions."
643+
return ResultStatus.INVALID, "Incorrect 2D pose keypoints/scores dimensions."
649644

645+
if self.valid_frames_mask is None:
646+
return ResultStatus.UNKNOWN, "No frame-validity mask provided by the 2D pose."
647+
if not isinstance(self.valid_frames_mask, np.ndarray) or self.valid_frames_mask.ndim != 1:
648+
return ResultStatus.UNKNOWN, "invalid 2D pose valid_frames_mask: must be a 1D numpy array."
650649
if self.valid_frames_mask.shape[0] != num_frames:
651-
return Pose2DStatus.INVALID, "valid_frames_mask mismatches the number of frames."
650+
return ResultStatus.INVALID, "2D pose valid_frames_mask mismatches the number of frames."
652651

653652
valid_count = int(np.sum(self.valid_frames_mask))
654653
if valid_count == 0:
655-
return Pose2DStatus.EMPTY, "No valid predictions in any frame."
654+
return ResultStatus.EMPTY, "No valid 2D pose predictions in any frame."
656655
if valid_count < num_frames:
657-
return Pose2DStatus.PARTIAL, "Missing predictions in a subset of frames."
658-
return Pose2DStatus.SUCCESS, "Valid predictions for all frames."
656+
return ResultStatus.PARTIAL, "Missing 2D pose predictions in a subset of frames."
657+
return ResultStatus.SUCCESS, "Valid 2D pose predictions for all frames."
659658

660659

661660
@dataclass
@@ -675,6 +674,40 @@ class Pose3DResult:
675674
``camera_to_world``). For animal poses this contains the
676675
limb-regularised output.
677676
"""
677+
valid_frames_mask: np.ndarray | None = None
678+
"""Boolean mask of frames with valid 3D poses, shape ``(num_frames,)``."""
679+
680+
@property
681+
def status(self) -> ResultStatus:
682+
"""Prediction status derived from ``valid_frames_mask``."""
683+
return self.get_status_info()[0]
684+
685+
@property
686+
def status_message(self) -> str:
687+
"""Human-readable explanation for :attr:`status`."""
688+
return self.get_status_info()[1]
689+
690+
def get_status_info(self) -> tuple[ResultStatus, str]:
691+
"""Prediction status derived from ``valid_frames_mask``."""
692+
if self.poses_3d.ndim != 3 or self.poses_3d_world.ndim != 3:
693+
return ResultStatus.INVALID, "Incorrect 3D result dimensions."
694+
num_frames = int(self.poses_3d.shape[0])
695+
if self.poses_3d_world.shape[0] != num_frames:
696+
return ResultStatus.INVALID, "poses_3d and poses_3d_world frame counts differ."
697+
698+
if self.valid_frames_mask is None:
699+
return ResultStatus.UNKNOWN, "No frame-validity mask provided by the 3D pose."
700+
if not isinstance(self.valid_frames_mask, np.ndarray) or self.valid_frames_mask.ndim != 1:
701+
return ResultStatus.UNKNOWN, "invalid 3D pose valid_frames_mask: must be a 1D numpy array."
702+
if self.valid_frames_mask.shape[0] != num_frames:
703+
return ResultStatus.INVALID, "3D pose valid_frames_mask mismatches the number of frames."
704+
705+
valid_count = int(np.sum(self.valid_frames_mask))
706+
if valid_count == 0:
707+
return ResultStatus.EMPTY, "No valid 3D pose predictions in any frame."
708+
if valid_count < num_frames:
709+
return ResultStatus.PARTIAL, "Missing 3D pose predictions in a subset of frames."
710+
return ResultStatus.SUCCESS, "Valid 3D pose predictions for all frames."
678711

679712

680713
#: Accepted source types for :meth:`FMPose3DInference.predict`.
@@ -889,15 +922,23 @@ def predict(
889922
"""
890923
result_2d = self.prepare_2d(source)
891924
status, status_msg = result_2d.get_status_info()
892-
if status in {Pose2DStatus.EMPTY, Pose2DStatus.INVALID}:
925+
if status in {ResultStatus.EMPTY, ResultStatus.INVALID}:
893926
raise ValueError(f"2D pose estimation is not usable for 3D lifting: {status.value}. {status_msg}")
894-
return self.pose_3d(
927+
result_3d = self.pose_3d(
895928
result_2d.keypoints,
896929
result_2d.image_size,
897930
camera_rotation=camera_rotation,
898931
seed=seed,
899932
progress=progress,
900933
)
934+
mask = result_2d.valid_frames_mask
935+
if mask is not None:
936+
invalid = ~mask
937+
if np.any(invalid):
938+
result_3d.poses_3d[invalid] = np.nan
939+
result_3d.poses_3d_world[invalid] = np.nan
940+
result_3d.valid_frames_mask = mask
941+
return result_3d
901942

902943
@torch.no_grad()
903944
def prepare_2d(

0 commit comments

Comments
 (0)