Skip to content

Commit 769f3de

Browse files
authored
Merge pull request #18 from AdaptiveMotorControlLab/feat/add_huggingface_functionality
Add huggingface functionality
2 parents ec9d855 + c9d3971 commit 769f3de

13 files changed

Lines changed: 868 additions & 239 deletions

File tree

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ To run inference on Human3.6M:
9292
sh ./scripts/FMPose3D_test.sh
9393
```
9494

95+
### Inference API
96+
97+
FMPose3D also ships a high-level Python API for end-to-end 3D pose estimation from images. See the [Inference API documentation](fmpose3d/inference_api/README.md) for the full reference.
98+
9599
## Experiments on non-human animals
96100

97101
For animal training/testing and demo scripts, see [animals/README.md](animals/README.md).

fmpose3d/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030

3131
# High-level inference API
32-
from .fmpose3d import (
32+
from .inference_api.fmpose3d import (
3333
FMPose3DInference,
3434
HRNetEstimator,
3535
Pose2DResult,

fmpose3d/common/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
import math
11+
import json
1112
from dataclasses import dataclass, field, fields, asdict
1213
from enum import Enum
1314
from typing import Dict, List
@@ -36,6 +37,16 @@ class ModelConfig:
3637
"""Model architecture configuration."""
3738
model_type: str = "fmpose3d_humans"
3839

40+
def to_json(self, filename: str | None = None, **kwargs) -> str:
41+
json_str = json.dumps(asdict(self), **kwargs)
42+
with open(filename, "w") as f:
43+
f.write(json_str)
44+
45+
@classmethod
46+
def from_json(cls, filename: str, **kwargs) -> "ModelConfig":
47+
with open(filename, "r") as f:
48+
return cls(**json.loads(f.read(), **kwargs))
49+
3950

4051
# Per-model-type defaults for fields marked with INFER_FROM_MODEL_TYPE.
4152
# Also consumed by PipelineConfig.for_model_type to set cross-config

fmpose3d/inference_api/README.md

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# FMPose3D Inference API
2+
3+
## Overview
4+
This inference API provides a high-level, end-to-end interface for monocular 3D pose estimation using flow matching. It wraps the full pipeline — input ingestion, 2D keypoint detection, and 3D lifting — behind a single `FMPose3DInference` class, supporting both **human** (17-joint H36M) and **animal** (26-joint Animal3D) skeletons. Model weights are downloaded automatically from HuggingFace when not provided locally.
5+
6+
---
7+
8+
9+
## Quick Examples
10+
11+
**Human pose estimation (end-to-end):**
12+
13+
```python
14+
from fmpose3d import FMPose3DInference, FMPose3DConfig
15+
16+
# Create a config (optional)
17+
config = FMPose3DConfig(model_type="fmpose3d_humans") # or "fmpose3d_animals"
18+
19+
# Initialize the API
20+
api = FMPose3DInference(config) # weights auto-downloaded
21+
22+
# Predict from source (path, or an image array)
23+
result = api.predict("photo.jpg")
24+
print(result.poses_3d.shape) # (1, 17, 3)
25+
print(result.poses_3d_world.shape) # (1, 17, 3)
26+
```
27+
28+
**Human pose estimation (two-step):**
29+
30+
```python
31+
from fmpose3d import FMPose3DInference
32+
33+
api = FMPose3DInference(model_weights_path="weights.pth")
34+
35+
# The 2D and 3D inference step can be called separately
36+
result_2d = api.prepare_2d("photo.jpg")
37+
result_3d = api.pose_3d(result_2d.keypoints, result_2d.image_size)
38+
```
39+
40+
**Animal pose estimation:**
41+
42+
```python
43+
from fmpose3d import FMPose3DInference
44+
45+
# The api has a convenience method for loading directly with the animal config
46+
api = FMPose3DInference.for_animals()
47+
result = api.predict("dog.jpg")
48+
print(result.poses_3d.shape) # (1, 26, 3)
49+
```
50+
51+
52+
## API Documentation
53+
54+
### `FMPose3DInference` — Main Inference Class
55+
56+
The high-level entry point. Manages the full pipeline: input ingestion, 2D estimation, and 3D lifting.
57+
58+
#### Constructor
59+
60+
```python
61+
FMPose3DInference(
62+
model_cfg: FMPose3DConfig | None = None,
63+
inference_cfg: InferenceConfig | None = None,
64+
model_weights_path: str | Path | None = None,
65+
device: str | torch.device | None = None,
66+
*,
67+
estimator_2d: HRNetEstimator | SuperAnimalEstimator | None = None,
68+
postprocessor: HumanPostProcessor | AnimalPostProcessor | None = None,
69+
)
70+
```
71+
72+
| Parameter | Description |
73+
|---|---|
74+
| `model_cfg` | Model architecture settings. Defaults to human (17 H36M joints). |
75+
| `inference_cfg` | Inference settings (sample steps, test augmentation, etc.). |
76+
| `model_weights_path` | Path to a `.pth` checkpoint. `None` triggers automatic download from HuggingFace. |
77+
| `device` | Compute device. `None` auto-selects CUDA if available. |
78+
| `estimator_2d` | Override the 2D pose estimator (auto-selected by default). |
79+
| `postprocessor` | Override the post-processor (auto-selected by default). |
80+
81+
#### `FMPose3DInference.for_animals(...)` — Class Method
82+
83+
```python
84+
@classmethod
85+
def for_animals(
86+
cls,
87+
model_weights_path: str | None = None,
88+
*,
89+
device: str | torch.device | None = None,
90+
inference_cfg: InferenceConfig | None = None,
91+
) -> FMPose3DInference
92+
```
93+
94+
Convenience constructor for the **animal** pipeline. Sets `model_type="fmpose3d_animals"`, loads the appropriate config (26-joint Animal3D skeleton) and disables flip augmentation by default.
95+
96+
---
97+
98+
### Public Methods
99+
100+
#### `predict(source, *, camera_rotation, seed, progress)` → `Pose3DResult`
101+
102+
End-to-end prediction: 2D estimation followed by 3D lifting in a single call.
103+
104+
| Parameter | Type | Description |
105+
|---|---|---|
106+
| `source` | `Source` | Image path, directory, numpy array `(H,W,C)` or `(N,H,W,C)`, or list thereof. Video files are not supported. |
107+
| `camera_rotation` | `ndarray \| None` | Length-4 quaternion for camera-to-world rotation. Defaults to the official demo rotation. `None` skips the transform. Ignored for animals. |
108+
| `seed` | `int \| None` | Seed for reproducible sampling. |
109+
| `progress` | `ProgressCallback \| None` | Callback `(current_step, total_steps) -> None`. |
110+
111+
**Returns:** `Pose3DResult`
112+
113+
---
114+
115+
#### `prepare_2d(source, progress)` → `Pose2DResult`
116+
117+
Runs only the 2D pose estimation step.
118+
119+
| Parameter | Type | Description |
120+
|---|---|---|
121+
| `source` | `Source` | Same flexible input as `predict()`. |
122+
| `progress` | `ProgressCallback \| None` | Optional progress callback. |
123+
124+
**Returns:** `Pose2DResult` containing `keypoints`, `scores`, and `image_size`.
125+
126+
---
127+
128+
#### `pose_3d(keypoints_2d, image_size, *, camera_rotation, seed, progress)` → `Pose3DResult`
129+
130+
Lifts pre-computed 2D keypoints to 3D using the flow-matching model.
131+
132+
| Parameter | Type | Description |
133+
|---|---|---|
134+
| `keypoints_2d` | `ndarray` | Shape `(num_persons, num_frames, J, 2)` or `(num_frames, J, 2)`. First person is used if 4D. |
135+
| `image_size` | `tuple[int, int]` | `(height, width)` of the source frames. |
136+
| `camera_rotation` | `ndarray \| None` | Camera-to-world quaternion (human only). |
137+
| `seed` | `int \| None` | Seed for reproducible results. |
138+
| `progress` | `ProgressCallback \| None` | Per-frame progress callback. |
139+
140+
**Returns:** `Pose3DResult`
141+
142+
---
143+
144+
#### `setup_runtime()`
145+
146+
Manually initializes all runtime components (2D estimator, 3D model, weights). Called automatically on first use of `predict`, `prepare_2d`, or `pose_3d`.
147+
148+
---
149+
150+
### Types & Data Classes
151+
152+
### `Source`
153+
154+
Accepted source types for `FMPose3DInference.predict` and `prepare_2d`:
155+
156+
- `str` or `Path` — path to an image file or a directory of images.
157+
- `np.ndarray` — a single frame `(H, W, C)` or a batch `(N, H, W, C)`.
158+
- `list` — a list of file paths or a list of `(H, W, C)` arrays.
159+
160+
```python
161+
Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]]
162+
```
163+
164+
#### `Pose2DResult`
165+
166+
| Field | Type | Description |
167+
|---|---|---|
168+
| `keypoints` | `ndarray` | 2D keypoints, shape `(num_persons, num_frames, J, 2)`. |
169+
| `scores` | `ndarray` | Per-joint confidence, shape `(num_persons, num_frames, J)`. |
170+
| `image_size` | `tuple[int, int]` | `(height, width)` of source frames. |
171+
172+
#### `Pose3DResult`
173+
174+
| Field | Type | Description |
175+
|---|---|---|
176+
| `poses_3d` | `ndarray` | Root-relative 3D poses, shape `(num_frames, J, 3)`. |
177+
| `poses_3d_world` | `ndarray` | Post-processed 3D poses, shape `(num_frames, J, 3)`. For humans: world-coordinate poses. For animals: limb-regularized poses. |
178+
179+
180+
181+
---
182+
183+
### 2D Estimators
184+
185+
#### `HRNetEstimator(cfg: HRNetConfig | None)`
186+
187+
Default 2D estimator for the human pipeline. Wraps HRNet + YOLO with a COCOH36M keypoint conversion.
188+
189+
- `setup_runtime()` — Loads YOLO + HRNet models.
190+
- `predict(frames: ndarray)``(keypoints, scores)` — Returns H36M-format 2D keypoints from BGR frames `(N, H, W, C)`.
191+
192+
#### `SuperAnimalEstimator(cfg: SuperAnimalConfig | None)`
193+
194+
2D estimator for the animal pipeline. Uses DeepLabCut SuperAnimal and maps quadruped80K keypoints to the 26-joint Animal3D layout.
195+
196+
- `setup_runtime()` — No-op (DLC loads lazily).
197+
- `predict(frames: ndarray)``(keypoints, scores)` — Returns Animal3D-format 2D keypoints from BGR frames.
198+
199+
---
200+
201+
### Post-Processors
202+
203+
#### `HumanPostProcessor`
204+
205+
Zeros the root joint (root-relative) and applies `camera_to_world` rotation.
206+
207+
#### `AnimalPostProcessor`
208+
209+
Applies limb regularization (rotates the pose so that average limb direction is vertical). No root zeroing or camera-to-world transform.
210+
211+
---
212+
213+
214+

fmpose3d/inference_api/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
FMPose3D: monocular 3D Pose Estimation via Flow Matching
3+
4+
Official implementation of the paper:
5+
"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
6+
by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
7+
Licensed under Apache 2.0
8+
"""
9+
10+
from .fmpose3d import (
11+
FMPose3DInference,
12+
HRNetEstimator,
13+
Pose2DResult,
14+
Pose3DResult,
15+
Source,
16+
)
17+
18+
__all__ = [
19+
"FMPose3DInference",
20+
"HRNetEstimator",
21+
"Pose2DResult",
22+
"Pose3DResult",
23+
"Source",
24+
]
Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
ProgressCallback = Callable[[int, int], None]
3434

3535

36+
#: HuggingFace repository hosting the official FMPose3D checkpoints.
37+
_HF_REPO_ID: str = "deruyter92/fmpose_temp"
38+
3639
# Default camera-to-world rotation quaternion (from the demo script).
3740
_DEFAULT_CAM_ROTATION = np.array(
3841
[0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
@@ -560,7 +563,7 @@ def __init__(
560563
self,
561564
model_cfg: FMPose3DConfig | None = None,
562565
inference_cfg: InferenceConfig | None = None,
563-
model_weights_path: str | Path | None = SKIP_WEIGHTS_VALIDATION,
566+
model_weights_path: str | Path | None = None,
564567
device: str | torch.device | None = None,
565568
*,
566569
estimator_2d: HRNetEstimator | SuperAnimalEstimator | None = None,
@@ -601,7 +604,7 @@ def __init__(
601604
@classmethod
602605
def for_animals(
603606
cls,
604-
model_weights_path: str = SKIP_WEIGHTS_VALIDATION,
607+
model_weights_path: str | None = None,
605608
*,
606609
device: str | torch.device | None = None,
607610
inference_cfg: InferenceConfig | None = None,
@@ -958,15 +961,11 @@ def _load_weights(self) -> None:
958961
# Private helpers – input resolution
959962
# ------------------------------------------------------------------
960963

961-
def _resolve_model_weights_path(self) -> None:
962-
# TODO @deruyter92: THIS IS TEMPORARY UNTIL WE DOWNLOAD THE WEIGHTS FROM HUGGINGFACE
963-
if self.model_weights_path is SKIP_WEIGHTS_VALIDATION:
964-
return SKIP_WEIGHTS_VALIDATION
965-
966-
if not self.model_weights_path:
964+
def _resolve_model_weights_path(self) -> None:
965+
if self.model_weights_path is None:
967966
self._download_model_weights()
968967
self.model_weights_path = Path(self.model_weights_path).resolve()
969-
if not self.model_weights_path.exists():
968+
if not self.model_weights_path.is_file():
970969
raise ValueError(
971970
f"Model weights file not found: {self.model_weights_path}. "
972971
"Please provide a valid path to a .pth checkpoint file in the "
@@ -976,12 +975,28 @@ def _resolve_model_weights_path(self) -> None:
976975
return self.model_weights_path
977976

978977
def _download_model_weights(self) -> None:
979-
"""Download model weights from huggingface."""
980-
# TODO @deruyter92: Implement download from huggingface
981-
raise NotImplementedError(
982-
"Downloading model weights from huggingface is not implemented yet."
983-
"Please provide a valid path to a .pth checkpoint file in the "
984-
"FMPose3DInference constructor."
978+
"""Download model weights from HuggingFace Hub.
979+
980+
The weight file is determined by the current ``model_cfg.model_type``
981+
(e.g. ``"fmpose3d_humans"`` -> ``fmpose3d_humans.pth``). Files are
982+
cached locally by :func:`huggingface_hub.hf_hub_download` so
983+
subsequent calls are instant.
984+
985+
Sets ``self.model_weights_path`` to the local cached file path.
986+
"""
987+
try:
988+
from huggingface_hub import hf_hub_download
989+
except ImportError:
990+
raise ImportError(
991+
"huggingface_hub is required to download model weights. "
992+
"Install it with: pip install huggingface_hub. Or download "
993+
"the weights manually and set model_weights_path to the weights file."
994+
) from None
995+
996+
filename = f"{self.model_cfg.model_type.value}.pth"
997+
self.model_weights_path = hf_hub_download(
998+
repo_id=_HF_REPO_ID,
999+
filename=filename,
9851000
)
9861001

9871002
def _ingest_input(self, source: Source) -> _IngestedInput:

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
"filterpy>=1.4.5",
3939
"pandas>=1.0.1",
4040
"deeplabcut==3.0.0rc13",
41+
"huggingface_hub>=0.20.0",
4142
]
4243

4344
[project.optional-dependencies]
@@ -69,6 +70,7 @@ line_length = 100
6970
[tool.pytest.ini_options]
7071
markers = [
7172
"functional: marks tests that require pretrained weights (deselect with '-m \"not functional\"')",
73+
"network: marks tests that may need internet access on first run (deselect with '-m \"not network\"')",
7274
]
7375

7476
[tool.codespell]

tests/__init__.py

Whitespace-only changes.

tests/fmpose3d_api/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""
2+
FMPose3D: monocular 3D Pose Estimation via Flow Matching
3+
4+
Official implementation of the paper:
5+
"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
6+
by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
7+
Licensed under Apache 2.0
8+
"""

0 commit comments

Comments
 (0)