Skip to content

Commit 58d2397

Browse files
committed
update tests and readme: results-validation with valid_frames mask, ResultStatus
1 parent 6edcef8 commit 58d2397

2 files changed

Lines changed: 87 additions & 8 deletions

File tree

fmpose3d/inference_api/README.md

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ Convenience constructor for the **animal** pipeline. Sets `model_type="fmpose3d_
100100
#### `predict(source, *, camera_rotation, seed, progress)` → `Pose3DResult`
101101

102102
End-to-end prediction: 2D estimation followed by 3D lifting in a single call.
103+
Raises `ValueError` when 2D estimation is unusable for lifting
104+
(`Pose2DResult.status` is `empty` or `invalid`).
105+
For partial 2D detections, invalid frames are masked to `NaN` in
106+
`Pose3DResult.poses_3d` and `Pose3DResult.poses_3d_world`.
103107

104108
| Parameter | Type | Description |
105109
|---|---|---|
@@ -121,7 +125,9 @@ Runs only the 2D pose estimation step.
121125
| `source` | `Source` | Same flexible input as `predict()`. |
122126
| `progress` | `ProgressCallback \| None` | Optional progress callback. |
123127
124-
**Returns:** `Pose2DResult` containing `keypoints`, `scores`, and `image_size`.
128+
**Returns:** `Pose2DResult` containing `keypoints`, `scores`, `image_size`,
129+
and `valid_frames_mask`. The object also exposes derived properties
130+
`status` and `status_message`.
125131

126132
---
127133

@@ -168,13 +174,35 @@ Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]]
168174
| `keypoints` | `ndarray` | 2D keypoints, shape `(num_persons, num_frames, J, 2)`. |
169175
| `scores` | `ndarray` | Per-joint confidence, shape `(num_persons, num_frames, J)`. |
170176
| `image_size` | `tuple[int, int]` | `(height, width)` of source frames. |
177+
| `valid_frames_mask` | `ndarray \| None` | Boolean mask, shape `(num_frames,)`, indicating frames with valid detections. |
178+
179+
Computed properties:
180+
181+
- `status``ResultStatus`
182+
- `status_message``str`
183+
184+
#### `ResultStatus`
185+
186+
String enum values:
187+
188+
- `success` — valid detections in all frames
189+
- `partial` — valid detections in a subset of frames
190+
- `empty` — no valid detections in any frame
191+
- `invalid` — output predictions are unusable/malformed
192+
- `unknown` — validity metadata missing or malformed
171193

172194
#### `Pose3DResult`
173195

174196
| Field | Type | Description |
175197
|---|---|---|
176198
| `poses_3d` | `ndarray` | Root-relative 3D poses, shape `(num_frames, J, 3)`. |
177199
| `poses_3d_world` | `ndarray` | Post-processed 3D poses, shape `(num_frames, J, 3)`. For humans: world-coordinate poses. For animals: limb-regularized poses. |
200+
| `valid_frames_mask` | `ndarray \| None` | Boolean mask, shape `(num_frames,)`, indicating frames with valid 3D output. |
201+
202+
Computed properties:
203+
204+
- `status``ResultStatus`
205+
- `status_message``str`
178206

179207

180208

@@ -187,14 +215,14 @@ Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]]
187215
Default 2D estimator for the human pipeline. Wraps HRNet + YOLO with a COCOH36M keypoint conversion.
188216

189217
- `setup_runtime()` — Loads YOLO + HRNet models.
190-
- `predict(frames: ndarray)``(keypoints, scores)` — Returns H36M-format 2D keypoints from BGR frames `(N, H, W, C)`.
218+
- `predict(frames: ndarray)``(keypoints, scores, valid_frames_mask)` — Returns H36M-format 2D keypoints from BGR frames `(N, H, W, C)` plus a frame-level validity mask.
191219

192220
#### `SuperAnimalEstimator(cfg: SuperAnimalConfig | None)`
193221

194222
2D estimator for the animal pipeline. Uses DeepLabCut SuperAnimal and maps quadruped80K keypoints to the 26-joint Animal3D layout.
195223

196224
- `setup_runtime()` — No-op (DLC loads lazily).
197-
- `predict(frames: ndarray)``(keypoints, scores)` — Returns Animal3D-format 2D keypoints from BGR frames.
225+
- `predict(frames: ndarray)``(keypoints, scores, valid_frames_mask)` — Returns Animal3D-format 2D keypoints plus a frame-level validity mask.
198226

199227
---
200228

tests/fmpose3d_api/test_fmpose3d.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
HRNetEstimator,
2626
HumanPostProcessor,
2727
Pose2DResult,
28+
ResultStatus,
2829
Pose3DResult,
2930
SuperAnimalEstimator,
3031
_default_components,
@@ -615,8 +616,9 @@ def test_predict_end_to_end_with_mock_estimator(self):
615616

616617
mock_kpts = np.random.randn(1, 1, 17, 2).astype("float32")
617618
mock_scores = np.ones((1, 1, 17), dtype="float32")
619+
mock_mask = np.array([True], dtype=bool)
618620
api._estimator_2d = MagicMock()
619-
api._estimator_2d.predict.return_value = (mock_kpts, mock_scores)
621+
api._estimator_2d.predict.return_value = (mock_kpts, mock_scores, mock_mask)
620622
api._estimator_2d.setup_runtime = MagicMock()
621623

622624
frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
@@ -626,6 +628,24 @@ def test_predict_end_to_end_with_mock_estimator(self):
626628
assert result.poses_3d.shape == (1, 17, 3)
627629
api._estimator_2d.predict.assert_called_once()
628630

631+
def test_predict_applies_partial_2d_mask_to_3d(self):
632+
"""predict() masks invalid 2D frames to NaN in 3D outputs."""
633+
api = _make_ready_api("fmpose3d_humans", test_augmentation=False)
634+
mock_kpts = np.random.randn(1, 3, 17, 2).astype("float32")
635+
mock_scores = np.ones((1, 3, 17), dtype="float32")
636+
mask = np.array([True, False, True], dtype=bool)
637+
api._estimator_2d = MagicMock()
638+
api._estimator_2d.predict.return_value = (mock_kpts, mock_scores, mask)
639+
api._estimator_2d.setup_runtime = MagicMock()
640+
641+
frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
642+
result = api.predict([frame, frame, frame], seed=42)
643+
644+
np.testing.assert_array_equal(result.valid_frames_mask, mask)
645+
assert result.status == ResultStatus.PARTIAL
646+
assert np.all(np.isnan(result.poses_3d[1]))
647+
assert np.all(np.isnan(result.poses_3d_world[1]))
648+
629649

630650
# =========================================================================
631651
# Unit tests — dataclasses
@@ -648,12 +668,41 @@ def test_pose2d_result_default_image_size(self):
648668
)
649669
assert result.image_size == (0, 0)
650670

671+
def test_pose2d_status_success(self):
672+
result = Pose2DResult(
673+
keypoints=np.zeros((1, 2, 17, 2)),
674+
scores=np.zeros((1, 2, 17)),
675+
valid_frames_mask=np.array([True, True], dtype=bool),
676+
)
677+
assert result.status == ResultStatus.SUCCESS
678+
assert "all frames" in result.status_message
679+
680+
def test_pose2d_status_partial(self):
681+
result = Pose2DResult(
682+
keypoints=np.zeros((1, 2, 17, 2)),
683+
scores=np.zeros((1, 2, 17)),
684+
valid_frames_mask=np.array([True, False], dtype=bool),
685+
)
686+
assert result.status == ResultStatus.PARTIAL
687+
assert "subset" in result.status_message
688+
689+
def test_pose2d_status_invalid_mask_length(self):
690+
result = Pose2DResult(
691+
keypoints=np.zeros((1, 2, 17, 2)),
692+
scores=np.zeros((1, 2, 17)),
693+
valid_frames_mask=np.array([True], dtype=bool),
694+
)
695+
assert result.status == ResultStatus.INVALID
696+
assert "mismatches" in result.status_message
697+
651698
def test_pose3d_result(self):
652699
p3d = np.random.randn(10, 17, 3)
653700
pw = np.random.randn(10, 17, 3)
654-
result = Pose3DResult(poses_3d=p3d, poses_3d_world=pw)
701+
mask = np.ones((10,), dtype=bool)
702+
result = Pose3DResult(poses_3d=p3d, poses_3d_world=pw, valid_frames_mask=mask)
655703
assert result.poses_3d is p3d
656704
assert result.poses_3d_world is pw
705+
assert result.status == ResultStatus.SUCCESS
657706

658707

659708
# =========================================================================
@@ -672,12 +721,13 @@ def test_predict_returns_zeros_when_no_bodyparts(self):
672721
"deeplabcut.pose_estimation_pytorch.apis.superanimal_analyze_images",
673722
) as mock_fn:
674723
mock_fn.return_value = {"frame.png": {"bodyparts": None}}
675-
kpts, scores = estimator.predict(frames)
724+
kpts, scores, mask = estimator.predict(frames)
676725

677726
assert kpts.shape == (1, 2, 26, 2)
678727
np.testing.assert_allclose(kpts, 0.0)
679728
assert scores.shape == (1, 2, 26)
680-
np.testing.assert_allclose(scores, 1.0)
729+
np.testing.assert_allclose(scores, 0.0)
730+
np.testing.assert_array_equal(mask, np.array([False, False]))
681731

682732
def test_predict_maps_valid_bodyparts(self):
683733
"""Valid DLC bodyparts are mapped to Animal3D layout."""
@@ -692,9 +742,10 @@ def test_predict_maps_valid_bodyparts(self):
692742
"deeplabcut.pose_estimation_pytorch.apis.superanimal_analyze_images",
693743
) as mock_fn:
694744
mock_fn.return_value = {"frame.png": {"bodyparts": fake_bp}}
695-
kpts, scores = estimator.predict(frames)
745+
kpts, scores, mask = estimator.predict(frames)
696746

697747
assert kpts.shape == (1, 1, 26, 2)
698748
assert scores.shape == (1, 1, 26)
749+
np.testing.assert_array_equal(mask, np.array([True]))
699750
# target[24] ← source[0] → (0*3, 0*3+1) = (0.0, 1.0)
700751
np.testing.assert_allclose(kpts[0, 0, 24], fake_bp[0, 0, :2])

0 commit comments

Comments
 (0)