Skip to content

Commit 6342f1d

Browse files
committed
Initial attempt at dataset download issue resolutions
1 parent 399d9eb commit 6342f1d

File tree

1 file changed

+188
-20
lines changed

1 file changed

+188
-20
lines changed

sklbench/datasets/downloaders.py

Lines changed: 188 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# ===============================================================================
1616

1717
import os
18+
import time
19+
import warnings
1820
from typing import Callable, List, Union
1921

2022
import numpy as np
@@ -24,33 +26,197 @@
2426
from sklearn.datasets import fetch_openml
2527

2628

27-
def retrieve(url: str, filename: str) -> None:
29+
def retrieve(url: str, filename: str, max_retries: int = 5) -> None:
30+
"""
31+
Download a file from a URL with retry logic and resume capability.
32+
33+
Args:
34+
url: URL to download from
35+
filename: Local file path to save to
36+
max_retries: Maximum number of retry attempts for failed downloads
37+
"""
2838
if os.path.isfile(filename):
29-
return
30-
elif url.startswith("http"):
31-
response = requests.get(url, stream=True)
32-
if response.status_code != 200:
33-
raise AssertionError(
34-
f"Failed to download from {url}.\n"
35-
f"Response returned status code {response.status_code}"
39+
# Check if file is complete by comparing size
40+
try:
41+
head_response = requests.head(url, allow_redirects=True, timeout=30)
42+
expected_size = int(head_response.headers.get("content-length", 0))
43+
actual_size = os.path.getsize(filename)
44+
45+
if expected_size > 0 and actual_size == expected_size:
46+
# File exists and is complete
47+
return
48+
else:
49+
warnings.warn(
50+
f"Existing file {filename} is incomplete ({actual_size}/{expected_size} bytes). "
51+
f"Will attempt to resume download.",
52+
RuntimeWarning
53+
)
54+
except Exception as e:
55+
# If we can't verify, assume file is complete
56+
warnings.warn(
57+
f"Could not verify file completeness for {filename}: {e}. Assuming complete.",
58+
RuntimeWarning
59+
)
60+
return
61+
62+
if not url.startswith("http"):
63+
raise ValueError(f"URL must start with http:// or https://, got: {url}")
64+
65+
temp_filename = filename + ".partial"
66+
block_size = 8192
67+
68+
for attempt in range(max_retries):
69+
try:
70+
# Check if we can resume a partial download
71+
resume_pos = 0
72+
if os.path.isfile(temp_filename):
73+
resume_pos = os.path.getsize(temp_filename)
74+
headers = {"Range": f"bytes={resume_pos}-"}
75+
mode = "ab" # Append mode
76+
warnings.warn(
77+
f"Resuming download of {url} from byte {resume_pos}",
78+
RuntimeWarning
79+
)
80+
else:
81+
headers = {}
82+
mode = "wb"
83+
84+
response = requests.get(url, stream=True, headers=headers, timeout=60)
85+
86+
# Handle different response codes
87+
if response.status_code == 200:
88+
# Full download
89+
mode = "wb"
90+
resume_pos = 0
91+
elif response.status_code == 206:
92+
# Partial content (resume successful)
93+
pass
94+
elif response.status_code == 416:
95+
# Range not satisfiable - file might be complete
96+
if os.path.isfile(temp_filename):
97+
os.rename(temp_filename, filename)
98+
return
99+
else:
100+
raise AssertionError(
101+
f"Failed to download from {url}. "
102+
f"Response returned status code {response.status_code}"
103+
)
104+
105+
# Get expected total size
106+
if response.status_code == 206:
107+
content_range = response.headers.get("content-range", "")
108+
if content_range:
109+
total_size = int(content_range.split("/")[1])
110+
else:
111+
total_size = 0
112+
else:
113+
total_size = int(response.headers.get("content-length", 0))
114+
115+
# Download the file
116+
bytes_downloaded = resume_pos
117+
with open(temp_filename, mode) as datafile:
118+
for data in response.iter_content(block_size):
119+
if data: # filter out keep-alive chunks
120+
datafile.write(data)
121+
bytes_downloaded += len(data)
122+
123+
# Verify download completeness
124+
if total_size > 0:
125+
actual_size = os.path.getsize(temp_filename)
126+
if actual_size != total_size:
127+
warnings.warn(
128+
f"Download incomplete: {actual_size}/{total_size} bytes. "
129+
f"Attempt {attempt + 1}/{max_retries}",
130+
RuntimeWarning
131+
)
132+
if attempt < max_retries - 1:
133+
continue # Retry
134+
else:
135+
raise AssertionError(
136+
f"Failed to completely download {url} after {max_retries} attempts. "
137+
f"Got {actual_size}/{total_size} bytes"
138+
)
139+
140+
# Download successful, rename temp file to final filename
141+
os.rename(temp_filename, filename)
142+
return
143+
144+
except (requests.exceptions.ChunkedEncodingError,
145+
requests.exceptions.ConnectionError,
146+
requests.exceptions.Timeout) as e:
147+
warnings.warn(
148+
f"Download interrupted for {url}: {type(e).__name__}: {e}. "
149+
f"Attempt {attempt + 1}/{max_retries}",
150+
RuntimeWarning
36151
)
37-
total_size = int(response.headers.get("content-length", 0))
38-
block_size = 8192
39-
n = 0
40-
with open(filename, "wb+") as datafile:
41-
for data in response.iter_content(block_size):
42-
n += len(data) / 1024
43-
datafile.write(data)
44-
if total_size != 0 and n != total_size / 1024:
45-
raise AssertionError("Some content was present but not downloaded/written")
152+
if attempt < max_retries - 1:
153+
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s, 8s, 16s
154+
warnings.warn(f"Waiting {wait_time}s before retry...", RuntimeWarning)
155+
time.sleep(wait_time)
156+
continue
157+
else:
158+
# Clean up partial file if all retries failed
159+
if os.path.isfile(temp_filename):
160+
os.remove(temp_filename)
161+
raise AssertionError(
162+
f"Failed to download {url} after {max_retries} attempts. "
163+
f"Last error: {type(e).__name__}: {e}"
164+
) from e
46165

47166

48167
def fetch_and_correct_openml(
49168
data_id: int, raw_data_cache_dir: str, as_frame: str = "auto"
50169
):
51-
x, y = fetch_openml(
52-
data_id=data_id, return_X_y=True, as_frame=as_frame, data_home=raw_data_cache_dir
53-
)
170+
"""
171+
Fetch OpenML dataset with fallback for MD5 checksum errors.
172+
173+
First tries sklearn's fetch_openml. If that fails due to MD5 checksum mismatch,
174+
falls back to using the openml package directly, which has updated checksums.
175+
"""
176+
try:
177+
# Try sklearn's fetch_openml first
178+
x, y = fetch_openml(
179+
data_id=data_id, return_X_y=True, as_frame=as_frame, data_home=raw_data_cache_dir
180+
)
181+
except ValueError as e:
182+
# Check if it's an MD5 checksum error
183+
if "md5 checksum" in str(e).lower():
184+
warnings.warn(
185+
f"MD5 checksum validation failed for OpenML dataset {data_id}. "
186+
f"Falling back to using openml package directly. "
187+
f"Original error: {e}",
188+
RuntimeWarning
189+
)
190+
191+
# Fall back to openml package which might have updated checksums
192+
try:
193+
import openml
194+
# Configure openml to use the provided cache directory
195+
openml_cache = os.path.join(raw_data_cache_dir, "openml_direct")
196+
os.makedirs(openml_cache, exist_ok=True)
197+
openml.config.set_root_cache_directory(openml_cache)
198+
199+
dataset = openml.datasets.get_dataset(
200+
data_id,
201+
download_data=True,
202+
download_qualities=False,
203+
download_features_meta_data=False
204+
)
205+
#Get the data with target column specified
206+
x, y, _, _ = dataset.get_data(
207+
dataset_format="dataframe" if as_frame == "auto" or as_frame else "array",
208+
target=dataset.default_target_attribute
209+
)
210+
except Exception as openml_error:
211+
raise ValueError(
212+
f"Failed to load OpenML dataset {data_id} using both sklearn and openml package. "
213+
f"sklearn error: {e}. openml error: {openml_error}"
214+
) from openml_error
215+
else:
216+
# Not a checksum error, re-raise
217+
raise
218+
219+
# Validate and convert return types
54220
if (
55221
isinstance(x, csr_matrix)
56222
or isinstance(x, pd.DataFrame)
@@ -59,6 +225,7 @@ def fetch_and_correct_openml(
59225
pass
60226
else:
61227
raise ValueError(f'Unknown "{type(x)}" x type was returned from fetch_openml')
228+
62229
if isinstance(y, pd.Series):
63230
# label transforms to cat.codes if it is passed as categorical series
64231
if isinstance(y.dtype, pd.CategoricalDtype):
@@ -68,6 +235,7 @@ def fetch_and_correct_openml(
68235
pass
69236
else:
70237
raise ValueError(f'Unknown "{type(y)}" y type was returned from fetch_openml')
238+
71239
return x, y
72240

73241

0 commit comments

Comments
 (0)