Skip to content

Commit b6a32f6

Browse files
committed
Use temporary HF cache and fix tar extraction
Use a temporary directory for HuggingFace cache (tempfile.TemporaryDirectory) so HF cache folders are not created under the target_dir. Improve tar.gz handling in _handle_downloaded_file by extracting members with tar.extractfile and shutil.copyfileobj (and fall back to shutil.copy2 for non-tar files), and remove the previous symlink-resolution and explicit HF-folder removal logic. Update hf_hub_download call to use repo_id/filename named args and thread per-file rename mappings through to the extractor. Adjusted tests to reflect that the HF cache is no longer created inside the target directory and to ensure the final artifact still exists.
1 parent 8fe4e6b commit b6a32f6

2 files changed

Lines changed: 30 additions & 36 deletions

File tree

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import json
1414
import os
1515
import tarfile
16+
import shutil
17+
import tempfile
1618
from pathlib import Path
1719

1820
from huggingface_hub import hf_hub_download
@@ -111,32 +113,27 @@ def get_available_models(dataset: str) -> list[str]:
111113
return list(_load_pytorch_dataset_models(dataset)["pose_models"].keys())
112114

113115

116+
114117
def _handle_downloaded_file(
115118
file_path: str, target_dir: str, rename_mapping: dict | None = None
116119
):
117-
"""Handle the downloaded file from HuggingFace"""
120+
"""Handle the downloaded file from HuggingFace cache and place the final artifact in target_dir."""
118121
file_name = os.path.basename(file_path)
122+
119123
try:
120124
with tarfile.open(file_path, mode="r:gz") as tar:
121125
for member in tar:
122126
if not member.isdir():
123127
fname = Path(member.name).name
124-
tar.makefile(member, os.path.join(target_dir, fname))
125-
except tarfile.ReadError: # The model is a .pt file
128+
extracted_path = os.path.join(target_dir, fname)
129+
with tar.extractfile(member) as src, open(extracted_path, "wb") as dst:
130+
shutil.copyfileobj(src, dst)
131+
except tarfile.ReadError:
126132
if rename_mapping is not None:
127133
file_name = rename_mapping.get(file_name, file_name)
128-
if os.path.islink(file_path):
129-
file_path_ = os.readlink(file_path)
130-
if not os.path.isabs(file_path_):
131-
file_path_ = os.path.abspath(
132-
os.path.join(os.path.dirname(file_path), file_path_)
133-
)
134-
file_path = file_path_
135-
import shutil
136134
shutil.copy2(file_path, os.path.join(target_dir, file_name))
137135

138136

139-
140137
def download_huggingface_model(
141138
model_name: str,
142139
target_dir: str = ".",
@@ -190,27 +187,22 @@ def download_huggingface_model(
190187

191188
if not os.path.isabs(target_dir):
192189
target_dir = os.path.abspath(target_dir)
193-
194190
os.makedirs(target_dir, exist_ok=True)
195191

196-
last_hf_folder = None
197-
198-
for url in urls:
199-
url = url.split("/")
200-
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])
201-
202-
downloaded = hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir))
203-
204-
if isinstance(rename_mapping, str):
205-
mapping = {targzfn: rename_mapping}
206-
else:
207-
mapping = rename_mapping
208-
209-
_handle_downloaded_file(downloaded, target_dir, mapping)
192+
with tempfile.TemporaryDirectory(prefix="dlc_hf_") as hf_cache_dir:
193+
for url in urls:
194+
url = url.split("/")
195+
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])
210196

211-
last_hf_folder = f"models--{url[0]}--{url[1]}"
197+
downloaded = hf_hub_download(
198+
repo_id=repo_id,
199+
filename=targzfn,
200+
cache_dir=hf_cache_dir,
201+
)
212202

213-
if remove_hf_folder and last_hf_folder is not None:
214-
import shutil
215-
shutil.rmtree(os.path.join(target_dir, last_hf_folder), ignore_errors=True)
203+
if isinstance(rename_mapping, str):
204+
mapping = {targzfn: rename_mapping}
205+
else:
206+
mapping = rename_mapping
216207

208+
_handle_downloaded_file(downloaded, target_dir, mapping)

tests/test_modeldownload.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,7 @@ def test_download_with_rename_mapping_for_pt(tmp_path, mock_modelzoo):
171171
assert not any(name.startswith("models--") for name in files)
172172

173173

174-
def test_keep_hf_folder_when_requested(tmp_path, mock_modelzoo):
175-
"""
176-
If remove_hf_folder=False, the cache structure should still exist.
177-
"""
174+
def test_remove_hf_folder_flag_no_longer_affects_target_dir(tmp_path, mock_modelzoo):
178175
folder = tmp_path / "keep_cache"
179176
folder.mkdir()
180177

@@ -185,4 +182,9 @@ def test_keep_hf_folder_when_requested(tmp_path, mock_modelzoo):
185182
)
186183

187184
files = {p.name for p in folder.iterdir()}
188-
assert any(name.startswith("models--") for name in files)
185+
186+
# Final artifact should still be there
187+
assert "full_cat.pt" in files
188+
189+
# HF cache should no longer be created inside target_dir
190+
assert not any(name.startswith("models--") for name in files)

0 commit comments

Comments
 (0)