-
Notifications
You must be signed in to change notification settings - Fork 54
Expand file tree
/
Copy pathtest_modelzoo.py
More file actions
80 lines (62 loc) · 2.75 KB
/
test_modelzoo.py
File metadata and controls
80 lines (62 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase
# from deeplabcut/tests/pose_estimation_pytorch/modelzoo/test_modelzoo_utils.py
import os
import dlclibrary
import pytest
from dlclibrary.dlcmodelzoo.modelzoo_download import MODELOPTIONS
from dlclive import modelzoo
@pytest.mark.parametrize(
"super_animal", ["superanimal_quadruped", "superanimal_topviewmouse"]
)
@pytest.mark.parametrize("model_name", ["hrnet_w32"])
@pytest.mark.parametrize("detector_name", [None, "fasterrcnn_resnet50_fpn_v2"])
def test_get_config_model_paths(super_animal, model_name, detector_name):
model_config = modelzoo.load_super_animal_config(
super_animal=super_animal,
model_name=model_name,
detector_name=detector_name,
)
assert isinstance(model_config, dict)
if detector_name is None:
assert model_config["method"].lower() == "bu"
assert "detector" not in model_config
else:
assert model_config["method"].lower() == "td"
assert "detector" in model_config
def test_humanbody_requires_detector_name():
with pytest.raises(ValueError):
modelzoo.load_super_animal_config(
super_animal="superanimal_humanbody",
model_name="hrnet_w32",
detector_name=None,
)
def test_humanbody_rejects_unsupported_detector():
with pytest.raises(ValueError):
modelzoo.load_super_animal_config(
super_animal="superanimal_humanbody",
model_name="hrnet_w32",
detector_name="fasterrcnn_resnet50_fpn_v2",
)
def test_humanbody_uses_torchvision_detector_config():
model_config = modelzoo.load_super_animal_config(
super_animal="superanimal_humanbody",
model_name="hrnet_w32",
detector_name="fasterrcnn_mobilenet_v3_large_fpn",
)
detector_model_cfg = model_config["detector"]["model"]
assert model_config["method"].lower() == "td"
assert detector_model_cfg["type"] == "TorchvisionDetectorAdaptor"
def test_download_huggingface_model(tmp_path_factory, model="full_cat"):
folder = tmp_path_factory.mktemp("temp")
dlclibrary.download_huggingface_model(model, str(folder))
assert os.path.exists(folder / "pose_cfg.yaml")
assert any(f.startswith("snapshot-") for f in os.listdir(folder))
# Verify that the Hugging Face folder was removed
assert not any(f.startswith("models--") for f in os.listdir(folder))
def test_download_huggingface_wrong_model():
with pytest.raises(ValueError):
dlclibrary.download_huggingface_model("wrong_model_name")
@pytest.mark.skip(reason="slow")
@pytest.mark.parametrize("model", MODELOPTIONS)
def test_download_all_models(tmp_path_factory, model):
test_download_huggingface_model(tmp_path_factory, model)