Skip to content

Commit b414616

Browse files
authored
fix: retry transient errors and clean up partial files in model downloads (#422)
1 parent 803212a commit b414616

2 files changed

Lines changed: 477 additions & 22 deletions

File tree

comfy_cli/file_utils.py

Lines changed: 111 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import pathlib
44
import subprocess
5+
import time
56
import zipfile
67

78
import httpx
@@ -162,6 +163,83 @@ def _download_file_aria2(url: str, local_filepath: pathlib.Path, headers: dict |
162163

163164
_VALID_DOWNLOADERS = {"httpx", "aria2"}
164165

166+
_DOWNLOAD_MAX_RETRIES = 3
167+
_DOWNLOAD_RETRY_BACKOFF = 2 # seconds multiplier
168+
_DOWNLOAD_TIMEOUT = httpx.Timeout(10.0, read=300.0)
169+
_TRANSIENT_EXCEPTIONS = (
170+
httpx.TimeoutException,
171+
httpx.NetworkError,
172+
httpx.ProtocolError,
173+
httpx.ProxyError,
174+
)
175+
176+
177+
def _cleanup_partial(filepath: pathlib.Path) -> None:
178+
"""Remove a partially downloaded file if it exists."""
179+
try:
180+
filepath.unlink(missing_ok=True)
181+
except OSError:
182+
pass
183+
184+
185+
def _friendly_network_error(exc: Exception) -> str:
186+
"""Return a user-friendly description of a network error."""
187+
if isinstance(exc, httpx.ReadTimeout):
188+
return "the server stopped sending data (read timeout)"
189+
if isinstance(exc, httpx.ConnectTimeout):
190+
return "could not connect to the server (connect timeout)"
191+
if isinstance(exc, httpx.TimeoutException):
192+
return f"the operation timed out ({type(exc).__name__})"
193+
if isinstance(exc, httpx.NetworkError):
194+
return f"a network error occurred ({type(exc).__name__}: {exc})"
195+
if isinstance(exc, httpx.ProtocolError):
196+
return f"a protocol error occurred ({type(exc).__name__}: {exc})"
197+
if isinstance(exc, httpx.ProxyError):
198+
return f"a proxy error occurred ({type(exc).__name__}: {exc})"
199+
return str(exc)
200+
201+
202+
def _download_file_httpx(
203+
url: str,
204+
local_filepath: pathlib.Path,
205+
headers: dict | None = None,
206+
*,
207+
state: dict | None = None,
208+
) -> None:
209+
"""Download a file using httpx streaming. Raises on HTTP or network errors.
210+
211+
If ``state`` is provided, ``state["file_opened"]`` is set to True immediately
212+
after the output file is opened for writing. Callers use this to distinguish
213+
failures raised *before* the destination was touched (HTTP errors, ConnectError,
214+
etc.) from failures raised *after* writing started (mid-stream ReadTimeout),
215+
so they can avoid deleting an unrelated pre-existing file at the destination.
216+
"""
217+
with httpx.stream("GET", url, follow_redirects=True, headers=headers, timeout=_DOWNLOAD_TIMEOUT) as response:
218+
if response.status_code != 200:
219+
try:
220+
error_body = response.read()
221+
except _TRANSIENT_EXCEPTIONS:
222+
error_body = ""
223+
status_reason = guess_status_code_reason(response.status_code, error_body)
224+
raise DownloadException(f"Failed to download file.\n{status_reason}")
225+
226+
content_length = response.headers.get("Content-Length")
227+
total = int(content_length) if content_length is not None else None
228+
if total is not None:
229+
description = f"Downloading {total // 1024 // 1024} MB"
230+
else:
231+
description = "Downloading..."
232+
233+
with open(local_filepath, "wb") as f:
234+
if state is not None:
235+
state["file_opened"] = True
236+
for data in ui.show_progress(
237+
response.iter_bytes(),
238+
total,
239+
description=description,
240+
):
241+
f.write(data)
242+
165243

166244
def download_file(url: str, local_filepath: pathlib.Path, headers: dict | None = None, downloader: str = "httpx"):
167245
"""Helper function to download a file."""
@@ -170,34 +248,45 @@ def download_file(url: str, local_filepath: pathlib.Path, headers: dict | None =
170248
f"Unknown downloader: {downloader!r}. Valid options: {', '.join(sorted(_VALID_DOWNLOADERS))}"
171249
)
172250

173-
local_filepath.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists
251+
local_filepath.parent.mkdir(parents=True, exist_ok=True)
174252

175253
if downloader == "aria2":
176254
return _download_file_aria2(url, local_filepath, headers)
177255

178-
with httpx.stream("GET", url, follow_redirects=True, headers=headers) as response:
179-
if response.status_code == 200:
180-
content_length = response.headers.get("Content-Length")
181-
total = int(content_length) if content_length is not None else None
182-
if total is not None:
183-
description = f"Downloading {total // 1024 // 1024} MB"
184-
else:
185-
description = "Downloading..."
186-
try:
187-
with open(local_filepath, "wb") as f:
188-
for data in ui.show_progress(
189-
response.iter_bytes(),
190-
total,
191-
description=description,
192-
):
193-
f.write(data)
194-
except KeyboardInterrupt:
256+
last_exc: Exception | None = None
257+
state: dict = {"file_opened": False}
258+
259+
for attempt in range(_DOWNLOAD_MAX_RETRIES):
260+
state["file_opened"] = False
261+
try:
262+
_download_file_httpx(url, local_filepath, headers, state=state)
263+
return
264+
except _TRANSIENT_EXCEPTIONS as exc:
265+
last_exc = exc
266+
# Only clean up if _download_file_httpx actually opened the destination —
267+
# otherwise we'd delete an unrelated pre-existing file at the same path.
268+
if state["file_opened"]:
269+
_cleanup_partial(local_filepath)
270+
if attempt < _DOWNLOAD_MAX_RETRIES - 1:
271+
wait = _DOWNLOAD_RETRY_BACKOFF * (attempt + 1)
272+
print(f"Download error (attempt {attempt + 1}/{_DOWNLOAD_MAX_RETRIES}): {_friendly_network_error(exc)}")
273+
print(f"Retrying in {wait}s...")
274+
time.sleep(wait)
275+
except KeyboardInterrupt:
276+
# Only prompt/cleanup if we actually opened the destination this attempt.
277+
# If the interrupt arrived during connection setup, there is no partial
278+
# file and the destination may hold an unrelated pre-existing file.
279+
if state["file_opened"]:
195280
delete_eh = ui.prompt_confirm_action("Download interrupted, cleanup files?", True)
196281
if delete_eh:
197-
local_filepath.unlink()
198-
else:
199-
status_reason = guess_status_code_reason(response.status_code, response.read())
200-
raise DownloadException(f"Failed to download file.\n{status_reason}")
282+
_cleanup_partial(local_filepath)
283+
raise
284+
285+
raise DownloadException(
286+
f"Download failed after {_DOWNLOAD_MAX_RETRIES} attempts: "
287+
f"{_friendly_network_error(last_exc)}\n"
288+
f"Please try again later."
289+
) from last_exc
201290

202291

203292
def _load_comfyignore_spec(ignore_filename: str = ".comfyignore") -> PathSpec | None:

0 commit comments

Comments
 (0)