Skip to content

Commit 4193b92

Browse files
committed
Update inference api: Add validation for 2d pose predictions.
- add a _validate_predictions method for the Estimator2D's - add return parameter valid_frames_mask for Estimator2D's - Add status/valid_frames_mask fields for Pose2DResults - Inference api: predict raises ValueError for invalid / empty pose2d results.
1 parent a51661e commit 4193b92

1 file changed

Lines changed: 222 additions & 18 deletions

File tree

fmpose3d/inference_api/fmpose3d.py

Lines changed: 222 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import copy
1414
from dataclasses import dataclass
15+
from enum import Enum
1516
from pathlib import Path
1617
from typing import Callable, Sequence, Tuple, Union
1718

@@ -82,7 +83,7 @@ def setup_runtime(self) -> None:
8283

8384
def predict(
8485
self, frames: np.ndarray
85-
) -> Tuple[np.ndarray, np.ndarray]:
86+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
8687
"""Estimate 2D keypoints from image frames and return in H36M format.
8788
8889
Parameters
@@ -96,6 +97,9 @@ def predict(
9697
H36M-format 2D keypoints, shape ``(num_persons, N, 17, 2)``.
9798
scores : ndarray
9899
Per-joint confidence scores, shape ``(num_persons, N, 17)``.
100+
valid_frames_mask : ndarray
101+
Boolean mask indicating which frames contain at least one
102+
valid detection, shape ``(N,)``.
99103
"""
100104
from fmpose3d.lib.preprocess import h36m_coco_format, revise_kpts
101105

@@ -104,12 +108,70 @@ def predict(
104108
keypoints, scores = self._model.predict(frames)
105109

106110
keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores)
111+
keypoints, scores = self._validate_predictions(
112+
keypoints, scores, num_frames=frames.shape[0],
113+
)
114+
valid_frames_mask = self._compute_valid_frames_mask(keypoints, scores)
115+
107116
# NOTE: revise_kpts is computed for consistency but is NOT applied
108117
# to the returned keypoints, matching the demo script behaviour.
109118
_revised = revise_kpts(keypoints, scores, valid_frames) # noqa: F841
119+
return keypoints, scores, valid_frames_mask
120+
121+
def _validate_predictions(
122+
self,
123+
keypoints: np.ndarray,
124+
scores: np.ndarray,
125+
*,
126+
num_frames: int,
127+
) -> Tuple[np.ndarray, np.ndarray]:
128+
"""Validate and normalise HRNet/H36M predictions."""
129+
num_joints = 17
130+
131+
keypoints = np.asarray(keypoints, dtype=np.float32)
132+
scores = np.asarray(scores, dtype=np.float32)
110133

134+
if keypoints.shape[0] == 0:
135+
# h36m_coco_format can drop all persons when all frames are empty.
136+
return (
137+
np.zeros((1, num_frames, num_joints, 2), dtype=np.float32),
138+
np.zeros((1, num_frames, num_joints), dtype=np.float32),
139+
)
140+
141+
if keypoints.ndim != 4 or keypoints.shape[-2:] != (num_joints, 2):
142+
raise ValueError(
143+
f"Invalid HRNet keypoints shape {keypoints.shape}; "
144+
f"expected (num_persons, num_frames, {num_joints}, 2)."
145+
)
146+
if scores.ndim != 3 or scores.shape[-1] != num_joints:
147+
raise ValueError(
148+
f"Invalid HRNet scores shape {scores.shape}; "
149+
f"expected (num_persons, num_frames, {num_joints})."
150+
)
151+
if keypoints.shape[:2] != scores.shape[:2]:
152+
raise ValueError(
153+
"HRNet keypoints/scores leading dimensions do not match: "
154+
f"{keypoints.shape[:2]} vs {scores.shape[:2]}."
155+
)
156+
if keypoints.shape[1] != num_frames:
157+
raise ValueError(
158+
f"HRNet frame count mismatch: got {keypoints.shape[1]}, "
159+
f"expected {num_frames}."
160+
)
111161
return keypoints, scores
112162

