Skip to content

Commit 61b4f11

Browse files
authored
Merge pull request #44 from C-Achard/cy/fix-modelzoo-download
Fix path reconstruction in ModelZoo download
2 parents 8b1af3a + 9919240 commit 61b4f11

2 files changed

Lines changed: 190 additions & 53 deletions

File tree

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import json
1414
import os
1515
import tarfile
16+
import shutil
17+
import tempfile
1618
from pathlib import Path
1719

1820
from huggingface_hub import hf_hub_download
@@ -111,34 +113,50 @@ def get_available_models(dataset: str) -> list[str]:
111113
return list(_load_pytorch_dataset_models(dataset)["pose_models"].keys())
112114

113115

116+
114117
def _handle_downloaded_file(
115118
file_path: str, target_dir: str, rename_mapping: dict | None = None
116119
):
117-
"""Handle the downloaded file from HuggingFace"""
120+
"""Handle the downloaded file from HuggingFace cache and place the final artifact in target_dir."""
118121
file_name = os.path.basename(file_path)
122+
119123
try:
120-
with tarfile.open(file_path, mode="r:gz") as tar:
121-
for member in tar:
122-
if not member.isdir():
123-
fname = Path(member.name).name
124-
tar.makefile(member, os.path.join(target_dir, fname))
125-
except tarfile.ReadError: # The model is a .pt file
124+
# Be permissive about compression type
125+
with tarfile.open(file_path, mode="r:*") as tar:
126+
extracted_any = False
127+
for member in tar.getmembers():
128+
# Only extract regular files
129+
if not member.isfile():
130+
continue
131+
132+
fname = Path(member.name).name
133+
if not fname:
134+
continue
135+
136+
src = tar.extractfile(member)
137+
if src is None:
138+
continue
139+
140+
extracted_path = os.path.join(target_dir, fname)
141+
with src, open(extracted_path, "wb") as dst:
142+
shutil.copyfileobj(src, dst)
143+
144+
extracted_any = True
145+
146+
# If it opened as a tar but contained nothing useful, fail loudly
147+
if not extracted_any:
148+
raise tarfile.ReadError(f"No regular files extracted from archive: {file_path}")
149+
150+
except tarfile.ReadError:
151+
# Not an archive -> treat as a direct model file (.pt/.pth/etc.)
126152
if rename_mapping is not None:
127153
file_name = rename_mapping.get(file_name, file_name)
128-
if os.path.islink(file_path):
129-
file_path_ = os.readlink(file_path)
130-
if not os.path.isabs(file_path_):
131-
file_path_ = os.path.abspath(
132-
os.path.join(os.path.dirname(file_path), file_path_)
133-
)
134-
file_path = file_path_
135-
os.rename(file_path, os.path.join(target_dir, file_name))
154+
shutil.copy2(file_path, os.path.join(target_dir, file_name))
136155

137156

