Skip to content

Commit 1c14ee5

Browse files
peterschmidt85Andrey Cheptsov
andauthored
Handle AWS outages with allowlist skip-on-failure (#226)
* Handle AWS region outages with allowlist skip-on-failure * Allow skip-on-failure for ap-southeast-5 * Handle AWS region failures and non-EC2 locations --------- Co-authored-by: Andrey Cheptsov <andrey.cheptsov@github.com>
1 parent 1804bb0 commit 1c14ee5

1 file changed

Lines changed: 169 additions & 24 deletions

File tree

src/gpuhunt/providers/aws.py

Lines changed: 169 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import boto3
1414
import requests
15-
from botocore.exceptions import ClientError, EndpointConnectionError
15+
from botocore.exceptions import ClientError, ConnectTimeoutError, EndpointConnectionError
1616

1717
from gpuhunt._internal.models import QueryFilter, RawCatalogItem
1818
from gpuhunt.providers import AbstractProvider
@@ -53,6 +53,37 @@
5353
"MarketOption": ["OnDemand"],
5454
}
5555
describe_instances_limit = 100
56+
pricing_download_retries = 3
57+
pricing_download_chunk_size = 1024 * 1024
58+
# AWS disruption workaround: if a request to one of these regions times out,
59+
# skip that region and continue collecting the catalog.
60+
TEMPORARILY_UNAVAILABLE_REGIONS = {
61+
"me-south-1",
62+
}
63+
# If this AWS account is not enabled in one of these regions,
64+
# skip that region and continue collecting the catalog.
65+
ACCOUNT_NOT_ENABLED_REGIONS = {
66+
"ap-southeast-5",
67+
"us-gov-west-1",
68+
"eu-south-1",
69+
"eu-south-2",
70+
"ap-southeast-3",
71+
"us-west-2-phx-1",
72+
"me-central-1",
73+
"il-central-1",
74+
"ap-southeast-4",
75+
"mx-central-1",
76+
"af-south-1",
77+
"ap-east-2",
78+
"us-gov-east-1",
79+
"ap-east-1",
80+
"ap-south-2",
81+
"ap-southeast-6",
82+
"eu-central-2",
83+
"ap-southeast-7",
84+
"ca-west-1",
85+
"me-south-1",
86+
}
5687

5788

