-
Notifications
You must be signed in to change notification settings - Fork 55
Expand file tree
/
Copy pathpytorch_model_zoo_export.py
More file actions
56 lines (44 loc) · 1.85 KB
/
pytorch_model_zoo_export.py
File metadata and controls
56 lines (44 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import warnings
from pathlib import Path
from collections import OrderedDict
import torch
from dlclive.modelzoo.utils import load_super_animal_config, download_super_animal_snapshot
def export_modelzoo_model(
export_path: str | Path,
super_animal: str,
model_name: str,
detector_name: str | None = None,
) -> None:
"""
"""
Path(export_path).parent.mkdir(parents=True, exist_ok=True)
if Path(export_path).exists():
warnings.warn(f"Export path {export_path} already exists, skipping export", UserWarning)
return
model_cfg = load_super_animal_config(
super_animal=super_animal,
model_name=model_name,
detector_name=detector_name,
)
def _load_model_weights(model_name: str, super_animal: str = super_animal) -> OrderedDict:
"""Download the model weights from huggingface and load them in torch state dict"""
checkpoint: Path = download_super_animal_snapshot(dataset=super_animal, model_name=model_name)
return torch.load(checkpoint, map_location="cpu", weights_only=True)["model"]
# Skip downloading the detector weights for humanbody models, as they are not on huggingface
skip_detector_download = (detector_name is None) or (super_animal == "superanimal_humanbody")
export_dict = {
"config": model_cfg,
"pose": _load_model_weights(model_name),
"detector": None if skip_detector_download else _load_model_weights(detector_name),
}
torch.save(export_dict, export_path)
if __name__ == "__main__":
"""Example usage"""
from utils import _MODELZOO_PATH
model_name = "resnet_50"
super_animal = "superanimal_quadruped"
export_modelzoo_model(
export_path=_MODELZOO_PATH / 'exported_models' / f'exported_{super_animal}_{model_name}.pt',
super_animal=super_animal,
model_name=model_name,
)