Skip to content

Commit 580610b

Browse files
committed
reduce unnecessary diff
1 parent 2141a21 commit 580610b

File tree

1 file changed

+60
-183
lines changed

1 file changed

+60
-183
lines changed

sklbench/datasets/downloaders.py

Lines changed: 60 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -16,225 +16,102 @@
1616

1717
import os
1818
import time
19-
import warnings
2019
from typing import Callable, List, Union
2120

2221
import numpy as np
22+
import openml
2323
import pandas as pd
2424
import requests
2525
from scipy.sparse import csr_matrix
26-
from sklearn.datasets import fetch_openml
2726

2827

29-
def retrieve(url: str, filename: str, max_retries: int = 5) -> None:
30-
"""
31-
Download a file from a URL with retry logic and resume capability.
32-
33-
Args:
34-
url: URL to download from
35-
filename: Local file path to save to
36-
max_retries: Maximum number of retry attempts for failed downloads
37-
"""
28+
def retrieve(url: str, filename: str, max_retries: int = 3) -> None:
29+
"""Download a file from a URL with basic retry logic."""
3830
if os.path.isfile(filename):
39-
# Check if file is complete by comparing size
40-
try:
41-
head_response = requests.head(url, allow_redirects=True, timeout=30)
42-
expected_size = int(head_response.headers.get("content-length", 0))
43-
actual_size = os.path.getsize(filename)
44-
45-
if expected_size > 0 and actual_size == expected_size:
46-
# File exists and is complete
47-
return
48-
else:
49-
warnings.warn(
50-
f"Existing file {filename} is incomplete ({actual_size}/{expected_size} bytes). "
51-
f"Will attempt to resume download.",
52-
RuntimeWarning
53-
)
54-
except Exception as e:
55-
# If we can't verify, assume file is complete
56-
warnings.warn(
57-
f"Could not verify file completeness for {filename}: {e}. Assuming complete.",
58-
RuntimeWarning
59-
)
60-
return
31+
return
6132

6233
if not url.startswith("http"):
6334
raise ValueError(f"URL must start with http:// or https://, got: {url}")
6435

65-
temp_filename = filename + ".partial"
66-
block_size = 8192
67-
6836
for attempt in range(max_retries):
6937
try:
70-
# Check if we can resume a partial download
71-
resume_pos = 0
72-
if os.path.isfile(temp_filename):
73-
resume_pos = os.path.getsize(temp_filename)
74-
headers = {"Range": f"bytes={resume_pos}-"}
75-
mode = "ab" # Append mode
76-
warnings.warn(
77-
f"Resuming download of {url} from byte {resume_pos}",
78-
RuntimeWarning
79-
)
80-
else:
81-
headers = {}
82-
mode = "wb"
83-
84-
response = requests.get(url, stream=True, headers=headers, timeout=60)
85-
86-
# Handle different response codes
87-
if response.status_code == 200:
88-
# Full download
89-
mode = "wb"
90-
resume_pos = 0
91-
elif response.status_code == 206:
92-
# Partial content (resume successful)
93-
pass
94-
elif response.status_code == 416:
95-
# Range not satisfiable - file might be complete
96-
if os.path.isfile(temp_filename):
97-
os.rename(temp_filename, filename)
98-
return
99-
else:
38+
response = requests.get(url, stream=True, timeout=120)
39+
if response.status_code != 200:
10040
raise AssertionError(
10141
f"Failed to download from {url}. "
10242
f"Response returned status code {response.status_code}"
10343
)
10444

105-
# Get expected total size
106-
if response.status_code == 206:
107-
content_range = response.headers.get("content-range", "")
108-
if content_range:
109-
total_size = int(content_range.split("/")[1])
110-
else:
111-
total_size = 0
112-
else:
113-
total_size = int(response.headers.get("content-length", 0))
45+
total_size = int(response.headers.get("content-length", 0))
46+
block_size = 8192
11447

115-
# Download the file
116-
bytes_downloaded = resume_pos
117-
with open(temp_filename, mode) as datafile:
48+
with open(filename, "wb") as datafile:
49+
bytes_written = 0
11850
for data in response.iter_content(block_size):
119-
if data: # filter out keep-alive chunks
51+
if data:
12052
datafile.write(data)
121-
bytes_downloaded += len(data)
122-
123-
# Verify download completeness
124-
if total_size > 0:
125-
actual_size = os.path.getsize(temp_filename)
126-
if actual_size != total_size:
127-
warnings.warn(
128-
f"Download incomplete: {actual_size}/{total_size} bytes. "
129-
f"Attempt {attempt + 1}/{max_retries}",
130-
RuntimeWarning
131-
)
132-
if attempt < max_retries - 1:
133-
continue # Retry
134-
else:
135-
raise AssertionError(
136-
f"Failed to completely download {url} after {max_retries} attempts. "
137-
f"Got {actual_size}/{total_size} bytes"
138-
)
139-
140-
# Download successful, rename temp file to final filename
141-
os.rename(temp_filename, filename)
53+
bytes_written += len(data)
54+
55+
# Verify download completeness if size is known
56+
if total_size > 0 and bytes_written != total_size:
57+
os.remove(filename)
58+
if attempt < max_retries - 1:
59+
time.sleep(1)
60+
continue
61+
raise AssertionError(
62+
f"Incomplete download from {url}. "
63+
f"Expected {total_size} bytes, got {bytes_written}"
64+
)
14265
return
14366

