1212
1313import copy
1414from dataclasses import dataclass
15+ from enum import Enum
1516from pathlib import Path
1617from typing import Callable , Sequence , Tuple , Union
1718
@@ -82,7 +83,7 @@ def setup_runtime(self) -> None:
8283
8384 def predict (
8485 self , frames : np .ndarray
85- ) -> Tuple [np .ndarray , np .ndarray ]:
86+ ) -> Tuple [np .ndarray , np .ndarray , np . ndarray ]:
8687 """Estimate 2D keypoints from image frames and return in H36M format.
8788
8889 Parameters
@@ -96,6 +97,9 @@ def predict(
9697 H36M-format 2D keypoints, shape ``(num_persons, N, 17, 2)``.
9798 scores : ndarray
9899 Per-joint confidence scores, shape ``(num_persons, N, 17)``.
100+ valid_frames_mask : ndarray
101+ Boolean mask indicating which frames contain at least one
102+ valid detection, shape ``(N,)``.
99103 """
100104 from fmpose3d .lib .preprocess import h36m_coco_format , revise_kpts
101105
@@ -104,12 +108,70 @@ def predict(
104108 keypoints , scores = self ._model .predict (frames )
105109
106110 keypoints , scores , valid_frames = h36m_coco_format (keypoints , scores )
111+ keypoints , scores = self ._validate_predictions (
112+ keypoints , scores , num_frames = frames .shape [0 ],
113+ )
114+ valid_frames_mask = self ._compute_valid_frames_mask (keypoints , scores )
115+
107116 # NOTE: revise_kpts is computed for consistency but is NOT applied
108117 # to the returned keypoints, matching the demo script behaviour.
109118 _revised = revise_kpts (keypoints , scores , valid_frames ) # noqa: F841
119+ return keypoints , scores , valid_frames_mask
120+
121+ def _validate_predictions (
122+ self ,
123+ keypoints : np .ndarray ,
124+ scores : np .ndarray ,
125+ * ,
126+ num_frames : int ,
127+ ) -> Tuple [np .ndarray , np .ndarray ]:
128+ """Validate and normalise HRNet/H36M predictions."""
129+ num_joints = 17
130+
131+ keypoints = np .asarray (keypoints , dtype = np .float32 )
132+ scores = np .asarray (scores , dtype = np .float32 )
110133
134+ if keypoints .shape [0 ] == 0 :
135+ # h36m_coco_format can drop all persons when all frames are empty.
136+ return (
137+ np .zeros ((1 , num_frames , num_joints , 2 ), dtype = np .float32 ),
138+ np .zeros ((1 , num_frames , num_joints ), dtype = np .float32 ),
139+ )
140+
141+ if keypoints .ndim != 4 or keypoints .shape [- 2 :] != (num_joints , 2 ):
142+ raise ValueError (
143+ f"Invalid HRNet keypoints shape { keypoints .shape } ; "
144+ f"expected (num_persons, num_frames, { num_joints } , 2)."
145+ )
146+ if scores .ndim != 3 or scores .shape [- 1 ] != num_joints :
147+ raise ValueError (
148+ f"Invalid HRNet scores shape { scores .shape } ; "
149+ f"expected (num_persons, num_frames, { num_joints } )."
150+ )
151+ if keypoints .shape [:2 ] != scores .shape [:2 ]:
152+ raise ValueError (
153+ "HRNet keypoints/scores leading dimensions do not match: "
154+ f"{ keypoints .shape [:2 ]} vs { scores .shape [:2 ]} ."
155+ )
156+ if keypoints .shape [1 ] != num_frames :
157+ raise ValueError (
158+ f"HRNet frame count mismatch: got { keypoints .shape [1 ]} , "
159+ f"expected { num_frames } ."
160+ )
111161 return keypoints , scores
112162
163+ @staticmethod
164+ def _compute_valid_frames_mask (
165+ keypoints : np .ndarray , scores : np .ndarray
166+ ) -> np .ndarray :
167+ """Return frame-level validity mask from estimator outputs."""
168+ safe_scores = np .nan_to_num (scores , nan = 0.0 )
169+ has_score = np .any (safe_scores > 0 , axis = - 1 ) # (num_persons, num_frames)
170+
171+ safe_kpts = np .nan_to_num (np .abs (keypoints ), nan = 0.0 )
172+ has_kpt = np .any (safe_kpts > 0 , axis = (- 1 , - 2 )) # (num_persons, num_frames)
173+ return np .any (has_score | has_kpt , axis = 0 )
174+
113175
114176# Quadruped80K → Animal3D (26 keypoints) mapping table.
115177# -1 entries are filled by linear interpolation (see _INTERPOLATION_RULES).
@@ -148,7 +210,7 @@ def setup_runtime(self) -> None:
148210
149211 def predict (
150212 self , frames : np .ndarray
151- ) -> Tuple [np .ndarray , np .ndarray ]:
213+ ) -> Tuple [np .ndarray , np .ndarray , np . ndarray ]:
152214 """Estimate 2D keypoints from image frames in Animal3D format.
153215
154216 The method writes *frames* to a temporary directory, runs
@@ -166,8 +228,11 @@ def predict(
166228 Animal3D-format 2D keypoints, shape ``(1, N, 26, 2)``.
167229 The first axis is always 1 (single individual).
168230 scores : ndarray
169- Placeholder confidence scores (all ones) ,
231+ Mapped per-joint confidence scores,
170232 shape ``(1, N, 26)``.
233+ valid_frames_mask : ndarray
234+ Boolean mask indicating which frames contain at least one
235+ valid detection, shape ``(N,)``.
171236 """
172237 import cv2
173238 import tempfile
@@ -178,6 +243,7 @@ def predict(
178243 cfg = self .cfg
179244 num_frames = frames .shape [0 ]
180245 all_mapped : list [np .ndarray ] = []
246+ all_scores : list [np .ndarray ] = []
181247
182248 with tempfile .TemporaryDirectory () as tmpdir :
183249 # Write each frame as an image so DLC can read it.
@@ -187,8 +253,7 @@ def predict(
187253 cv2 .imwrite (p , frames [idx ])
188254 paths .append (p )
189255
190- # Run DeepLabCut on each frame individually (the API
191- # expects a single image path).
256+ # Run DeepLabCut on each frame individually.
192257 for img_path in paths :
193258 predictions = superanimal_analyze_images (
194259 superanimal_name = cfg .superanimal_name ,
@@ -199,21 +264,33 @@ def predict(
199264 out_folder = tmpdir ,
200265 )
201266 # predictions: {image_path: {"bodyparts": (N_ind, K, 3), ...}}
202- for _path , payload in predictions .items ():
203- bodyparts = payload .get ("bodyparts" )
204- if bodyparts is None :
205- # No detection -- fill with zeros.
206- all_mapped .append (np .zeros ((1 , 26 , 2 ), dtype = "float32" ))
207- continue
208- xy = bodyparts [..., :2 ] # (N_ind, K, 2)
209- mapped = self ._map_keypoints (xy )
210- # Take only the first individual.
211- all_mapped .append (mapped [:1 ])
267+ payload = predictions .get (img_path ) if isinstance (predictions , dict ) else None
268+ if payload is None and isinstance (predictions , dict ) and len (predictions ) == 1 :
269+ payload = next (iter (predictions .values ()))
270+
271+ bodyparts = None if payload is None else payload .get ("bodyparts" )
272+ bodyparts = None if bodyparts is None else np .asarray (bodyparts )
273+ if bodyparts is None or bodyparts .shape [0 ] == 0 :
274+ # No detection -- fill with zeros and zero confidence.
275+ all_mapped .append (np .zeros ((1 , 26 , 2 ), dtype = np .float32 ))
276+ all_scores .append (np .zeros ((1 , 26 ), dtype = np .float32 ))
277+ continue
278+
279+ xy = bodyparts [..., :2 ] # (N_ind, K, 2)
280+ conf = bodyparts [..., 2 ] # (N_ind, K)
281+ mapped = self ._map_keypoints (xy )
282+ mapped_scores = self ._map_scores (conf )
283+
284+ # Take only the first individual.
285+ all_mapped .append (mapped [:1 ])
286+ all_scores .append (mapped_scores [:1 ])
212287
213288 # Stack along frame axis → (1, N, 26, 2)
214289 kpts = np .stack (all_mapped , axis = 1 ) # (1, N, 26, 2)
215- scores = np .ones (kpts .shape [:3 ], dtype = "float32" ) # (1, N, 26)
216- return kpts , scores
290+ scores = np .stack (all_scores , axis = 1 ) # (1, N, 26)
291+ kpts , scores = self ._validate_predictions (kpts , scores , num_frames = num_frames )
292+ valid_frames_mask = self ._compute_valid_frames_mask (kpts , scores )
293+ return kpts , scores , valid_frames_mask
217294
218295 # ------------------------------------------------------------------ #
219296
@@ -247,6 +324,80 @@ def _map_keypoints(xy: np.ndarray) -> np.ndarray:
247324
248325 return mapped
249326
327+ @staticmethod
328+ def _map_scores (conf : np .ndarray ) -> np .ndarray :
329+ """Map confidence scores from quadruped80K to Animal3D layout."""
330+ num_ind , num_src = conf .shape
331+ num_tgt = len (_QUADRUPED80K_TO_ANIMAL3D )
332+ mapped = np .full ((num_ind , num_tgt ), np .nan , dtype = np .float32 )
333+
334+ for tgt_idx , src_idx in enumerate (_QUADRUPED80K_TO_ANIMAL3D ):
335+ if src_idx != - 1 and src_idx < num_src :
336+ mapped [:, tgt_idx ] = conf [:, src_idx ]
337+ elif src_idx == - 1 and tgt_idx in _INTERPOLATION_RULES :
338+ s1 , s2 = _INTERPOLATION_RULES [tgt_idx ]
339+ if s1 < num_src and s2 < num_src :
340+ mapped [:, tgt_idx ] = (conf [:, s1 ] + conf [:, s2 ]) / 2.0
341+
342+ return mapped
343+
344+ def _validate_predictions (
345+ self ,
346+ keypoints : np .ndarray ,
347+ scores : np .ndarray ,
348+ * ,
349+ num_frames : int ,
350+ ) -> Tuple [np .ndarray , np .ndarray ]:
351+ """Validate and normalise SuperAnimal predictions."""
352+ num_joints = 26
353+ keypoints = np .asarray (keypoints , dtype = np .float32 )
354+ scores = np .asarray (scores , dtype = np .float32 )
355+
356+ if keypoints .shape [0 ] == 0 :
357+ return (
358+ np .zeros ((1 , num_frames , num_joints , 2 ), dtype = np .float32 ),
359+ np .zeros ((1 , num_frames , num_joints ), dtype = np .float32 ),
360+ )
361+
362+ if keypoints .ndim != 4 or keypoints .shape [- 2 :] != (num_joints , 2 ):
363+ raise ValueError (
364+ f"Invalid SuperAnimal keypoints shape { keypoints .shape } ; "
365+ f"expected (num_individuals, num_frames, { num_joints } , 2)."
366+ )
367+ if scores .ndim != 3 or scores .shape [- 1 ] != num_joints :
368+ raise ValueError (
369+ f"Invalid SuperAnimal scores shape { scores .shape } ; "
370+ f"expected (num_individuals, num_frames, { num_joints } )."
371+ )
372+ if keypoints .shape [:2 ] != scores .shape [:2 ]:
373+ raise ValueError (
374+ "SuperAnimal keypoints/scores leading dimensions do not match: "
375+ f"{ keypoints .shape [:2 ]} vs { scores .shape [:2 ]} ."
376+ )
377+ if keypoints .shape [1 ] != num_frames :
378+ raise ValueError (
379+ f"SuperAnimal frame count mismatch: got { keypoints .shape [1 ]} , "
380+ f"expected { num_frames } ."
381+ )
382+
383+ # Normalise unknown values to zeros so downstream code can treat these
384+ # joints as invalid via score==0 while retaining shape stability.
385+ keypoints = np .nan_to_num (keypoints , nan = 0.0 )
386+ scores = np .nan_to_num (scores , nan = 0.0 )
387+ return keypoints , scores
388+
389+ @staticmethod
390+ def _compute_valid_frames_mask (
391+ keypoints : np .ndarray , scores : np .ndarray
392+ ) -> np .ndarray :
393+ """Return frame-level validity mask from estimator outputs."""
394+ safe_scores = np .nan_to_num (scores , nan = 0.0 )
395+ has_score = np .any (safe_scores > 0 , axis = - 1 ) # (num_persons, num_frames)
396+
397+ safe_kpts = np .nan_to_num (np .abs (keypoints ), nan = 0.0 )
398+ has_kpt = np .any (safe_kpts > 0 , axis = (- 1 , - 2 )) # (num_persons, num_frames)
399+ return np .any (has_score | has_kpt , axis = 0 )
400+
250401
251402# ---------------------------------------------------------------------------
252403# Limb regularisation (animal post-processing)
@@ -444,6 +595,16 @@ def _default_components(
444595# ---------------------------------------------------------------------------
445596
446597
598+ class Pose2DStatus (str , Enum ):
599+ """High-level status for 2D pose estimation."""
600+
601+ SUCCESS = "success"
602+ PARTIAL = "partial"
603+ EMPTY = "empty"
604+ INVALID = "invalid"
605+ UNKNOWN = "unknown"
606+
607+
447608@dataclass
448609class Pose2DResult :
449610 """Container returned by :meth:`FMPose3DInference.prepare_2d`.
@@ -458,6 +619,43 @@ class Pose2DResult:
458619 """Per-joint confidence scores, shape ``(num_persons, num_frames, J)``."""
459620 image_size : tuple [int , int ] = (0 , 0 )
460621 """``(height, width)`` of the source frames."""
622+ valid_frames_mask : np .ndarray | None = None
623+ """Boolean mask of frames with at least one valid detection, shape ``(N,)``."""
624+
625+ @property
626+ def status (self ) -> Pose2DStatus :
627+ """Prediction status derived from ``valid_frames_mask``."""
628+ return self .get_status_info ()[0 ]
629+
630+ @property
631+ def status_message (self ) -> str :
632+ """Human-readable explanation for :attr:`status`."""
633+ return self .get_status_info ()[1 ]
634+
635+ def get_status_info (self ) -> tuple [Pose2DStatus , str ]:
636+ """Prediction status derived from ``valid_frames_mask``."""
637+ if self .valid_frames_mask is None :
638+ return Pose2DStatus .UNKNOWN , "No frame-validity mask provided by the estimator."
639+ elif not isinstance (self .valid_frames_mask , np .ndarray ) or self .valid_frames_mask .ndim != 1 :
640+ return Pose2DStatus .UNKNOWN , "invalid valid_frames_mask: must be a 1D numpy array."
641+
642+ # Derive expected frame count from canonical shapes.
643+ if self .keypoints .ndim == 4 :
644+ num_frames = int (self .keypoints .shape [1 ])
645+ elif self .scores .ndim == 3 :
646+ num_frames = int (self .scores .shape [1 ])
647+ else :
648+ return Pose2DStatus .INVALID , "Incorrect keypoints/scores dimensions."
649+
650+ if self .valid_frames_mask .shape [0 ] != num_frames :
651+ return Pose2DStatus .INVALID , "valid_frames_mask mismatches the number of frames."
652+
653+ valid_count = int (np .sum (self .valid_frames_mask ))
654+ if valid_count == 0 :
655+ return Pose2DStatus .EMPTY , "No valid predictions in any frame."
656+ if valid_count < num_frames :
657+ return Pose2DStatus .PARTIAL , "Missing predictions in a subset of frames."
658+ return Pose2DStatus .SUCCESS , "Valid predictions for all frames."
461659
462660
463661@dataclass
@@ -690,6 +888,9 @@ def predict(
690888 Root-relative and world-coordinate 3D poses.
691889 """
692890 result_2d = self .prepare_2d (source )
891+ status , status_msg = result_2d .get_status_info ()
892+ if status in {Pose2DStatus .EMPTY , Pose2DStatus .INVALID }:
893+ raise ValueError (f"2D pose estimation is not usable for 3D lifting: { status .value } . { status_msg } " )
693894 return self .pose_3d (
694895 result_2d .keypoints ,
695896 result_2d .image_size ,
@@ -733,13 +934,16 @@ def prepare_2d(
733934 self .setup_runtime ()
734935 if progress :
735936 progress (0 , 1 )
736- keypoints , scores = self ._estimator_2d .predict (ingested .frames )
937+ keypoints , scores , valid_frames_mask = self ._estimator_2d .predict (
938+ ingested .frames
939+ )
737940 if progress :
738941 progress (1 , 1 )
739942 return Pose2DResult (
740943 keypoints = keypoints ,
741944 scores = scores ,
742945 image_size = ingested .image_size ,
946+ valid_frames_mask = valid_frames_mask ,
743947 )
744948
745949 @torch .no_grad ()
0 commit comments