1212
1313import boto3
1414import requests
15- from botocore .exceptions import ClientError , EndpointConnectionError
15+ from botocore .exceptions import ClientError , ConnectTimeoutError , EndpointConnectionError
1616
1717from gpuhunt ._internal .models import QueryFilter , RawCatalogItem
1818from gpuhunt .providers import AbstractProvider
5353 "MarketOption" : ["OnDemand" ],
5454}
5555describe_instances_limit = 100
56+ pricing_download_retries = 3
57+ pricing_download_chunk_size = 1024 * 1024
58+ # AWS disruption workaround: if a request to one of these regions times out,
59+ # skip that region and continue collecting the catalog.
60+ TEMPORARILY_UNAVAILABLE_REGIONS = {
61+ "me-south-1" ,
62+ }
63+ # If this AWS account is not enabled in one of these regions,
64+ # skip that region and continue collecting the catalog.
65+ ACCOUNT_NOT_ENABLED_REGIONS = {
66+ "ap-southeast-5" ,
67+ "us-gov-west-1" ,
68+ "eu-south-1" ,
69+ "eu-south-2" ,
70+ "ap-southeast-3" ,
71+ "us-west-2-phx-1" ,
72+ "me-central-1" ,
73+ "il-central-1" ,
74+ "ap-southeast-4" ,
75+ "mx-central-1" ,
76+ "af-south-1" ,
77+ "ap-east-2" ,
78+ "us-gov-east-1" ,
79+ "ap-east-1" ,
80+ "ap-south-2" ,
81+ "ap-southeast-6" ,
82+ "eu-central-2" ,
83+ "ap-southeast-7" ,
84+ "ca-west-1" ,
85+ "me-south-1" ,
86+ }
5687
5788
5889class AWSProvider (AbstractProvider ):
@@ -71,6 +102,7 @@ def __init__(self, cache_path: Optional[str] = None):
71102 else :
72103 self .temp_dir = tempfile .TemporaryDirectory ()
73104 self .cache_path = self .temp_dir .name + "/index.csv"
105+ self .ec2_api_regions = _get_ec2_api_regions ()
74106 # todo aws creds
75107 self .preview_gpus = {
76108 "p4de.24xlarge" : ("A100" , 80.0 ),
@@ -80,12 +112,7 @@ def get(
80112 self , query_filter : Optional [QueryFilter ] = None , balance_resources : bool = True
81113 ) -> list [RawCatalogItem ]:
82114 if not os .path .exists (self .cache_path ):
83- logger .info ("Downloading EC2 prices to %s" , self .cache_path )
84- with requests .get (ec2_pricing_url , stream = True , timeout = 20 ) as r :
85- r .raise_for_status ()
86- with open (self .cache_path , "wb" ) as f :
87- for chunk in r .iter_content (chunk_size = 8192 ):
88- f .write (chunk )
115+ self ._download_pricing_file ()
89116
90117 offers = []
91118 with open (self .cache_path , newline = "" ) as f :
@@ -126,30 +153,73 @@ def skip(self, row: dict[str, str]) -> bool:
126153
127154 def fill_gpu_details (self , offers : list [RawCatalogItem ]):
128155 regions = defaultdict (list )
156+ non_ec2_api_regions = set ()
129157 for offer in offers :
130158 if offer .gpu_count > 0 and offer .instance_name not in self .preview_gpus :
159+ if offer .location not in self .ec2_api_regions :
160+ non_ec2_api_regions .add (offer .location )
161+ continue
131162 regions [offer .location ].append (offer .instance_name )
163+ if non_ec2_api_regions :
164+ logger .info (
165+ "Skipping non-EC2 location codes for GPU details: %s" ,
166+ ", " .join (sorted (non_ec2_api_regions )),
167+ )
132168
133169 gpus = copy .deepcopy (self .preview_gpus )
134170 while regions :
135171 region = max (regions , key = lambda r : len (regions [r ]))
136172 instance_types = regions .pop (region )
137173
138- client = boto3 .client ("ec2" , region_name = region )
139- paginator = client .get_paginator ("describe_instance_types" )
140- for offset in range (0 , len (instance_types ), describe_instances_limit ):
141- logger .info ("Fetching GPU details for %s (offset=%s)" , region , offset )
142- pages = paginator .paginate (
143- InstanceTypes = instance_types [offset : offset + describe_instances_limit ]
144- )
145- for page in pages :
146- for i in page ["InstanceTypes" ]:
147- if "GpuInfo" in i :
148- gpu = i ["GpuInfo" ]["Gpus" ][0 ]
149- gpus [i ["InstanceType" ]] = (
150- gpu ["Name" ],
151- _get_gpu_memory_gib (gpu ["Name" ], gpu ["MemoryInfo" ]["SizeInMiB" ]),
152- )
174+ try :
175+ client = boto3 .client ("ec2" , region_name = region )
176+ paginator = client .get_paginator ("describe_instance_types" )
177+ for offset in range (0 , len (instance_types ), describe_instances_limit ):
178+ logger .info ("Fetching GPU details for %s (offset=%s)" , region , offset )
179+ pages = paginator .paginate (
180+ InstanceTypes = instance_types [offset : offset + describe_instances_limit ]
181+ )
182+ for page in pages :
183+ for i in page ["InstanceTypes" ]:
184+ if "GpuInfo" in i :
185+ gpu = i ["GpuInfo" ]["Gpus" ][0 ]
186+ gpus [i ["InstanceType" ]] = (
187+ gpu ["Name" ],
188+ _get_gpu_memory_gib (
189+ gpu ["Name" ], gpu ["MemoryInfo" ]["SizeInMiB" ]
190+ ),
191+ )
192+ except ConnectTimeoutError as e :
193+ if region in TEMPORARILY_UNAVAILABLE_REGIONS :
194+ logger .warning (
195+ "Skipping AWS region %s for GPU details due to temporary AWS regional disruption "
196+ "(connect timeout): %s" ,
197+ region ,
198+ e ,
199+ )
200+ continue
201+ raise RuntimeError (f"Failed AWS GPU details fetch in region { region } : { e } " ) from e
202+ except ClientError as e :
203+ code = e .response .get ("Error" , {}).get ("Code" )
204+ if code == "AuthFailure" and region in ACCOUNT_NOT_ENABLED_REGIONS :
205+ logger .warning (
206+ "Skipping AWS region %s for GPU details because account is not enabled "
207+ "in this region (AuthFailure): %s" ,
208+ region ,
209+ e ,
210+ )
211+ continue
212+ raise RuntimeError (f"Failed AWS GPU details fetch in region { region } : { e } " ) from e
213+ except EndpointConnectionError as e :
214+ if region in ACCOUNT_NOT_ENABLED_REGIONS :
215+ logger .warning (
216+ "Skipping AWS region %s for GPU details because account is not enabled "
217+ "in this region (EndpointConnectionError): %s" ,
218+ region ,
219+ e ,
220+ )
221+ continue
222+ raise RuntimeError (f"Failed AWS GPU details fetch in region { region } : { e } " ) from e
153223
154224 regions = {
155225 region : left
@@ -195,14 +265,80 @@ def _add_spots_worker(
195265 zone_prices ,
196266 ) in instance_prices .items (): # reduce zone prices to a single value
197267 spot_prices [(instance_type , region )] = min (zone_prices )
198- except (ClientError , EndpointConnectionError ):
199- return {}
268+ except ConnectTimeoutError as e :
269+ if region in TEMPORARILY_UNAVAILABLE_REGIONS :
270+ logger .warning (
271+ "Skipping AWS region %s for spot prices due to temporary AWS regional disruption "
272+ "(connect timeout): %s" ,
273+ region ,
274+ e ,
275+ )
276+ return {}
277+ raise RuntimeError (f"Failed AWS spot price fetch in region { region } : { e } " ) from e
278+ except ClientError as e :
279+ code = e .response .get ("Error" , {}).get ("Code" )
280+ if code == "AuthFailure" and region in ACCOUNT_NOT_ENABLED_REGIONS :
281+ logger .warning (
282+ "Skipping AWS region %s for spot prices because account is not enabled "
283+ "in this region (AuthFailure): %s" ,
284+ region ,
285+ e ,
286+ )
287+ return {}
288+ raise RuntimeError (f"Failed AWS spot price fetch in region { region } : { e } " ) from e
289+ except EndpointConnectionError as e :
290+ if region in ACCOUNT_NOT_ENABLED_REGIONS :
291+ logger .warning (
292+ "Skipping AWS region %s for spot prices because account is not enabled "
293+ "in this region (EndpointConnectionError): %s" ,
294+ region ,
295+ e ,
296+ )
297+ return {}
298+ raise RuntimeError (f"Failed AWS spot price fetch in region { region } : { e } " ) from e
200299 return spot_prices
201300
301+ def _download_pricing_file (self ) -> None :
302+ logger .info ("Downloading EC2 prices to %s" , self .cache_path )
303+ temp_cache_path = f"{ self .cache_path } .part"
304+ for attempt in range (1 , pricing_download_retries + 1 ):
305+ try :
306+ with requests .get (ec2_pricing_url , stream = True , timeout = 20 ) as r :
307+ r .raise_for_status ()
308+ with open (temp_cache_path , "wb" ) as f :
309+ for chunk in r .iter_content (chunk_size = pricing_download_chunk_size ):
310+ if chunk :
311+ f .write (chunk )
312+ os .replace (temp_cache_path , self .cache_path )
313+ return
314+ except (requests .RequestException , OSError ) as e :
315+ if os .path .exists (temp_cache_path ):
316+ os .remove (temp_cache_path )
317+ if attempt == pricing_download_retries :
318+ raise RuntimeError (
319+ f"Failed to download AWS pricing file after { pricing_download_retries } "
320+ f"attempts: { e } "
321+ ) from e
322+ logger .warning (
323+ "Failed to download AWS pricing file (attempt %s/%s), retrying: %s" ,
324+ attempt ,
325+ pricing_download_retries ,
326+ e ,
327+ )
328+
202329 def add_spots (self , offers : list [RawCatalogItem ]) -> list [RawCatalogItem ]:
203330 region_instances = defaultdict (set )
331+ non_ec2_api_regions = set ()
204332 for offer in offers :
333+ if offer .location not in self .ec2_api_regions :
334+ non_ec2_api_regions .add (offer .location )
335+ continue
205336 region_instances [offer .location ].add (offer .instance_name )
337+ if non_ec2_api_regions :
338+ logger .info (
339+ "Skipping non-EC2 location codes for spot prices: %s" ,
340+ ", " .join (sorted (non_ec2_api_regions )),
341+ )
206342
207343 spot_prices = dict ()
208344 with ThreadPoolExecutor (max_workers = 8 ) as executor :
@@ -282,3 +418,12 @@ def _parse_gpu_count(s: str) -> Optional[int]:
282418 # AWS fractional GPUs not supported
283419 return None
284420 return int (count )
421+
422+
423+ def _get_ec2_api_regions () -> set [str ]:
424+ session = boto3 .session .Session ()
425+ return {
426+ region
427+ for partition in session .get_available_partitions ()
428+ for region in session .get_available_regions ("ec2" , partition_name = partition )
429+ }
0 commit comments