5889
class AWSProvider(AbstractProvider):
@@ -71,6 +102,7 @@ def __init__(self, cache_path: Optional[str] = None):
71102
else:
72103
self.temp_dir = tempfile.TemporaryDirectory()
73104
self.cache_path = self.temp_dir.name + "/index.csv"
105+
self.ec2_api_regions = _get_ec2_api_regions()
74106
# todo aws creds
75107
self.preview_gpus = {
76108
"p4de.24xlarge": ("A100", 80.0),
@@ -80,12 +112,7 @@ def get(
80112
self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True
81113
) -> list[RawCatalogItem]:
82114
if not os.path.exists(self.cache_path):
83-
logger.info("Downloading EC2 prices to %s", self.cache_path)
84-
with requests.get(ec2_pricing_url, stream=True, timeout=20) as r:
85-
r.raise_for_status()
86-
with open(self.cache_path, "wb") as f:
87-
for chunk in r.iter_content(chunk_size=8192):
88-
f.write(chunk)
115+
self._download_pricing_file()
89116

90117
offers = []
91118
with open(self.cache_path, newline="") as f:
@@ -126,30 +153,73 @@ def skip(self, row: dict[str, str]) -> bool:
126153

127154
def fill_gpu_details(self, offers: list[RawCatalogItem]):
128155
regions = defaultdict(list)
156+
non_ec2_api_regions = set()
129157
for offer in offers:
130158
if offer.gpu_count > 0 and offer.instance_name not in self.preview_gpus:
159+
if offer.location not in self.ec2_api_regions:
160+
non_ec2_api_regions.add(offer.location)
161+
continue
131162
regions[offer.location].append(offer.instance_name)
163+
if non_ec2_api_regions:
164+
logger.info(
165+
"Skipping non-EC2 location codes for GPU details: %s",
166+
", ".join(sorted(non_ec2_api_regions)),
167+
)
132168

133169
gpus = copy.deepcopy(self.preview_gpus)
134170
while regions:
135171
region = max(regions, key=lambda r: len(regions[r]))
136172
instance_types = regions.pop(region)
137173

138-
client = boto3.client("ec2", region_name=region)
139-
paginator = client.get_paginator("describe_instance_types")
140-
for offset in range(0, len(instance_types), describe_instances_limit):
141-
logger.info("Fetching GPU details for %s (offset=%s)", region, offset)
142-
pages = paginator.paginate(
143-
InstanceTypes=instance_types[offset : offset + describe_instances_limit]
144-
)
145-
for page in pages:
146-
for i in page["InstanceTypes"]:
147-
if "GpuInfo" in i:
148-
gpu = i["GpuInfo"]["Gpus"][0]
149-
gpus[i["InstanceType"]] = (
150-
gpu["Name"],
151-
_get_gpu_memory_gib(gpu["Name"], gpu["MemoryInfo"]["SizeInMiB"]),
152-
)
174+
try:
175+
client = boto3.client("ec2", region_name=region)
176+
paginator = client.get_paginator("describe_instance_types")
177+
for offset in range(0, len(instance_types), describe_instances_limit):
178+
logger.info("Fetching GPU details for %s (offset=%s)", region, offset)
179+
pages = paginator.paginate(
180+
InstanceTypes=instance_types[offset : offset + describe_instances_limit]
181+
)
182+
for page in pages:
183+
for i in page["InstanceTypes"]:
184+
if "GpuInfo" in i:
185+
gpu = i["GpuInfo"]["Gpus"][0]
186+
gpus[i["InstanceType"]] = (
187+
gpu["Name"],
188+
_get_gpu_memory_gib(
189+
gpu["Name"], gpu["MemoryInfo"]["SizeInMiB"]
190+
),
191+
)
192+
except ConnectTimeoutError as e:
193+
if region in TEMPORARILY_UNAVAILABLE_REGIONS:
194+
logger.warning(
195+
"Skipping AWS region %s for GPU details due to temporary AWS regional disruption "
196+
"(connect timeout): %s",
197+
region,
198+
e,
199+
)
200+
continue
201+
raise RuntimeError(f"Failed AWS GPU details fetch in region {region}: {e}") from e
202+
except ClientError as e:
203+
code = e.response.get("Error", {}).get("Code")
204+
if code == "AuthFailure" and region in ACCOUNT_NOT_ENABLED_REGIONS:
205+
logger.warning(
206+
"Skipping AWS region %s for GPU details because account is not enabled "
207+
"in this region (AuthFailure): %s",
208+
region,
209+
e,
210+
)
211+
continue
212+
raise RuntimeError(f"Failed AWS GPU details fetch in region {region}: {e}") from e
213+
except EndpointConnectionError as e:
214+
if region in ACCOUNT_NOT_ENABLED_REGIONS:
215+
logger.warning(
216+
"Skipping AWS region %s for GPU details because account is not enabled "
217+
"in this region (EndpointConnectionError): %s",
218+
region,
219+
e,
220+
)
221+
continue
222+
raise RuntimeError(f"Failed AWS GPU details fetch in region {region}: {e}") from e
153223

154224
regions = {
155225
region: left
@@ -195,14 +265,80 @@ def _add_spots_worker(
195265
zone_prices,
196266
) in instance_prices.items(): # reduce zone prices to a single value
197267
spot_prices[(instance_type, region)] = min(zone_prices)
198-
except (ClientError, EndpointConnectionError):
199-
return {}
268+
except ConnectTimeoutError as e:
269+
if region in TEMPORARILY_UNAVAILABLE_REGIONS:
270+
logger.warning(
271+
"Skipping AWS region %s for spot prices due to temporary AWS regional disruption "
272+
"(connect timeout): %s",
273+
region,
274+
e,
275+
)
276+
return {}
277+
raise RuntimeError(f"Failed AWS spot price fetch in region {region}: {e}") from e
278+
except ClientError as e:
279+
code = e.response.get("Error", {}).get("Code")
280+
if code == "AuthFailure" and region in ACCOUNT_NOT_ENABLED_REGIONS:
281+
logger.warning(
282+
"Skipping AWS region %s for spot prices because account is not enabled "
283+
"in this region (AuthFailure): %s",
284+
region,
285+
e,
286+
)
287+
return {}
288+
raise RuntimeError(f"Failed AWS spot price fetch in region {region}: {e}") from e
289+
except EndpointConnectionError as e:
290+
if region in ACCOUNT_NOT_ENABLED_REGIONS:
291+
logger.warning(
292+
"Skipping AWS region %s for spot prices because account is not enabled "
293+
"in this region (EndpointConnectionError): %s",
294+
region,
295+
e,
296+
)
297+
return {}
298+
raise RuntimeError(f"Failed AWS spot price fetch in region {region}: {e}") from e
200299
return spot_prices
201300

301+
def _download_pricing_file(self) -> None:
302+
logger.info("Downloading EC2 prices to %s", self.cache_path)
303+
temp_cache_path = f"{self.cache_path}.part"
304+
for attempt in range(1, pricing_download_retries + 1):
305+
try:
306+
with requests.get(ec2_pricing_url, stream=True, timeout=20) as r:
307+
r.raise_for_status()
308+
with open(temp_cache_path, "wb") as f:
309+
for chunk in r.iter_content(chunk_size=pricing_download_chunk_size):
310+
if chunk:
311+
f.write(chunk)
312+
os.replace(temp_cache_path, self.cache_path)
313+
return
314+
except (requests.RequestException, OSError) as e:
315+
if os.path.exists(temp_cache_path):
316+
os.remove(temp_cache_path)
317+
if attempt == pricing_download_retries:
318+
raise RuntimeError(
319+
f"Failed to download AWS pricing file after {pricing_download_retries} "
320+
f"attempts: {e}"
321+
) from e
322+
logger.warning(
323+
"Failed to download AWS pricing file (attempt %s/%s), retrying: %s",
324+
attempt,
325+
pricing_download_retries,
326+
e,
327+
)
328+
202329
def add_spots(self, offers: list[RawCatalogItem]) -> list[RawCatalogItem]:
203330
region_instances = defaultdict(set)
331+
non_ec2_api_regions = set()
204332
for offer in offers:
333+
if offer.location not in self.ec2_api_regions:
334+
non_ec2_api_regions.add(offer.location)
335+
continue
205336
region_instances[offer.location].add(offer.instance_name)
337+
if non_ec2_api_regions:
338+
logger.info(
339+
"Skipping non-EC2 location codes for spot prices: %s",
340+
", ".join(sorted(non_ec2_api_regions)),
341+
)
206342

207343
spot_prices = dict()
208344
with ThreadPoolExecutor(max_workers=8) as executor:
@@ -282,3 +418,12 @@ def _parse_gpu_count(s: str) -> Optional[int]:
282418
# AWS fractional GPUs not supported
283419
return None
284420
return int(count)
421+
422+
423+
def _get_ec2_api_regions() -> set[str]:
424+
session = boto3.session.Session()
425+
return {
426+
region
427+
for partition in session.get_available_partitions()
428+
for region in session.get_available_regions("ec2", partition_name=partition)
429+
}

0 commit comments

Comments
 (0)