88#
99# Licensed under GNU Lesser General Public License v3.0
1010#
11- import dlclibrary
11+ from __future__ import annotations
12+
13+ import io
1214import os
15+ import tarfile
16+ from pathlib import Path
17+
1318import pytest
19+
20+ import dlclibrary
21+ import dlclibrary .dlcmodelzoo .modelzoo_download as modelzoo_download
1422from dlclibrary .dlcmodelzoo .modelzoo_download import MODELOPTIONS
1523
1624
17- def test_download_huggingface_model (tmp_path_factory , model = "full_cat" ):
18- folder = tmp_path_factory .mktemp ("temp" )
25+ def _fake_model_names ():
26+ """
27+ Return a deterministic fake URL for each model.
28+ Alternate between tar.gz and .pt to test both branches.
29+ """
30+ mapping = {}
31+ for i , model in enumerate (MODELOPTIONS ):
32+ ext = ".tar.gz" if i % 2 == 0 else ".pt"
33+ mapping [model ] = f"fakeorg/{ model } -repo/{ model } { ext } "
34+ return mapping
35+
36+
37+ def _write_fake_tar_gz (path : Path ):
38+ """
39+ Create a fake tar.gz archive with the files the downloader expects
40+ for archive-based DLC models.
41+ """
42+ path .parent .mkdir (parents = True , exist_ok = True )
43+
44+ with tarfile .open (path , mode = "w:gz" ) as tar :
45+ files = {
46+ "pose_cfg.yaml" : b"all_joints: [0, 1]\n " ,
47+ "snapshot-103000.index" : b"fake index" ,
48+ "snapshot-103000.data-00000-of-00001" : b"fake weights" ,
49+ "snapshot-103000.meta" : b"fake meta" ,
50+ }
51+
52+ for name , content in files .items ():
53+ info = tarfile .TarInfo (name = name )
54+ info .size = len (content )
55+ tar .addfile (info , io .BytesIO (content ))
56+
57+
58+ def _write_fake_pt (path : Path ):
59+ """
60+ Create a fake .pt / .pth weight file.
61+ """
62+ path .parent .mkdir (parents = True , exist_ok = True )
63+ path .write_bytes (b"fake pytorch weights" )
64+
65+
66+ @pytest .fixture
67+ def mock_modelzoo (monkeypatch ):
68+ """
69+ Patch both:
70+ - model name resolution
71+ - hf_hub_download network call
72+
73+ so all downloads are local and deterministic.
74+ """
75+ fake_names = _fake_model_names ()
76+
77+ monkeypatch .setattr (modelzoo_download , "_load_model_names" , lambda : fake_names )
78+
79+ def fake_hf_hub_download (repo_id , filename , cache_dir ):
80+ cache_dir = Path (cache_dir )
81+ hf_folder = cache_dir / f"models--{ repo_id .replace ('/' , '--' )} "
82+ snapshot_dir = hf_folder / "snapshots" / "fakecommit123"
83+ returned_file = snapshot_dir / filename
84+
85+ if filename .endswith (".tar.gz" ):
86+ _write_fake_tar_gz (returned_file )
87+ elif filename .endswith (".pt" ) or filename .endswith (".pth" ):
88+ _write_fake_pt (returned_file )
89+ else :
90+ raise AssertionError (f"Unexpected mocked filename: { filename } " )
91+
92+ return str (returned_file )
93+
94+ monkeypatch .setattr (modelzoo_download , "hf_hub_download" , fake_hf_hub_download )
95+
96+ return fake_names
97+
98+
99+ def _assert_download_success (folder : Path , model : str ):
100+ """
101+ Shared assertion helper for download_huggingface_model.
102+ """
19103 dlclibrary .download_huggingface_model (model , str (folder ))
20104
21- try : # These are not created for .pt models
22- assert os .path .exists (folder / "pose_cfg.yaml" )
23- assert any (f .startswith ("snapshot-" ) for f in os .listdir (folder ))
24- except AssertionError :
25- assert any (f .endswith (".pth" ) for f in os .listdir (folder ))
105+ files = {p .name for p in folder .iterdir ()}
106+
107+ # Archive-based DLC model
108+ if "pose_cfg.yaml" in files :
109+ assert "pose_cfg.yaml" in files
110+ assert any (name .startswith ("snapshot-" ) for name in files )
111+
112+ # Direct PyTorch model
113+ else :
114+ assert any (name .endswith ((".pt" , ".pth" )) for name in files )
115+
116+ # Verify that the Hugging Face cache folder was removed
117+ assert not any (name .startswith ("models--" ) for name in files )
26118
27- # Verify that the Hugging Face folder was removed
28- assert not any (f .startswith ("models--" ) for f in os .listdir (folder ))
29119
120+ def test_download_huggingface_model_tar_or_pt (tmp_path , mock_modelzoo ):
121+ folder = tmp_path / "download_one"
122+ folder .mkdir ()
30123
31- def test_download_huggingface_wrong_model ():
124+ # "full_cat" may map to tar.gz or .pt depending on ordering;
125+ # this assertion helper supports both branches.
126+ _assert_download_success (folder , "full_cat" )
127+
128+
129+ def test_download_huggingface_wrong_model (mock_modelzoo ):
32130 with pytest .raises (ValueError ):
33131 dlclibrary .download_huggingface_model ("wrong_model_name" )
34132
@@ -40,6 +138,34 @@ def test_parse_superanimal_models():
40138
41139
42140@pytest .mark .parametrize ("model" , MODELOPTIONS )
43- def test_download_all_models (tmp_path_factory , model ):
44- print ("Downloading ..." , model )
45- test_download_huggingface_model (tmp_path_factory , model )
141+ def test_download_all_models (tmp_path , mock_modelzoo , model ):
142+ folder = tmp_path / model
143+ folder .mkdir ()
144+ _assert_download_success (folder , model )
145+
146+
147+ def test_download_with_rename_mapping_for_pt (tmp_path , mock_modelzoo ):
148+ """
149+ Explicitly test rename_mapping for a .pt model.
150+ """
151+ # Pick one of the mocked .pt models
152+ pt_model = None
153+ for i , model in enumerate (MODELOPTIONS ):
154+ if i % 2 == 1 :
155+ pt_model = model
156+ break
157+
158+ assert pt_model is not None , "Expected at least one mocked .pt model"
159+
160+ folder = tmp_path / "rename_pt"
161+ folder .mkdir ()
162+
163+ dlclibrary .download_huggingface_model (
164+ pt_model ,
165+ str (folder ),
166+ rename_mapping = "renamed_weights.pt" ,
167+ )
168+
169+ files = {p .name for p in folder .iterdir ()}
170+ assert "renamed_weights.pt" in files
171+ assert not any (name .startswith ("models--" ) for name in files )
0 commit comments