44
55import torch
66
7- from dlclive .modelzoo .utils import load_super_animal_config , download_super_animal_snapshot
7+ from dlclive .modelzoo .utils import (
8+ load_super_animal_config ,
9+ download_super_animal_snapshot ,
10+ )
811
912
1013def export_modelzoo_model (
@@ -29,7 +32,9 @@ def export_modelzoo_model(
2932 """
3033 Path (export_path ).parent .mkdir (parents = True , exist_ok = True )
3134 if Path (export_path ).exists ():
32- warnings .warn (f"Export path { export_path } already exists, skipping export" , UserWarning )
35+ warnings .warn (
36+ f"Export path { export_path } already exists, skipping export" , UserWarning
37+ )
3338 return
3439
3540 model_cfg = load_super_animal_config (
@@ -38,30 +43,40 @@ def export_modelzoo_model(
3843 detector_name = detector_name ,
3944 )
4045
41- def _load_model_weights (model_name : str , super_animal : str = super_animal ) -> OrderedDict :
46+ def _load_model_weights (
47+ model_name : str , super_animal : str = super_animal
48+ ) -> OrderedDict :
4249 """Download the model weights from huggingface and load them in torch state dict"""
43- checkpoint : Path = download_super_animal_snapshot (dataset = super_animal , model_name = model_name )
50+ checkpoint : Path = download_super_animal_snapshot (
51+ dataset = super_animal , model_name = model_name
52+ )
4453 return torch .load (checkpoint , map_location = "cpu" , weights_only = True )["model" ]
45-
54+
4655 # Skip downloading the detector weights for humanbody models, as they are not on huggingface
47- skip_detector_download = (detector_name is None ) or (super_animal == "superanimal_humanbody" )
56+ skip_detector_download = (detector_name is None ) or (
57+ super_animal == "superanimal_humanbody"
58+ )
4859 export_dict = {
4960 "config" : model_cfg ,
5061 "pose" : _load_model_weights (model_name ),
51- "detector" : None if skip_detector_download else _load_model_weights (detector_name ),
62+ "detector" : None
63+ if skip_detector_download
64+ else _load_model_weights (detector_name ),
5265 }
5366 torch .save (export_dict , export_path )
5467
5568
5669if __name__ == "__main__" :
57- """Example usage"""
70+ """Example usage"""
5871 from utils import _MODELZOO_PATH
59-
72+
6073 model_name = "resnet_50"
6174 super_animal = "superanimal_quadruped"
6275
6376 export_modelzoo_model (
64- export_path = _MODELZOO_PATH / 'exported_models' / f'exported_{ super_animal } _{ model_name } .pt' ,
77+ export_path = _MODELZOO_PATH
78+ / "exported_models"
79+ / f"exported_{ super_animal } _{ model_name } .pt" ,
6580 super_animal = super_animal ,
6681 model_name = model_name ,
6782 )
0 commit comments