Skip to content

Commit 8fe4e6b

Browse files
committed
Update test_modeldownload.py
1 parent 20cf022 commit 8fe4e6b

1 file changed

Lines changed: 157 additions & 14 deletions

File tree

tests/test_modeldownload.py

Lines changed: 157 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,125 @@
88
#
99
# Licensed under GNU Lesser General Public License v3.0
1010
#
11-
import dlclibrary
11+
from __future__ import annotations
12+
13+
import io
1214
import os
15+
import tarfile
16+
from pathlib import Path
17+
1318
import pytest
19+
20+
import dlclibrary
21+
import dlclibrary.dlcmodelzoo.modelzoo_download as modelzoo_download
1422
from dlclibrary.dlcmodelzoo.modelzoo_download import MODELOPTIONS
1523

1624

17-
def test_download_huggingface_model(tmp_path_factory, model="full_cat"):
18-
folder = tmp_path_factory.mktemp("temp")
25+
def _fake_model_names():
26+
"""
27+
Return a deterministic fake URL for each model.
28+
Alternate between tar.gz and .pt to test both branches.
29+
"""
30+
mapping = {}
31+
for i, model in enumerate(MODELOPTIONS):
32+
ext = ".tar.gz" if i % 2 == 0 else ".pt"
33+
mapping[model] = f"fakeorg/{model}-repo/{model}{ext}"
34+
return mapping
35+
36+
37+
def _write_fake_tar_gz(path: Path):
38+
"""
39+
Create a fake tar.gz archive with the files the downloader expects
40+
for archive-based DLC models.
41+
"""
42+
path.parent.mkdir(parents=True, exist_ok=True)
43+
44+
with tarfile.open(path, mode="w:gz") as tar:
45+
files = {
46+
"pose_cfg.yaml": b"all_joints: [0, 1]\n",
47+
"snapshot-103000.index": b"fake index",
48+
"snapshot-103000.data-00000-of-00001": b"fake weights",
49+
"snapshot-103000.meta": b"fake meta",
50+
}
51+
52+
for name, content in files.items():
53+
info = tarfile.TarInfo(name=name)
54+
info.size = len(content)
55+
tar.addfile(info, io.BytesIO(content))
56+
57+
58+
def _write_fake_pt(path: Path):
59+
"""
60+
Create a fake .pt / .pth weight file.
61+
"""
62+
path.parent.mkdir(parents=True, exist_ok=True)
63+
path.write_bytes(b"fake pytorch weights")
64+
65+
66+
@pytest.fixture
67+
def mock_modelzoo(monkeypatch):
68+
"""
69+
Patch both:
70+
- model name resolution
71+
- hf_hub_download network call
72+
73+
so all downloads are local and deterministic.
74+
"""
75+
fake_names = _fake_model_names()
76+
77+
monkeypatch.setattr(modelzoo_download, "_load_model_names", lambda: fake_names)
78+
79+
def fake_hf_hub_download(repo_id, filename, cache_dir):
80+
cache_dir = Path(cache_dir)
81+
hf_folder = cache_dir / f"models--{repo_id.replace('/', '--')}"
82+
snapshot_dir = hf_folder / "snapshots" / "fakecommit123"
83+
returned_file = snapshot_dir / filename
84+
85+
if filename.endswith(".tar.gz"):
86+
_write_fake_tar_gz(returned_file)
87+
elif filename.endswith(".pt") or filename.endswith(".pth"):
88+
_write_fake_pt(returned_file)
89+
else:
90+
raise AssertionError(f"Unexpected mocked filename: {filename}")
91+
92+
return str(returned_file)
93+
94+
monkeypatch.setattr(modelzoo_download, "hf_hub_download", fake_hf_hub_download)
95+
96+
return fake_names
97+
98+
99+
def _assert_download_success(folder: Path, model: str):
100+
"""
101+
Shared assertion helper for download_huggingface_model.
102+
"""
19103
dlclibrary.download_huggingface_model(model, str(folder))
20104

21-
try: # These are not created for .pt models
22-
assert os.path.exists(folder / "pose_cfg.yaml")
23-
assert any(f.startswith("snapshot-") for f in os.listdir(folder))
24-
except AssertionError:
25-
assert any(f.endswith(".pth") for f in os.listdir(folder))
105+
files = {p.name for p in folder.iterdir()}
106+
107+
# Archive-based DLC model
108+
if "pose_cfg.yaml" in files:
109+
assert "pose_cfg.yaml" in files
110+
assert any(name.startswith("snapshot-") for name in files)
111+
112+
# Direct PyTorch model
113+
else:
114+
assert any(name.endswith((".pt", ".pth")) for name in files)
26115

27-
# Verify that the Hugging Face folder was removed
28-
assert not any(f.startswith("models--") for f in os.listdir(folder))
116+
# Verify that the Hugging Face cache folder was removed
117+
assert not any(name.startswith("models--") for name in files)
29118

30119

31-
def test_download_huggingface_wrong_model():
120+
def test_download_huggingface_model_tar_or_pt(tmp_path, mock_modelzoo):
121+
folder = tmp_path / "download_one"
122+
folder.mkdir()
123+
124+
# "full_cat" may map to tar.gz or .pt depending on ordering;
125+
# this assertion helper supports both branches.
126+
_assert_download_success(folder, "full_cat")
127+
128+
129+
def test_download_huggingface_wrong_model(mock_modelzoo):
32130
with pytest.raises(ValueError):
33131
dlclibrary.download_huggingface_model("wrong_model_name")
34132

@@ -40,6 +138,51 @@ def test_parse_superanimal_models():
40138

41139

42140
@pytest.mark.parametrize("model", MODELOPTIONS)
43-
def test_download_all_models(tmp_path_factory, model):
44-
print("Downloading ...", model)
45-
test_download_huggingface_model(tmp_path_factory, model)
141+
def test_download_all_models(tmp_path, mock_modelzoo, model):
142+
folder = tmp_path / model
143+
folder.mkdir()
144+
_assert_download_success(folder, model)
145+
146+
147+
def test_download_with_rename_mapping_for_pt(tmp_path, mock_modelzoo):
148+
"""
149+
Explicitly test rename_mapping for a .pt model.
150+
"""
151+
# Pick one of the mocked .pt models
152+
pt_model = None
153+
for i, model in enumerate(MODELOPTIONS):
154+
if i % 2 == 1:
155+
pt_model = model
156+
break
157+
158+
assert pt_model is not None, "Expected at least one mocked .pt model"
159+
160+
folder = tmp_path / "rename_pt"
161+
folder.mkdir()
162+
163+
dlclibrary.download_huggingface_model(
164+
pt_model,
165+
str(folder),
166+
rename_mapping="renamed_weights.pt",
167+
)
168+
169+
files = {p.name for p in folder.iterdir()}
170+
assert "renamed_weights.pt" in files
171+
assert not any(name.startswith("models--") for name in files)
172+
173+
174+
def test_keep_hf_folder_when_requested(tmp_path, mock_modelzoo):
175+
"""
176+
If remove_hf_folder=False, the cache structure should still exist.
177+
"""
178+
folder = tmp_path / "keep_cache"
179+
folder.mkdir()
180+
181+
dlclibrary.download_huggingface_model(
182+
"full_cat",
183+
str(folder),
184+
remove_hf_folder=False,
185+
)
186+
187+
files = {p.name for p in folder.iterdir()}
188+
assert any(name.startswith("models--") for name in files)

0 commit comments

Comments
 (0)