Skip to content

Commit b2d7e0f

Browse files
committed
Introduce SupportedModel Enum
1 parent d112fc8 commit b2d7e0f

4 files changed

Lines changed: 26 additions & 4 deletions

File tree

fmpose3d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
HRNetConfig,
2525
InferenceConfig,
2626
ModelConfig,
27+
SupportedModel,
2728
PipelineConfig,
2829
)
2930

@@ -57,6 +58,7 @@
5758
"HRNetConfig",
5859
"InferenceConfig",
5960
"ModelConfig",
61+
"SupportedModel",
6062
"PipelineConfig",
6163
# Aggregation methods
6264
"average_aggregation",

fmpose3d/common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .config import (
1616
PipelineConfig,
1717
ModelConfig,
18+
SupportedModel,
1819
FMPose3DConfig,
1920
HRNetConfig,
2021
Pose2DConfig,
@@ -48,6 +49,7 @@
4849
"HRNetConfig",
4950
"Pose2DConfig",
5051
"ModelConfig",
52+
"SupportedModel",
5153
"DatasetConfig",
5254
"TrainingConfig",
5355
"InferenceConfig",

fmpose3d/common/config.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,28 @@
99

1010
import math
1111
from dataclasses import dataclass, field, fields, asdict
12+
from enum import Enum
1213
from typing import Dict, List
1314

1415
# ---------------------------------------------------------------------------
1516
# Dataclass configuration groups
1617
# ---------------------------------------------------------------------------
1718

1819

20+
class SupportedModel(str, Enum):
21+
"""Supported FMPose3D pose-estimation model types."""
22+
FMPOSE3D_HUMANS = "fmpose3d_humans"
23+
FMPOSE3D_ANIMALS = "fmpose3d_animals"
24+
25+
@classmethod
26+
def _missing_(cls, value: str) -> "SupportedModel":
27+
valid = ", ".join(repr(m.value) for m in cls)
28+
raise ValueError(
29+
f"{value!r} is not a valid {cls.__name__}. "
30+
f"Valid values are: {valid}"
31+
)
32+
33+
1934
@dataclass
2035
class ModelConfig:
2136
"""Model architecture configuration."""
@@ -51,7 +66,7 @@ class ModelConfig:
5166

5267
@dataclass
5368
class FMPose3DConfig(ModelConfig):
54-
model_type: str = "fmpose3d_humans"
69+
model_type: SupportedModel = SupportedModel.FMPOSE3D_HUMANS
5570
model: str = ""
5671
layers: int = 5
5772
channel: int = 512
@@ -67,6 +82,8 @@ class FMPose3DConfig(ModelConfig):
6782
frames: int = 1
6883

6984
def __post_init__(self):
85+
if not isinstance(self.model_type, SupportedModel):
86+
self.model_type = SupportedModel(self.model_type)
7087
defaults = _FMPOSE3D_DEFAULTS.get(self.model_type)
7188
if defaults is None:
7289
supported = ", ".join(sorted(_FMPOSE3D_DEFAULTS))
@@ -321,7 +338,7 @@ def _pick(dc_class, src: dict):
321338

322339
kwargs = {}
323340
for group_name, dc_class in _SUB_CONFIG_CLASSES.items():
324-
if group_name == "model_cfg" and raw.get("model_type", 'fmpose3d_humans') in _FMPOSE3D_DEFAULTS:
341+
if group_name == "model_cfg" and raw.get("model_type", "fmpose3d_humans") in _FMPOSE3D_DEFAULTS:
325342
dc_class = FMPose3DConfig
326343
elif group_name == "pose2d_cfg":
327344
p2d = raw.get("pose2d_model", "hrnet")

fmpose3d/fmpose3d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
FMPose3DConfig,
2525
HRNetConfig,
2626
InferenceConfig,
27+
SupportedModel,
2728
SuperAnimalConfig,
2829
)
2930
from fmpose3d.models import get_model
@@ -430,7 +431,7 @@ def _default_components(
430431
inspected to choose pipeline components. Adding a third pipeline
431432
means adding one branch here (or turning this into a registry).
432433
"""
433-
if model_cfg.model_type == "fmpose3d_animals":
434+
if model_cfg.model_type == SupportedModel.FMPOSE3D_ANIMALS:
434435
return SuperAnimalEstimator(), AnimalPostProcessor()
435436
return HRNetEstimator(), HumanPostProcessor()
436437

@@ -624,7 +625,7 @@ def for_animals(
624625
if inference_cfg is None:
625626
inference_cfg = InferenceConfig(test_augmentation=False)
626627
return cls(
627-
model_cfg=FMPose3DConfig(model_type="fmpose3d_animals"),
628+
model_cfg=FMPose3DConfig(model_type=SupportedModel.FMPOSE3D_ANIMALS),
628629
inference_cfg=inference_cfg,
629630
model_weights_path=model_weights_path,
630631
device=device,

0 commit comments

Comments
 (0)