Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ To run inference on Human3.6M:
sh ./scripts/FMPose3D_test.sh
```

### Inference API

FMPose3D also ships a high-level Python API for end-to-end 3D pose estimation from images. See the [Inference API documentation](fmpose3d/inference_api/README.md) for the full reference.

## Experiments on non-human animals

For animal training/testing and demo scripts, see [animals/README.md](animals/README.md).
Expand Down
2 changes: 1 addition & 1 deletion fmpose3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)

# High-level inference API
from .fmpose3d import (
from .inference_api.fmpose3d import (
FMPose3DInference,
HRNetEstimator,
Pose2DResult,
Expand Down
11 changes: 11 additions & 0 deletions fmpose3d/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import math
import json
from dataclasses import dataclass, field, fields, asdict
from enum import Enum
from typing import Dict, List
Expand Down Expand Up @@ -36,6 +37,16 @@ class ModelConfig:
"""Model architecture configuration."""
model_type: str = "fmpose3d_humans"

def to_json(self, filename: str | None = None, **kwargs) -> str:
json_str = json.dumps(asdict(self), **kwargs)
with open(filename, "w") as f:
f.write(json_str)

@classmethod
def from_json(cls, filename: str, **kwargs) -> "ModelConfig":
with open(filename, "r") as f:
return cls(**json.loads(f.read(), **kwargs))


# Per-model-type defaults for fields marked with INFER_FROM_MODEL_TYPE.
# Also consumed by PipelineConfig.for_model_type to set cross-config
Expand Down
214 changes: 214 additions & 0 deletions fmpose3d/inference_api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# FMPose3D Inference API

## Overview
This inference API provides a high-level, end-to-end interface for monocular 3D pose estimation using flow matching. It wraps the full pipeline — input ingestion, 2D keypoint detection, and 3D lifting — behind a single `FMPose3DInference` class, supporting both **human** (17-joint H36M) and **animal** (26-joint Animal3D) skeletons. Model weights are downloaded automatically from HuggingFace when not provided locally.

---


## Quick Examples

**Human pose estimation (end-to-end):**

```python
from fmpose3d import FMPose3DInference, FMPose3DConfig

# Create a config (optional)
config = FMPose3DConfig(model_type="fmpose3d_humans") # or "fmpose3d_animals"

# Initialize the API
api = FMPose3DInference(config) # weights auto-downloaded

# Predict from source (path, or an image array)
result = api.predict("photo.jpg")
print(result.poses_3d.shape) # (1, 17, 3)
print(result.poses_3d_world.shape) # (1, 17, 3)
```

**Human pose estimation (two-step):**

```python
from fmpose3d import FMPose3DInference

api = FMPose3DInference(model_weights_path="weights.pth")

# The 2D and 3D inference step can be called separately
result_2d = api.prepare_2d("photo.jpg")
result_3d = api.pose_3d(result_2d.keypoints, result_2d.image_size)
```

**Animal pose estimation:**

```python
from fmpose3d import FMPose3DInference

