Skip to content

Commit 9919240

Browse files
committed
Make tar extraction more robust and permissive
Open archives with a permissive compression mode (r:*) and only extract regular files. Skip members with empty basenames or where extractfile() returns None, use a context manager for file streams, and raise a ReadError if no regular files were extracted to fail loudly. Preserve previous fallback: if the file isn't an archive, copy it as a direct model file. Also update example usage to remove the explicit remove_hf_folder argument.
1 parent 673ff9a commit 9919240

1 file changed

Lines changed: 27 additions & 7 deletions

File tree

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,34 @@ def _handle_downloaded_file(
121121
file_name = os.path.basename(file_path)
122122

123123
try:
124-
with tarfile.open(file_path, mode="r:gz") as tar:
125-
for member in tar:
124+
# Be permissive about compression type
125+
with tarfile.open(file_path, mode="r:*") as tar:
126+
extracted_any = False
127+
for member in tar.getmembers():
128+
# Only extract regular files
126129
if not member.isfile():
127-
fname = Path(member.name).name
128-
extracted_path = os.path.join(target_dir, fname)
129-
with tar.extractfile(member) as src, open(extracted_path, "wb") as dst:
130-
shutil.copyfileobj(src, dst)
130+
continue
131+
132+
fname = Path(member.name).name
133+
if not fname:
134+
continue
135+
136+
src = tar.extractfile(member)
137+
if src is None:
138+
continue
139+
140+
extracted_path = os.path.join(target_dir, fname)
141+
with src, open(extracted_path, "wb") as dst:
142+
shutil.copyfileobj(src, dst)
143+
144+
extracted_any = True
145+
146+
# If it opened as a tar but contained nothing useful, fail loudly
147+
if not extracted_any:
148+
raise tarfile.ReadError(f"No regular files extracted from archive: {file_path}")
149+
131150
except tarfile.ReadError:
151+
# Not an archive -> treat as a direct model file (.pt/.pth/etc.)
132152
if rename_mapping is not None:
133153
file_name = rename_mapping.get(file_name, file_name)
134154
shutil.copy2(file_path, os.path.join(target_dir, file_name))
@@ -158,7 +178,7 @@ def download_huggingface_model(
158178
159179
Examples:
160180
>>> # Download without renaming, keep original filename
161-
download_huggingface_model("superanimal_bird_resnet_50", remove_hf_folder=False)
181+
download_huggingface_model("superanimal_bird_resnet_50")
162182
163183
>>> # Download and rename by specifying the new name directly
164184
download_huggingface_model(

0 commit comments

Comments
 (0)