163+
@staticmethod
164+
def _compute_valid_frames_mask(
165+
keypoints: np.ndarray, scores: np.ndarray
166+
) -> np.ndarray:
167+
"""Return frame-level validity mask from estimator outputs."""
168+
safe_scores = np.nan_to_num(scores, nan=0.0)
169+
has_score = np.any(safe_scores > 0, axis=-1) # (num_persons, num_frames)
170+
171+
safe_kpts = np.nan_to_num(np.abs(keypoints), nan=0.0)
172+
has_kpt = np.any(safe_kpts > 0, axis=(-1, -2)) # (num_persons, num_frames)
173+
return np.any(has_score | has_kpt, axis=0)
174+
113175

114176
# Quadruped80K → Animal3D (26 keypoints) mapping table.
115177
# -1 entries are filled by linear interpolation (see _INTERPOLATION_RULES).
@@ -148,7 +210,7 @@ def setup_runtime(self) -> None:
148210

149211
def predict(
150212
self, frames: np.ndarray
151-
) -> Tuple[np.ndarray, np.ndarray]:
213+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
152214
"""Estimate 2D keypoints from image frames in Animal3D format.
153215
154216
The method writes *frames* to a temporary directory, runs
@@ -166,8 +228,11 @@ def predict(
166228
Animal3D-format 2D keypoints, shape ``(1, N, 26, 2)``.
167229
The first axis is always 1 (single individual).
168230
scores : ndarray
169-
Placeholder confidence scores (all ones),
231+
Mapped per-joint confidence scores,
170232
shape ``(1, N, 26)``.
233+
valid_frames_mask : ndarray
234+
Boolean mask indicating which frames contain at least one
235+
valid detection, shape ``(N,)``.
171236
"""
172237
import cv2
173238
import tempfile
@@ -178,6 +243,7 @@ def predict(
178243
cfg = self.cfg
179244
num_frames = frames.shape[0]
180245
all_mapped: list[np.ndarray] = []
246+
all_scores: list[np.ndarray] = []
181247

182248
with tempfile.TemporaryDirectory() as tmpdir:
183249
# Write each frame as an image so DLC can read it.
@@ -187,8 +253,7 @@ def predict(
187253
cv2.imwrite(p, frames[idx])
188254
paths.append(p)
189255

190-
# Run DeepLabCut on each frame individually (the API
191-
# expects a single image path).
256+
# Run DeepLabCut on each frame individually.
192257
for img_path in paths:
193258
predictions = superanimal_analyze_images(
194259
superanimal_name=cfg.superanimal_name,
@@ -199,21 +264,33 @@ def predict(
199264
out_folder=tmpdir,
200265
)
201266
# predictions: {image_path: {"bodyparts": (N_ind, K, 3), ...}}
202-
for _path, payload in predictions.items():
203-
bodyparts = payload.get("bodyparts")
204-
if bodyparts is None:
205-
# No detection -- fill with zeros.
206-
all_mapped.append(np.zeros((1, 26, 2), dtype="float32"))
207-
continue
208-
xy = bodyparts[..., :2] # (N_ind, K, 2)
209-
mapped = self._map_keypoints(xy)
210-
# Take only the first individual.
211-
all_mapped.append(mapped[:1])
267+
payload = predictions.get(img_path) if isinstance(predictions, dict) else None
268+
if payload is None and isinstance(predictions, dict) and len(predictions) == 1:
269+
payload = next(iter(predictions.values()))
270+
271+
bodyparts = None if payload is None else payload.get("bodyparts")
272+
bodyparts = None if bodyparts is None else np.asarray(bodyparts)
273+
if bodyparts is None or bodyparts.shape[0] == 0:
274+
# No detection -- fill with zeros and zero confidence.
275+
all_mapped.append(np.zeros((1, 26, 2), dtype=np.float32))
276+
all_scores.append(np.zeros((1, 26), dtype=np.float32))
277+
continue
278+
279+
xy = bodyparts[..., :2] # (N_ind, K, 2)
280+
conf = bodyparts[..., 2] # (N_ind, K)
281+
mapped = self._map_keypoints(xy)
282+
mapped_scores = self._map_scores(conf)
283+
284+
# Take only the first individual.
285+
all_mapped.append(mapped[:1])
286+
all_scores.append(mapped_scores[:1])
212287

213288
# Stack along frame axis → (1, N, 26, 2)
214289
kpts = np.stack(all_mapped, axis=1) # (1, N, 26, 2)
215-
scores = np.ones(kpts.shape[:3], dtype="float32") # (1, N, 26)
216-
return kpts, scores
290+
scores = np.stack(all_scores, axis=1) # (1, N, 26)
291+
kpts, scores = self._validate_predictions(kpts, scores, num_frames=num_frames)
292+
valid_frames_mask = self._compute_valid_frames_mask(kpts, scores)
293+
return kpts, scores, valid_frames_mask
217294

218295
# ------------------------------------------------------------------ #
219296

@@ -247,6 +324,80 @@ def _map_keypoints(xy: np.ndarray) -> np.ndarray:
247324

248325
return mapped
249326

327+
@staticmethod
328+
def _map_scores(conf: np.ndarray) -> np.ndarray:
329+
"""Map confidence scores from quadruped80K to Animal3D layout."""
330+
num_ind, num_src = conf.shape
331+
num_tgt = len(_QUADRUPED80K_TO_ANIMAL3D)
332+
mapped = np.full((num_ind, num_tgt), np.nan, dtype=np.float32)
333+
334+
for tgt_idx, src_idx in enumerate(_QUADRUPED80K_TO_ANIMAL3D):
335+
if src_idx != -1 and src_idx < num_src:
336+
mapped[:, tgt_idx] = conf[:, src_idx]
337+
elif src_idx == -1 and tgt_idx in _INTERPOLATION_RULES:
338+
s1, s2 = _INTERPOLATION_RULES[tgt_idx]
339+
if s1 < num_src and s2 < num_src:
340+
mapped[:, tgt_idx] = (conf[:, s1] + conf[:, s2]) / 2.0
341+
342+
return mapped
343+
344+
def _validate_predictions(
345+
self,
346+
keypoints: np.ndarray,
347+
scores: np.ndarray,
348+
*,
349+
num_frames: int,
350+
) -> Tuple[np.ndarray, np.ndarray]:
351+
"""Validate and normalise SuperAnimal predictions."""
352+
num_joints = 26
353+
keypoints = np.asarray(keypoints, dtype=np.float32)
354+
scores = np.asarray(scores, dtype=np.float32)
355+
356+
if keypoints.shape[0] == 0:
357+
return (
358+
np.zeros((1, num_frames, num_joints, 2), dtype=np.float32),
359+
np.zeros((1, num_frames, num_joints), dtype=np.float32),
360+
)
361+
362+
if keypoints.ndim != 4 or keypoints.shape[-2:] != (num_joints, 2):
363+
raise ValueError(
364+
f"Invalid SuperAnimal keypoints shape {keypoints.shape}; "
365+
f"expected (num_individuals, num_frames, {num_joints}, 2)."
366+
)
367+
if scores.ndim != 3 or scores.shape[-1] != num_joints:
368+
raise ValueError(
369+
f"Invalid SuperAnimal scores shape {scores.shape}; "
370+
f"expected (num_individuals, num_frames, {num_joints})."
371+
)
372+
if keypoints.shape[:2] != scores.shape[:2]:
373+
raise ValueError(
374+
"SuperAnimal keypoints/scores leading dimensions do not match: "
375+
f"{keypoints.shape[:2]} vs {scores.shape[:2]}."
376+
)
377+
if keypoints.shape[1] != num_frames:
378+
raise ValueError(
379+
f"SuperAnimal frame count mismatch: got {keypoints.shape[1]}, "
380+
f"expected {num_frames}."
381+
)
382+
383+
# Normalise unknown values to zeros so downstream code can treat these
384+
# joints as invalid via score==0 while retaining shape stability.
385+
keypoints = np.nan_to_num(keypoints, nan=0.0)
386+
scores = np.nan_to_num(scores, nan=0.0)
387+
return keypoints, scores
388+
389+
@staticmethod
390+
def _compute_valid_frames_mask(
391+
keypoints: np.ndarray, scores: np.ndarray
392+
) -> np.ndarray:
393+
"""Return frame-level validity mask from estimator outputs."""
394+
safe_scores = np.nan_to_num(scores, nan=0.0)
395+
has_score = np.any(safe_scores > 0, axis=-1) # (num_persons, num_frames)
396+
397+
safe_kpts = np.nan_to_num(np.abs(keypoints), nan=0.0)
398+
has_kpt = np.any(safe_kpts > 0, axis=(-1, -2)) # (num_persons, num_frames)
399+
return np.any(has_score | has_kpt, axis=0)
400+
250401

251402
# ---------------------------------------------------------------------------
252403
# Limb regularisation (animal post-processing)
@@ -444,6 +595,16 @@ def _default_components(
444595
# ---------------------------------------------------------------------------
445596

446597

598+
class Pose2DStatus(str, Enum):
599+
"""High-level status for 2D pose estimation."""
600+
601+
SUCCESS = "success"
602+
PARTIAL = "partial"
603+
EMPTY = "empty"
604+
INVALID = "invalid"
605+
UNKNOWN = "unknown"
606+
607+
447608
@dataclass
448609
class Pose2DResult:
449610
"""Container returned by :meth:`FMPose3DInference.prepare_2d`.
@@ -458,6 +619,43 @@ class Pose2DResult:
458619
"""Per-joint confidence scores, shape ``(num_persons, num_frames, J)``."""
459620
image_size: tuple[int, int] = (0, 0)
460621
"""``(height, width)`` of the source frames."""
622+
valid_frames_mask: np.ndarray | None = None
623+
"""Boolean mask of frames with at least one valid detection, shape ``(N,)``."""
624+
625+
@property
626+
def status(self) -> Pose2DStatus:
627+
"""Prediction status derived from ``valid_frames_mask``."""
628+
return self.get_status_info()[0]
629+
630+
@property
631+
def status_message(self) -> str:
632+
"""Human-readable explanation for :attr:`status`."""
633+
return self.get_status_info()[1]
634+
635+
def get_status_info(self) -> tuple[Pose2DStatus, str]:
636+
"""Prediction status derived from ``valid_frames_mask``."""
637+
if self.valid_frames_mask is None:
638+
return Pose2DStatus.UNKNOWN, "No frame-validity mask provided by the estimator."
639+
elif not isinstance(self.valid_frames_mask, np.ndarray) or self.valid_frames_mask.ndim != 1:
640+
return Pose2DStatus.UNKNOWN, "invalid valid_frames_mask: must be a 1D numpy array."
641+
642+
# Derive expected frame count from canonical shapes.
643+
if self.keypoints.ndim == 4:
644+
num_frames = int(self.keypoints.shape[1])
645+
elif self.scores.ndim == 3:
646+
num_frames = int(self.scores.shape[1])
647+
else:
648+
return Pose2DStatus.INVALID, "Incorrect keypoints/scores dimensions."
649+
650+
if self.valid_frames_mask.shape[0] != num_frames:
651+
return Pose2DStatus.INVALID, "valid_frames_mask mismatches the number of frames."
652+
653+
valid_count = int(np.sum(self.valid_frames_mask))
654+
if valid_count == 0:
655+
return Pose2DStatus.EMPTY, "No valid predictions in any frame."
656+
if valid_count < num_frames:
657+
return Pose2DStatus.PARTIAL, "Missing predictions in a subset of frames."
658+
return Pose2DStatus.SUCCESS, "Valid predictions for all frames."
461659

462660

463661
@dataclass
@@ -690,6 +888,9 @@ def predict(
690888
Root-relative and world-coordinate 3D poses.
691889
"""
692890
result_2d = self.prepare_2d(source)
891+
status, status_msg = result_2d.get_status_info()
892+
if status in {Pose2DStatus.EMPTY, Pose2DStatus.INVALID}:
893+
raise ValueError(f"2D pose estimation is not usable for 3D lifting: {status.value}. {status_msg}")
693894
return self.pose_3d(
694895
result_2d.keypoints,
695896
result_2d.image_size,
@@ -733,13 +934,16 @@ def prepare_2d(
733934
self.setup_runtime()
734935
if progress:
735936
progress(0, 1)
736-
keypoints, scores = self._estimator_2d.predict(ingested.frames)
937+
keypoints, scores, valid_frames_mask = self._estimator_2d.predict(
938+
ingested.frames
939+
)
737940
if progress:
738941
progress(1, 1)
739942
return Pose2DResult(
740943
keypoints=keypoints,
741944
scores=scores,
742945
image_size=ingested.image_size,
946+
valid_frames_mask=valid_frames_mask,
743947
)
744948

745949
@torch.no_grad()

0 commit comments

Comments
 (0)