# The api has a convenience method for loading directly with the animal config
api = FMPose3DInference.for_animals()
result = api.predict("dog.jpg")
print(result.poses_3d.shape) # (1, 26, 3)
```


## API Documentation

### `FMPose3DInference` — Main Inference Class

The high-level entry point. Manages the full pipeline: input ingestion, 2D estimation, and 3D lifting.

#### Constructor

```python
FMPose3DInference(
model_cfg: FMPose3DConfig | None = None,
inference_cfg: InferenceConfig | None = None,
model_weights_path: str | Path | None = None,
device: str | torch.device | None = None,
*,
estimator_2d: HRNetEstimator | SuperAnimalEstimator | None = None,
postprocessor: HumanPostProcessor | AnimalPostProcessor | None = None,
)
```

| Parameter | Description |
|---|---|
| `model_cfg` | Model architecture settings. Defaults to human (17 H36M joints). |
| `inference_cfg` | Inference settings (sample steps, test augmentation, etc.). |
| `model_weights_path` | Path to a `.pth` checkpoint. `None` triggers automatic download from HuggingFace. |
| `device` | Compute device. `None` auto-selects CUDA if available. |
| `estimator_2d` | Override the 2D pose estimator (auto-selected by default). |
| `postprocessor` | Override the post-processor (auto-selected by default). |

#### `FMPose3DInference.for_animals(...)` — Class Method

```python
@classmethod
def for_animals(
cls,
model_weights_path: str | None = None,
*,
device: str | torch.device | None = None,
inference_cfg: InferenceConfig | None = None,
) -> FMPose3DInference
```

Convenience constructor for the **animal** pipeline. Sets `model_type="fmpose3d_animals"`, loads the appropriate config (26-joint Animal3D skeleton) and disables flip augmentation by default.

---

### Public Methods

#### `predict(source, *, camera_rotation, seed, progress)` → `Pose3DResult`

End-to-end prediction: 2D estimation followed by 3D lifting in a single call.

| Parameter | Type | Description |
|---|---|---|
| `source` | `Source` | Image path, directory, numpy array `(H,W,C)` or `(N,H,W,C)`, or list thereof. Video files are not supported. |
| `camera_rotation` | `ndarray \| None` | Length-4 quaternion for camera-to-world rotation. Defaults to the official demo rotation. `None` skips the transform. Ignored for animals. |
| `seed` | `int \| None` | Seed for reproducible sampling. |
| `progress` | `ProgressCallback \| None` | Callback `(current_step, total_steps) -> None`. |

**Returns:** `Pose3DResult`

---

#### `prepare_2d(source, progress)` → `Pose2DResult`

Runs only the 2D pose estimation step.

| Parameter | Type | Description |
|---|---|---|
| `source` | `Source` | Same flexible input as `predict()`. |
| `progress` | `ProgressCallback \| None` | Optional progress callback. |

**Returns:** `Pose2DResult` containing `keypoints`, `scores`, and `image_size`.

---

#### `pose_3d(keypoints_2d, image_size, *, camera_rotation, seed, progress)` → `Pose3DResult`

Lifts pre-computed 2D keypoints to 3D using the flow-matching model.

| Parameter | Type | Description |
|---|---|---|
| `keypoints_2d` | `ndarray` | Shape `(num_persons, num_frames, J, 2)` or `(num_frames, J, 2)`. First person is used if 4D. |
| `image_size` | `tuple[int, int]` | `(height, width)` of the source frames. |
| `camera_rotation` | `ndarray \| None` | Camera-to-world quaternion (human only). |
| `seed` | `int \| None` | Seed for reproducible results. |
| `progress` | `ProgressCallback \| None` | Per-frame progress callback. |

**Returns:** `Pose3DResult`

---

#### `setup_runtime()`

Manually initializes all runtime components (2D estimator, 3D model, weights). Called automatically on first use of `predict`, `prepare_2d`, or `pose_3d`.

---

### Types & Data Classes

### `Source`

Accepted source types for `FMPose3DInference.predict` and `prepare_2d`:

- `str` or `Path` — path to an image file or a directory of images.
- `np.ndarray` — a single frame `(H, W, C)` or a batch `(N, H, W, C)`.
- `list` — a list of file paths or a list of `(H, W, C)` arrays.

```python
Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]]
```

#### `Pose2DResult`

| Field | Type | Description |
|---|---|---|
| `keypoints` | `ndarray` | 2D keypoints, shape `(num_persons, num_frames, J, 2)`. |
| `scores` | `ndarray` | Per-joint confidence, shape `(num_persons, num_frames, J)`. |
| `image_size` | `tuple[int, int]` | `(height, width)` of source frames. |

#### `Pose3DResult`

| Field | Type | Description |
|---|---|---|
| `poses_3d` | `ndarray` | Root-relative 3D poses, shape `(num_frames, J, 3)`. |
| `poses_3d_world` | `ndarray` | Post-processed 3D poses, shape `(num_frames, J, 3)`. For humans: world-coordinate poses. For animals: limb-regularized poses. |



---

### 2D Estimators

#### `HRNetEstimator(cfg: HRNetConfig | None)`

Default 2D estimator for the human pipeline. Wraps HRNet + YOLO with a COCO → H36M keypoint conversion.

- `setup_runtime()` — Loads YOLO + HRNet models.
- `predict(frames: ndarray)` → `(keypoints, scores)` — Returns H36M-format 2D keypoints from BGR frames `(N, H, W, C)`.

#### `SuperAnimalEstimator(cfg: SuperAnimalConfig | None)`

2D estimator for the animal pipeline. Uses DeepLabCut SuperAnimal and maps quadruped80K keypoints to the 26-joint Animal3D layout.

- `setup_runtime()` — No-op (DLC loads lazily).
- `predict(frames: ndarray)` → `(keypoints, scores)` — Returns Animal3D-format 2D keypoints from BGR frames.

---

### Post-Processors

#### `HumanPostProcessor`

Zeros the root joint (root-relative) and applies `camera_to_world` rotation.

#### `AnimalPostProcessor`

Applies limb regularization (rotates the pose so that average limb direction is vertical). No root zeroing or camera-to-world transform.

---



24 changes: 24 additions & 0 deletions fmpose3d/inference_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
FMPose3D: monocular 3D Pose Estimation via Flow Matching

Official implementation of the paper:
"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
Licensed under Apache 2.0
"""

from .fmpose3d import (
FMPose3DInference,
HRNetEstimator,
Pose2DResult,
Pose3DResult,
Source,
)

__all__ = [
"FMPose3DInference",
"HRNetEstimator",
"Pose2DResult",
"Pose3DResult",
"Source",
]
45 changes: 30 additions & 15 deletions fmpose3d/fmpose3d.py → fmpose3d/inference_api/fmpose3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
ProgressCallback = Callable[[int, int], None]


#: HuggingFace repository hosting the official FMPose3D checkpoints.
_HF_REPO_ID: str = "deruyter92/fmpose_temp"

# Default camera-to-world rotation quaternion (from the demo script).
_DEFAULT_CAM_ROTATION = np.array(
[0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
Expand Down Expand Up @@ -560,7 +563,7 @@ def __init__(
self,
model_cfg: FMPose3DConfig | None = None,
inference_cfg: InferenceConfig | None = None,
model_weights_path: str | Path | None = SKIP_WEIGHTS_VALIDATION,
model_weights_path: str | Path | None = None,
device: str | torch.device | None = None,
*,
estimator_2d: HRNetEstimator | SuperAnimalEstimator | None = None,
Expand Down Expand Up @@ -601,7 +604,7 @@ def __init__(
@classmethod
def for_animals(
cls,
model_weights_path: str = SKIP_WEIGHTS_VALIDATION,
model_weights_path: str | None = None,
*,
device: str | torch.device | None = None,
inference_cfg: InferenceConfig | None = None,
Expand Down Expand Up @@ -958,15 +961,11 @@ def _load_weights(self) -> None:
# Private helpers – input resolution
# ------------------------------------------------------------------

def _resolve_model_weights_path(self) -> None:
# TODO @deruyter92: THIS IS TEMPORARY UNTIL WE DOWNLOAD THE WEIGHTS FROM HUGGINGFACE
if self.model_weights_path is SKIP_WEIGHTS_VALIDATION:
return SKIP_WEIGHTS_VALIDATION

if not self.model_weights_path:
def _resolve_model_weights_path(self) -> None:
if self.model_weights_path is None:
self._download_model_weights()
self.model_weights_path = Path(self.model_weights_path).resolve()
if not self.model_weights_path.exists():
if not self.model_weights_path.is_file():
raise ValueError(
f"Model weights file not found: {self.model_weights_path}. "
"Please provide a valid path to a .pth checkpoint file in the "
Expand All @@ -976,12 +975,28 @@ def _resolve_model_weights_path(self) -> None:
return self.model_weights_path

def _download_model_weights(self) -> None:
"""Download model weights from huggingface."""
# TODO @deruyter92: Implement download from huggingface
raise NotImplementedError(
"Downloading model weights from huggingface is not implemented yet."
"Please provide a valid path to a .pth checkpoint file in the "
"FMPose3DInference constructor."
"""Download model weights from HuggingFace Hub.

The weight file is determined by the current ``model_cfg.model_type``
(e.g. ``"fmpose3d_humans"`` -> ``fmpose3d_humans.pth``). Files are
cached locally by :func:`huggingface_hub.hf_hub_download` so
subsequent calls are instant.

Sets ``self.model_weights_path`` to the local cached file path.
"""
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError(
"huggingface_hub is required to download model weights. "
"Install it with: pip install huggingface_hub. Or download "
"the weights manually and set model_weights_path to the weights file."
) from None

filename = f"{self.model_cfg.model_type.value}.pth"
self.model_weights_path = hf_hub_download(
repo_id=_HF_REPO_ID,
filename=filename,
)

def _ingest_input(self, source: Source) -> _IngestedInput:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"filterpy>=1.4.5",
"pandas>=1.0.1",
"deeplabcut==3.0.0rc13",
"huggingface_hub>=0.20.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -69,6 +70,7 @@ line_length = 100
[tool.pytest.ini_options]
markers = [
"functional: marks tests that require pretrained weights (deselect with '-m \"not functional\"')",
"network: marks tests that may need internet access on first run (deselect with '-m \"not network\"')",
]

[tool.codespell]
Expand Down
Empty file added tests/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions tests/fmpose3d_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
FMPose3D: monocular 3D Pose Estimation via Flow Matching

Official implementation of the paper:
"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
Licensed under Apache 2.0
"""
Loading
Loading