1515# ===============================================================================
1616
1717import os
18+ import time
19+ import warnings
1820from typing import Callable , List , Union
1921
2022import numpy as np
2426from sklearn .datasets import fetch_openml
2527
2628
27- def retrieve (url : str , filename : str ) -> None :
29+ def retrieve (url : str , filename : str , max_retries : int = 5 ) -> None :
30+ """
31+ Download a file from a URL with retry logic and resume capability.
32+
33+ Args:
34+ url: URL to download from
35+ filename: Local file path to save to
36+ max_retries: Maximum number of retry attempts for failed downloads
37+ """
2838 if os .path .isfile (filename ):
29- return
30- elif url .startswith ("http" ):
31- response = requests .get (url , stream = True )
32- if response .status_code != 200 :
33- raise AssertionError (
34- f"Failed to download from { url } .\n "
35- f"Response returned status code { response .status_code } "
39+ # Check if file is complete by comparing size
40+ try :
41+ head_response = requests .head (url , allow_redirects = True , timeout = 30 )
42+ expected_size = int (head_response .headers .get ("content-length" , 0 ))
43+ actual_size = os .path .getsize (filename )
44+
45+ if expected_size > 0 and actual_size == expected_size :
46+ # File exists and is complete
47+ return
48+ else :
49+ warnings .warn (
50+ f"Existing file { filename } is incomplete ({ actual_size } /{ expected_size } bytes). "
51+ f"Will attempt to resume download." ,
52+ RuntimeWarning
53+ )
54+ except Exception as e :
55+ # If we can't verify, assume file is complete
56+ warnings .warn (
57+ f"Could not verify file completeness for { filename } : { e } . Assuming complete." ,
58+ RuntimeWarning
59+ )
60+ return
61+
62+ if not url .startswith ("http" ):
63+ raise ValueError (f"URL must start with http:// or https://, got: { url } " )
64+
65+ temp_filename = filename + ".partial"
66+ block_size = 8192
67+
68+ for attempt in range (max_retries ):
69+ try :
70+ # Check if we can resume a partial download
71+ resume_pos = 0
72+ if os .path .isfile (temp_filename ):
73+ resume_pos = os .path .getsize (temp_filename )
74+ headers = {"Range" : f"bytes={ resume_pos } -" }
75+ mode = "ab" # Append mode
76+ warnings .warn (
77+ f"Resuming download of { url } from byte { resume_pos } " ,
78+ RuntimeWarning
79+ )
80+ else :
81+ headers = {}
82+ mode = "wb"
83+
84+ response = requests .get (url , stream = True , headers = headers , timeout = 60 )
85+
86+ # Handle different response codes
87+ if response .status_code == 200 :
88+ # Full download
89+ mode = "wb"
90+ resume_pos = 0
91+ elif response .status_code == 206 :
92+ # Partial content (resume successful)
93+ pass
94+ elif response .status_code == 416 :
95+ # Range not satisfiable - file might be complete
96+ if os .path .isfile (temp_filename ):
97+ os .rename (temp_filename , filename )
98+ return
99+ else :
100+ raise AssertionError (
101+ f"Failed to download from { url } . "
102+ f"Response returned status code { response .status_code } "
103+ )
104+
105+ # Get expected total size
106+ if response .status_code == 206 :
107+ content_range = response .headers .get ("content-range" , "" )
108+ if content_range :
109+ total_size = int (content_range .split ("/" )[1 ])
110+ else :
111+ total_size = 0
112+ else :
113+ total_size = int (response .headers .get ("content-length" , 0 ))
114+
115+ # Download the file
116+ bytes_downloaded = resume_pos
117+ with open (temp_filename , mode ) as datafile :
118+ for data in response .iter_content (block_size ):
119+ if data : # filter out keep-alive chunks
120+ datafile .write (data )
121+ bytes_downloaded += len (data )
122+
123+ # Verify download completeness
124+ if total_size > 0 :
125+ actual_size = os .path .getsize (temp_filename )
126+ if actual_size != total_size :
127+ warnings .warn (
128+ f"Download incomplete: { actual_size } /{ total_size } bytes. "
129+ f"Attempt { attempt + 1 } /{ max_retries } " ,
130+ RuntimeWarning
131+ )
132+ if attempt < max_retries - 1 :
133+ continue # Retry
134+ else :
135+ raise AssertionError (
136+ f"Failed to completely download { url } after { max_retries } attempts. "
137+ f"Got { actual_size } /{ total_size } bytes"
138+ )
139+
140+ # Download successful, rename temp file to final filename
141+ os .rename (temp_filename , filename )
142+ return
143+
144+ except (requests .exceptions .ChunkedEncodingError ,
145+ requests .exceptions .ConnectionError ,
146+ requests .exceptions .Timeout ) as e :
147+ warnings .warn (
148+ f"Download interrupted for { url } : { type (e ).__name__ } : { e } . "
149+ f"Attempt { attempt + 1 } /{ max_retries } " ,
150+ RuntimeWarning
36151 )
37- total_size = int (response .headers .get ("content-length" , 0 ))
38- block_size = 8192
39- n = 0
40- with open (filename , "wb+" ) as datafile :
41- for data in response .iter_content (block_size ):
42- n += len (data ) / 1024
43- datafile .write (data )
44- if total_size != 0 and n != total_size / 1024 :
45- raise AssertionError ("Some content was present but not downloaded/written" )
152+ if attempt < max_retries - 1 :
153+ wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s, 8s, 16s
154+ warnings .warn (f"Waiting { wait_time } s before retry..." , RuntimeWarning )
155+ time .sleep (wait_time )
156+ continue
157+ else :
158+ # Clean up partial file if all retries failed
159+ if os .path .isfile (temp_filename ):
160+ os .remove (temp_filename )
161+ raise AssertionError (
162+ f"Failed to download { url } after { max_retries } attempts. "
163+ f"Last error: { type (e ).__name__ } : { e } "
164+ ) from e
46165
47166
48167def fetch_and_correct_openml (
49168 data_id : int , raw_data_cache_dir : str , as_frame : str = "auto"
50169):
51- x , y = fetch_openml (
52- data_id = data_id , return_X_y = True , as_frame = as_frame , data_home = raw_data_cache_dir
53- )
170+ """
171+ Fetch OpenML dataset with fallback for MD5 checksum errors.
172+
173+ First tries sklearn's fetch_openml. If that fails due to MD5 checksum mismatch,
174+ falls back to using the openml package directly, which has updated checksums.
175+ """
176+ try :
177+ # Try sklearn's fetch_openml first
178+ x , y = fetch_openml (
179+ data_id = data_id , return_X_y = True , as_frame = as_frame , data_home = raw_data_cache_dir
180+ )
181+ except ValueError as e :
182+ # Check if it's an MD5 checksum error
183+ if "md5 checksum" in str (e ).lower ():
184+ warnings .warn (
185+ f"MD5 checksum validation failed for OpenML dataset { data_id } . "
186+ f"Falling back to using openml package directly. "
187+ f"Original error: { e } " ,
188+ RuntimeWarning
189+ )
190+
191+ # Fall back to openml package which might have updated checksums
192+ try :
193+ import openml
194+ # Configure openml to use the provided cache directory
195+ openml_cache = os .path .join (raw_data_cache_dir , "openml_direct" )
196+ os .makedirs (openml_cache , exist_ok = True )
197+ openml .config .set_root_cache_directory (openml_cache )
198+
199+ dataset = openml .datasets .get_dataset (
200+ data_id ,
201+ download_data = True ,
202+ download_qualities = False ,
203+ download_features_meta_data = False
204+ )
205+ #Get the data with target column specified
206+ x , y , _ , _ = dataset .get_data (
207+ dataset_format = "dataframe" if as_frame == "auto" or as_frame else "array" ,
208+ target = dataset .default_target_attribute
209+ )
210+ except Exception as openml_error :
211+ raise ValueError (
212+ f"Failed to load OpenML dataset { data_id } using both sklearn and openml package. "
213+ f"sklearn error: { e } . openml error: { openml_error } "
214+ ) from openml_error
215+ else :
216+ # Not a checksum error, re-raise
217+ raise
218+
219+ # Validate and convert return types
54220 if (
55221 isinstance (x , csr_matrix )
56222 or isinstance (x , pd .DataFrame )
@@ -59,6 +225,7 @@ def fetch_and_correct_openml(
59225 pass
60226 else :
61227 raise ValueError (f'Unknown "{ type (x )} " x type was returned from fetch_openml' )
228+
62229 if isinstance (y , pd .Series ):
63230 # label transforms to cat.codes if it is passed as categorical series
64231 if isinstance (y .dtype , pd .CategoricalDtype ):
@@ -68,6 +235,7 @@ def fetch_and_correct_openml(
68235 pass
69236 else :
70237 raise ValueError (f'Unknown "{ type (y )} " y type was returned from fetch_openml' )
238+
71239 return x , y
72240
73241
0 commit comments