144-
except (requests.exceptions.ChunkedEncodingError,
145-
requests.exceptions.ConnectionError,
146-
requests.exceptions.Timeout) as e:
147-
warnings.warn(
148-
f"Download interrupted for {url}: {type(e).__name__}: {e}. "
149-
f"Attempt {attempt + 1}/{max_retries}",
150-
RuntimeWarning
151-
)
67+
except (
68+
requests.exceptions.RequestException,
69+
IOError,
70+
) as e:
71+
if os.path.isfile(filename):
72+
os.remove(filename)
15273
if attempt < max_retries - 1:
153-
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s, 8s, 16s
154-
warnings.warn(f"Waiting {wait_time}s before retry...", RuntimeWarning)
155-
time.sleep(wait_time)
74+
time.sleep(1)
15675
continue
157-
else:
158-
# Clean up partial file if all retries failed
159-
if os.path.isfile(temp_filename):
160-
os.remove(temp_filename)
161-
raise AssertionError(
162-
f"Failed to download {url} after {max_retries} attempts. "
163-
f"Last error: {type(e).__name__}: {e}"
164-
) from e
76+
raise AssertionError(
77+
f"Failed to download {url} after {max_retries} attempts: {e}"
78+
) from e
16579

16680

16781
def fetch_and_correct_openml(
16882
data_id: int, raw_data_cache_dir: str, as_frame: str = "auto"
16983
):
170-
"""
171-
Fetch OpenML dataset with fallback for MD5 checksum errors.
172-
173-
First tries sklearn's fetch_openml. If that fails due to MD5 checksum mismatch,
174-
falls back to using the openml package directly, which has updated checksums.
175-
"""
176-
try:
177-
# Try sklearn's fetch_openml first
178-
x, y = fetch_openml(
179-
data_id=data_id, return_X_y=True, as_frame=as_frame, data_home=raw_data_cache_dir
180-
)
181-
except ValueError as e:
182-
# Check if it's an MD5 checksum error
183-
if "md5 checksum" in str(e).lower():
184-
warnings.warn(
185-
f"MD5 checksum validation failed for OpenML dataset {data_id}. "
186-
f"Falling back to using openml package directly. "
187-
f"Original error: {e}",
188-
RuntimeWarning
189-
)
190-
191-
# Fall back to openml package which might have updated checksums
192-
try:
193-
import openml
194-
# Configure openml to use the provided cache directory
195-
openml_cache = os.path.join(raw_data_cache_dir, "openml_direct")
196-
os.makedirs(openml_cache, exist_ok=True)
197-
openml.config.set_root_cache_directory(openml_cache)
198-
199-
dataset = openml.datasets.get_dataset(
200-
data_id,
201-
download_data=True,
202-
download_qualities=False,
203-
download_features_meta_data=False
204-
)
205-
#Get the data with target column specified
206-
x, y, _, _ = dataset.get_data(
207-
dataset_format="dataframe" if as_frame == "auto" or as_frame else "array",
208-
target=dataset.default_target_attribute
209-
)
210-
except Exception as openml_error:
211-
raise ValueError(
212-
f"Failed to load OpenML dataset {data_id} using both sklearn and openml package. "
213-
f"sklearn error: {e}. openml error: {openml_error}"
214-
) from openml_error
215-
else:
216-
# Not a checksum error, re-raise
217-
raise
218-
219-
# Validate and convert return types
220-
if (
221-
isinstance(x, csr_matrix)
222-
or isinstance(x, pd.DataFrame)
223-
or isinstance(x, np.ndarray)
224-
):
225-
pass
226-
else:
227-
raise ValueError(f'Unknown "{type(x)}" x type was returned from fetch_openml')
228-
84+
"""Fetch OpenML dataset using the openml package."""
85+
# Configure openml cache directory
86+
openml_cache = os.path.join(raw_data_cache_dir, "openml")
87+
os.makedirs(openml_cache, exist_ok=True)
88+
openml.config.set_root_cache_directory(openml_cache)
89+
90+
# Fetch the dataset
91+
dataset = openml.datasets.get_dataset(
92+
data_id,
93+
download_data=True,
94+
download_qualities=False,
95+
download_features_meta_data=False,
96+
)
97+
98+
# Get the data with target column specified
99+
x, y, _, _ = dataset.get_data(
100+
dataset_format="dataframe" if as_frame == "auto" or as_frame else "array",
101+
target=dataset.default_target_attribute,
102+
)
103+
104+
# Validate x type
105+
if not isinstance(x, (csr_matrix, pd.DataFrame, np.ndarray)):
106+
raise ValueError(f'Unknown x type "{type(x)}" returned from openml')
107+
108+
# Convert y to numpy array if needed
229109
if isinstance(y, pd.Series):
230-
# label transforms to cat.codes if it is passed as categorical series
231110
if isinstance(y.dtype, pd.CategoricalDtype):
232111
y = y.cat.codes
233112
y = y.values
234-
elif isinstance(y, np.ndarray):
235-
pass
236-
else:
237-
raise ValueError(f'Unknown "{type(y)}" y type was returned from fetch_openml')
113+
elif not isinstance(y, np.ndarray):
114+
raise ValueError(f'Unknown y type "{type(y)}" returned from openml')
238115

239116
return x, y
240117

0 commit comments

Comments
 (0)