138157
def download_huggingface_model(
139158
model_name: str,
140159
target_dir: str = ".",
141-
remove_hf_folder: bool = True,
142160
rename_mapping: str | dict | None = None,
143161
):
144162
"""
@@ -151,10 +169,6 @@ def download_huggingface_model(
151169
target_dir (str, optional):
152170
Target directory where the model weights will be stored.
153171
Defaults to the current directory.
154-
remove_hf_folder (bool, optional):
155-
Whether to remove the directory structure created by HuggingFace
156-
after downloading and decompressing the data into DeepLabCut format.
157-
Defaults to True.
158172
rename_mapping (dict | str | None, optional):
159173
- If a dictionary, it should map the original Hugging Face filenames
160174
to new filenames (e.g. {"snapshot-12345.tar.gz": "mymodel.tar.gz"}).
@@ -164,7 +178,7 @@ def download_huggingface_model(
164178
165179
Examples:
166180
>>> # Download without renaming, keep original filename
167-
download_huggingface_model("superanimal_bird_resnet_50", remove_hf_folder=False)
181+
download_huggingface_model("superanimal_bird_resnet_50")
168182
169183
>>> # Download and rename by specifying the new name directly
170184
download_huggingface_model(
@@ -188,25 +202,22 @@ def download_huggingface_model(
188202

189203
if not os.path.isabs(target_dir):
190204
target_dir = os.path.abspath(target_dir)
205+
os.makedirs(target_dir, exist_ok=True)
191206

192-
for url in urls:
193-
url = url.split("/")
194-
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])
195-
196-
hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir))
197-
198-
# Create a new subfolder as indicated below, unzipping from there and deleting this folder
199-
hf_folder = f"models--{url[0]}--{url[1]}"
200-
path_ = os.path.join(target_dir, hf_folder, "snapshots")
201-
commit = os.listdir(path_)[0]
202-
file_name = os.path.join(path_, commit, targzfn)
203-
204-
if isinstance(rename_mapping, str):
205-
rename_mapping = {targzfn: rename_mapping}
207+
with tempfile.TemporaryDirectory(prefix="dlc_hf_") as hf_cache_dir:
208+
for url in urls:
209+
url = url.split("/")
210+
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])
206211

207-
_handle_downloaded_file(file_name, target_dir, rename_mapping)
212+
downloaded = hf_hub_download(
213+
repo_id=repo_id,
214+
filename=targzfn,
215+
cache_dir=hf_cache_dir,
216+
)
208217

209-
if remove_hf_folder:
210-
import shutil
218+
if isinstance(rename_mapping, str):
219+
mapping = {targzfn: rename_mapping}
220+
else:
221+
mapping = rename_mapping
211222

212-
shutil.rmtree(os.path.join(target_dir, hf_folder))
223+
_handle_downloaded_file(downloaded, target_dir, mapping)

tests/test_modeldownload.py

Lines changed: 140 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,125 @@
88
#
99
# Licensed under GNU Lesser General Public License v3.0
1010
#
11-
import dlclibrary
11+
from __future__ import annotations
12+
13+
import io
1214
import os
15+
import tarfile
16+
from pathlib import Path
17+
1318
import pytest
19+
20+
import dlclibrary
21+
import dlclibrary.dlcmodelzoo.modelzoo_download as modelzoo_download
1422
from dlclibrary.dlcmodelzoo.modelzoo_download import MODELOPTIONS
1523

1624

17-
def test_download_huggingface_model(tmp_path_factory, model="full_cat"):
18-
folder = tmp_path_factory.mktemp("temp")
25+
def _fake_model_names():
26+
"""
27+
Return a deterministic fake URL for each model.
28+
Alternate between tar.gz and .pt to test both branches.
29+
"""
30+
mapping = {}
31+
for i, model in enumerate(MODELOPTIONS):
32+
ext = ".tar.gz" if i % 2 == 0 else ".pt"
33+
mapping[model] = f"fakeorg/{model}-repo/{model}{ext}"
34+
return mapping
35+
36+
37+
def _write_fake_tar_gz(path: Path):
38+
"""
39+
Create a fake tar.gz archive with the files the downloader expects
40+
for archive-based DLC models.
41+
"""
42+
path.parent.mkdir(parents=True, exist_ok=True)
43+
44+
with tarfile.open(path, mode="w:gz") as tar:
45+
files = {
46+
"pose_cfg.yaml": b"all_joints: [0, 1]\n",
47+
"snapshot-103000.index": b"fake index",
48+
"snapshot-103000.data-00000-of-00001": b"fake weights",
49+
"snapshot-103000.meta": b"fake meta",
50+
}
51+
52+
for name, content in files.items():
53+
info = tarfile.TarInfo(name=name)
54+
info.size = len(content)
55+
tar.addfile(info, io.BytesIO(content))
56+
57+
58+
def _write_fake_pt(path: Path):
59+
"""
60+
Create a fake .pt / .pth weight file.
61+
"""
62+
path.parent.mkdir(parents=True, exist_ok=True)
63+
path.write_bytes(b"fake pytorch weights")
64+
65+
66+
@pytest.fixture
67+
def mock_modelzoo(monkeypatch):
68+
"""
69+
Patch both:
70+
- model name resolution
71+
- hf_hub_download network call
72+
73+
so all downloads are local and deterministic.
74+
"""
75+
fake_names = _fake_model_names()
76+
77+
monkeypatch.setattr(modelzoo_download, "_load_model_names", lambda: fake_names)
78+
79+
def fake_hf_hub_download(repo_id, filename, cache_dir):
80+
cache_dir = Path(cache_dir)
81+
hf_folder = cache_dir / f"models--{repo_id.replace('/', '--')}"
82+
snapshot_dir = hf_folder / "snapshots" / "fakecommit123"
83+
returned_file = snapshot_dir / filename
84+
85+
if filename.endswith(".tar.gz"):
86+
_write_fake_tar_gz(returned_file)
87+
elif filename.endswith(".pt") or filename.endswith(".pth"):
88+
_write_fake_pt(returned_file)
89+
else:
90+
raise AssertionError(f"Unexpected mocked filename: {filename}")
91+
92+
return str(returned_file)
93+
94+
monkeypatch.setattr(modelzoo_download, "hf_hub_download", fake_hf_hub_download)
95+
96+
return fake_names
97+
98+
99+
def _assert_download_success(folder: Path, model: str):
100+
"""
101+
Shared assertion helper for download_huggingface_model.
102+
"""
19103
dlclibrary.download_huggingface_model(model, str(folder))
20104

21-
try: # These are not created for .pt models
22-
assert os.path.exists(folder / "pose_cfg.yaml")
23-
assert any(f.startswith("snapshot-") for f in os.listdir(folder))
24-
except AssertionError:
25-
assert any(f.endswith(".pth") for f in os.listdir(folder))
105+
files = {p.name for p in folder.iterdir()}
106+
107+
# Archive-based DLC model
108+
if "pose_cfg.yaml" in files:
109+
assert "pose_cfg.yaml" in files
110+
assert any(name.startswith("snapshot-") for name in files)
111+
112+
# Direct PyTorch model
113+
else:
114+
assert any(name.endswith((".pt", ".pth")) for name in files)
115+
116+
# Verify that the Hugging Face cache folder was removed
117+
assert not any(name.startswith("models--") for name in files)
26118

27-
# Verify that the Hugging Face folder was removed
28-
assert not any(f.startswith("models--") for f in os.listdir(folder))
29119

120+
def test_download_huggingface_model_tar_or_pt(tmp_path, mock_modelzoo):
121+
folder = tmp_path / "download_one"
122+
folder.mkdir()
30123

31-
def test_download_huggingface_wrong_model():
124+
# "full_cat" may map to tar.gz or .pt depending on ordering;
125+
# this assertion helper supports both branches.
126+
_assert_download_success(folder, "full_cat")
127+
128+
129+
def test_download_huggingface_wrong_model(mock_modelzoo):
32130
with pytest.raises(ValueError):
33131
dlclibrary.download_huggingface_model("wrong_model_name")
34132

@@ -40,6 +138,34 @@ def test_parse_superanimal_models():
40138

41139

42140
@pytest.mark.parametrize("model", MODELOPTIONS)
43-
def test_download_all_models(tmp_path_factory, model):
44-
print("Downloading ...", model)
45-
test_download_huggingface_model(tmp_path_factory, model)
141+
def test_download_all_models(tmp_path, mock_modelzoo, model):
142+
folder = tmp_path / model
143+
folder.mkdir()
144+
_assert_download_success(folder, model)
145+
146+
147+
def test_download_with_rename_mapping_for_pt(tmp_path, mock_modelzoo):
148+
"""
149+
Explicitly test rename_mapping for a .pt model.
150+
"""
151+
# Pick one of the mocked .pt models
152+
pt_model = None
153+
for i, model in enumerate(MODELOPTIONS):
154+
if i % 2 == 1:
155+
pt_model = model
156+
break
157+
158+
assert pt_model is not None, "Expected at least one mocked .pt model"
159+
160+
folder = tmp_path / "rename_pt"
161+
folder.mkdir()
162+
163+
dlclibrary.download_huggingface_model(
164+
pt_model,
165+
str(folder),
166+
rename_mapping="renamed_weights.pt",
167+
)
168+
169+
files = {p.name for p in folder.iterdir()}
170+
assert "renamed_weights.pt" in files
171+
assert not any(name.startswith("models--") for name in files)

0 commit comments

Comments
 (0)