Skip to content

Commit 23602bc

Browse files
authored
feat: resume interrupted transfers (#9)
* feat: add download resuming * fix: clean up exception handling * fix: test
1 parent a7a76ad commit 23602bc

4 files changed

Lines changed: 125 additions & 23 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ raise *"Multiple suitable storage providers found"* for any HTTP(S) URL.
1717
mtime read from `Content-Length` and `Last-Modified` response headers. Servers that do not
1818
support `HEAD` requests are handled gracefully (size and mtime default to 0). No checksum
1919
is available for generic URLs.
20+
- Resumable downloads: interrupted transfers are continued from where they left off using
21+
HTTP `Range` requests (`206 Partial Content`). Servers that do not support range requests
22+
fall back to a full re-download. Partial files are preserved across retries for
23+
connection/timeout errors, and discarded on checksum mismatches or other errors.
2024

2125
### Removed
2226

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ A Snakemake storage plugin for downloading files via HTTP with local caching, ch
1919
- **Checksum verification**: Automatically verifies checksums (from Zenodo API, data.pypsa.org manifests, or GCS object metadata)
2020
- **Rate limit handling**: Automatically respects Zenodo's rate limits using `X-RateLimit-*` headers with exponential backoff retry
2121
- **Concurrent download control**: Limits simultaneous downloads to prevent overwhelming servers
22+
- **Resumable downloads**: Interrupted transfers resume from where they left off using HTTP range requests
2223
- **Progress bars**: Shows download progress with tqdm
2324
- **Immutable URLs**: Returns mtime=0 for Zenodo and data.pypsa.org (persistent URLs); uses actual mtime for GCS and generic HTTP
2425
- **Environment variable support**: Configure via environment variables for CI/CD workflows
@@ -154,6 +155,7 @@ The plugin automatically:
154155
- Uses `X-RateLimit-Reset` to calculate wait time
155156
- Retries failed requests with exponential backoff (up to 5 attempts)
156157
- Handles transient errors: HTTP errors, timeouts, checksum mismatches, and network issues
158+
- Resumes interrupted downloads using `Range` requests where supported by the server
157159

158160
## URL Handling
159161

src/snakemake_storage_plugin_cached_http/__init__.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,8 @@ def __init__(self, observed: str, expected: str):
118118
retry_decorator = retry(
119119
exceptions=( # pyright: ignore[reportArgumentType]
120120
httpx.HTTPError,
121-
TimeoutError,
122121
OSError,
123122
WrongChecksum,
124-
WorkflowError,
125123
),
126124
tries=5,
127125
delay=3,
@@ -264,19 +262,23 @@ def _get_rate_limit_wait_time(self, headers: httpx.Headers) -> float | None:
264262
return wait_seconds
265263

266264
@asynccontextmanager
267-
async def httpr(self, method: str, url: str):
265+
async def httpr(self, method: str, url: str, headers: dict[str, str] | None = None):
268266
"""
269267
HTTP request wrapper with rate limiting and exception logging.
270268
271269
Args:
272270
method: HTTP method (e.g., "get", "post")
273271
url: URL to request
272+
headers: Optional additional HTTP headers
274273
275274
Yields:
276275
httpx.Response object
277276
"""
278277
try:
279-
async with self.client() as client, client.stream(method, url) as response:
278+
async with (
279+
self.client() as client,
280+
client.stream(method, url, headers=headers) as response,
281+
):
280282
wait_time = self._get_rate_limit_wait_time(response.headers)
281283
if wait_time is not None:
282284
logger.info(
@@ -340,10 +342,7 @@ async def get_http_metadata(self, parsed: ParseResult) -> FileMetadata | None:
340342
if response.status_code == 405:
341343
# HEAD not supported; assume file exists with unknown size/mtime
342344
return FileMetadata(checksum=None, size=0, mtime=0.0)
343-
if response.status_code != 200:
344-
raise WorkflowError(
345-
f"Failed to fetch HTTP metadata: HTTP {response.status_code} ({url})"
346-
)
345+
response.raise_for_status()
347346

348347
size = int(response.headers.get("content-length", 0))
349348

@@ -391,10 +390,7 @@ async def get_zenodo_metadata(self, url: ParseResult) -> FileMetadata | None:
391390
api_url = f"https://{netloc}/api/records/{record_id}"
392391

393392
async with self.httpr("get", api_url) as response:
394-
if response.status_code != 200:
395-
raise WorkflowError(
396-
f"Failed to fetch Zenodo record metadata: HTTP {response.status_code} ({api_url})"
397-
)
393+
response.raise_for_status()
398394

399395
# Read the full response body
400396
content = await response.aread()
@@ -516,10 +512,7 @@ async def get_gcs_metadata(self, url: ParseResult) -> FileMetadata | None:
516512
async with self.httpr("get", api_url) as response:
517513
if response.status_code == 404:
518514
return None
519-
if response.status_code != 200:
520-
raise WorkflowError(
521-
f"Failed to fetch GCS object metadata: HTTP {response.status_code} ({api_url})"
522-
)
515+
response.raise_for_status()
523516

524517
content = await response.aread()
525518
data = json.loads(content)
@@ -728,20 +721,34 @@ async def managed_retrieve(self):
728721
return
729722

730723
try:
724+
# Check for existing partial file to resume
725+
offset = local_path.stat().st_size if local_path.exists() else 0
726+
headers = {"Range": f"bytes={offset}-"} if offset > 0 else None
727+
731728
# Download using a get request, rate limit errors are detected and raise
732729
# WorkflowError to trigger a retry
733-
async with self.provider.httpr("get", query) as response:
734-
if response.status_code != 200:
735-
raise WorkflowError(
736-
f"Failed to download: HTTP {response.status_code} ({query})"
730+
async with self.provider.httpr("get", query, headers=headers) as response:
731+
if response.status_code == 206:
732+
# Server supports resume - append to existing partial file
733+
mode = "ab"
734+
logger.info(f"Resuming {filename} from byte {offset}")
735+
elif response.status_code == 200:
736+
# Server doesn't support Range - discard partial and restart
737+
mode = "wb"
738+
offset = 0
739+
else:
740+
response.raise_for_status()
741+
raise AssertionError(
742+
f"Unhandled status code: {response.status_code}"
737743
)
738744

739-
total_size = int(response.headers.get("content-length", 0))
745+
total_size = int(response.headers.get("content-length", 0)) + offset
740746

741747
# Download to local path with progress bar
742-
with local_path.open(mode="wb") as f:
748+
with local_path.open(mode=mode) as f:
743749
with tqdm(
744750
total=total_size,
751+
initial=offset,
745752
unit="B",
746753
unit_scale=True,
747754
desc=filename,
@@ -758,7 +765,11 @@ async def managed_retrieve(self):
758765
if self.provider.cache:
759766
self.provider.cache.put(query, local_path)
760767

761-
except:
768+
except httpx.TransportError:
769+
# Mid-transfer interruption - keep partial file for resume on next retry
770+
raise
771+
except: # noqa: E722
772+
# Any other error (wrong checksum, HTTP error, KeyboardInterrupt) - delete and maybe restart
762773
if local_path.exists():
763774
local_path.unlink()
764775
raise

tests/test_download.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
import logging
99
import os
1010
import time
11+
from contextlib import asynccontextmanager
12+
from unittest.mock import MagicMock
1113
from urllib.parse import urlparse
1214

15+
import httpx
1316
import pytest
1417

1518
from snakemake_storage_plugin_cached_http import (
19+
FileMetadata,
1620
StorageObject,
1721
StorageProvider,
1822
StorageProviderSettings,
@@ -339,3 +343,84 @@ async def test_cache_staleness_for_mutable_sources(
339343
downloaded_content = cached_path.read_bytes()
340344
assert b'"stale": true' not in downloaded_content
341345
assert downloaded_content == original_content
346+
347+
348+
def make_mock_httpr(content: bytes, fail_at: int | None):
349+
"""
350+
Factory for a mock httpr context manager simulating a range-capable HTTP server.
351+
352+
Args:
353+
content: The full file content to serve.
354+
fail_at: If set, drop the connection after serving this many bytes on the
355+
first (non-Range) request, simulating a mid-transfer interruption.
356+
If None, serve the full content without interruption.
357+
"""
358+
359+
@asynccontextmanager
360+
async def mock_httpr(method, request_url, headers=None):
361+
range_header = headers.get("Range") if headers else None
362+
mock_httpr.received_range_headers.append(range_header)
363+
364+
response = MagicMock()
365+
if range_header is None:
366+
response.status_code = 200
367+
chunk = content
368+
drop_at = fail_at
369+
else:
370+
response.status_code = 206
371+
offset = int(range_header.removeprefix("bytes=").removesuffix("-"))
372+
chunk = content[offset:]
373+
drop_at = None
374+
375+
async def aiter_bytes(chunk_size=8192):
376+
if drop_at is None:
377+
yield chunk
378+
else:
379+
yield chunk[:drop_at]
380+
raise httpx.ReadError("peer closed connection")
381+
382+
response.aiter_bytes = aiter_bytes
383+
response.headers = {"content-length": str(len(chunk))}
384+
385+
yield response
386+
387+
mock_httpr.received_range_headers = []
388+
389+
return mock_httpr
390+
391+
392+
@pytest.mark.asyncio
393+
async def test_resume_on_partial_file(storage_provider, tmp_path):
394+
"""Test that downloads resume from partial files using HTTP Range requests."""
395+
url = TEST_CONFIGS["zenodo"]["url"]
396+
full_content = b'{"port": "test", "lat": 1.0, "lon": 2.0}'
397+
398+
fail_at = 10
399+
mock_httpr = make_mock_httpr(full_content, fail_at=fail_at)
400+
storage_provider.httpr = mock_httpr
401+
storage_provider.cache = None
402+
403+
obj = StorageObject(
404+
query=url,
405+
keep_local=False,
406+
retrieve=True,
407+
provider=storage_provider,
408+
)
409+
410+
local_path = tmp_path / "resume_test" / "file.json"
411+
local_path.parent.mkdir(parents=True, exist_ok=True)
412+
obj.local_path = lambda: local_path
413+
414+
async def mock_get_metadata(url):
415+
return FileMetadata(checksum=None, size=0, mtime=0)
416+
417+
obj.provider.get_metadata = mock_get_metadata
418+
419+
await obj.managed_retrieve()
420+
421+
assert len(mock_httpr.received_range_headers) == 2
422+
assert mock_httpr.received_range_headers[0] is None # first attempt: no Range header
423+
assert mock_httpr.received_range_headers[1] == f"bytes={fail_at}-"
424+
assert local_path.read_bytes() == full_content
425+
426+

0 commit comments

Comments
 (0)