Skip to content

Commit 1964591

Browse files
authored
Support shared AWS compute caches (#3483)
* Log get_offers times * Request aws quotas and zones in parallel * Revert "Request aws quotas and zones in parallel" This reverts commit a0f365e. * Add AWSQuotasSharedCache * Refactor compute caches
1 parent 65eacc7 commit 1964591

File tree

5 files changed

+84
-49
lines changed

5 files changed

+84
-49
lines changed

src/dstack/_internal/core/backends/aws/backend.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import botocore.exceptions
24

35
from dstack._internal.core.backends.aws.compute import AWSCompute
@@ -11,9 +13,12 @@ class AWSBackend(Backend):
1113
TYPE = BackendType.AWS
1214
COMPUTE_CLASS = AWSCompute
1315

14-
def __init__(self, config: AWSConfig):
16+
def __init__(self, config: AWSConfig, compute: Optional[AWSCompute] = None):
1517
self.config = config
16-
self._compute = AWSCompute(self.config)
18+
if compute is not None:
19+
self._compute = compute
20+
else:
21+
self._compute = AWSCompute(self.config)
1722
self._check_credentials()
1823

1924
def compute(self) -> AWSCompute:

src/dstack/_internal/core/backends/aws/compute.py

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import threading
22
from collections.abc import Iterable
33
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
from dataclasses import dataclass, field
45
from typing import Any, Callable, Dict, List, Optional, Tuple
56

67
import boto3
@@ -19,6 +20,8 @@
1920
)
2021
from dstack._internal.core.backends.base.compute import (
2122
Compute,
23+
ComputeCache,
24+
ComputeTTLCache,
2225
ComputeWithAllOffersCached,
2326
ComputeWithCreateInstanceSupport,
2427
ComputeWithGatewaySupport,
@@ -94,6 +97,11 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
9497
return hashkey(*args, **kwargs)
9598

9699

100+
@dataclass
101+
class AWSQuotasCache(ComputeTTLCache):
102+
execution_lock: threading.Lock = field(default_factory=threading.Lock)
103+
104+
97105
class AWSCompute(
98106
ComputeWithAllOffersCached,
99107
ComputeWithCreateInstanceSupport,
@@ -106,7 +114,12 @@ class AWSCompute(
106114
ComputeWithVolumeSupport,
107115
Compute,
108116
):
109-
def __init__(self, config: AWSConfig):
117+
def __init__(
118+
self,
119+
config: AWSConfig,
120+
quotas_cache: Optional[AWSQuotasCache] = None,
121+
zones_cache: Optional[ComputeCache] = None,
122+
):
110123
super().__init__()
111124
self.config = config
112125
if isinstance(config.creds, AWSAccessKeyCreds):
@@ -119,23 +132,18 @@ def __init__(self, config: AWSConfig):
119132
# Caches to avoid redundant API calls when provisioning many instances
120133
# get_offers is already cached but we still cache its sub-functions
121134
# with more aggressive/longer caches.
122-
self._offers_post_filter_cache_lock = threading.Lock()
123-
self._offers_post_filter_cache = TTLCache(maxsize=10, ttl=180)
124-
self._get_regions_to_quotas_cache_lock = threading.Lock()
125-
self._get_regions_to_quotas_execution_lock = threading.Lock()
126-
self._get_regions_to_quotas_cache = TTLCache(maxsize=10, ttl=300)
127-
self._get_regions_to_zones_cache_lock = threading.Lock()
128-
self._get_regions_to_zones_cache = Cache(maxsize=10)
129-
self._get_vpc_id_subnet_id_or_error_cache_lock = threading.Lock()
130-
self._get_vpc_id_subnet_id_or_error_cache = TTLCache(maxsize=100, ttl=600)
131-
self._get_maximum_efa_interfaces_cache_lock = threading.Lock()
132-
self._get_maximum_efa_interfaces_cache = Cache(maxsize=100)
133-
self._get_subnets_availability_zones_cache_lock = threading.Lock()
134-
self._get_subnets_availability_zones_cache = Cache(maxsize=100)
135-
self._create_security_group_cache_lock = threading.Lock()
136-
self._create_security_group_cache = TTLCache(maxsize=100, ttl=600)
137-
self._get_image_id_and_username_cache_lock = threading.Lock()
138-
self._get_image_id_and_username_cache = TTLCache(maxsize=100, ttl=600)
135+
self._offers_post_filter_cache = ComputeTTLCache(cache=TTLCache(maxsize=10, ttl=180))
136+
if quotas_cache is None:
137+
quotas_cache = AWSQuotasCache(cache=TTLCache(maxsize=10, ttl=600))
138+
self._regions_to_quotas_cache = quotas_cache
139+
if zones_cache is None:
140+
zones_cache = ComputeCache(cache=Cache(maxsize=10))
141+
self._regions_to_zones_cache = zones_cache
142+
self._vpc_id_subnet_id_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600))
143+
self._maximum_efa_interfaces_cache = ComputeCache(cache=Cache(maxsize=100))
144+
self._subnets_availability_zones_cache = ComputeCache(cache=Cache(maxsize=100))
145+
self._security_group_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600))
146+
self._image_id_and_username_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600))
139147

140148
def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
141149
offers = get_catalog_offers(
@@ -144,7 +152,7 @@ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability
144152
extra_filter=_supported_instances,
145153
)
146154
regions = list(set(i.region for i in offers))
147-
with self._get_regions_to_quotas_execution_lock:
155+
with self._regions_to_quotas_cache.execution_lock:
148156
# Cache lock does not prevent concurrent execution.
149157
# We use a separate lock to avoid requesting quotas in parallel and hitting rate limits.
150158
regions_to_quotas = self._get_regions_to_quotas(self.session, regions)
@@ -173,9 +181,9 @@ def _get_offers_cached_key(self, requirements: Requirements) -> int:
173181
return hash(requirements.json())
174182

175183
@cachedmethod(
176-
cache=lambda self: self._offers_post_filter_cache,
184+
cache=lambda self: self._offers_post_filter_cache.cache,
177185
key=_get_offers_cached_key,
178-
lock=lambda self: self._offers_post_filter_cache_lock,
186+
lock=lambda self: self._offers_post_filter_cache.lock,
179187
)
180188
def get_offers_post_filter(
181189
self, requirements: Requirements
@@ -789,9 +797,9 @@ def _get_regions_to_quotas_key(
789797
return hashkey(tuple(regions))
790798

791799
@cachedmethod(
792-
cache=lambda self: self._get_regions_to_quotas_cache,
800+
cache=lambda self: self._regions_to_quotas_cache.cache,
793801
key=_get_regions_to_quotas_key,
794-
lock=lambda self: self._get_regions_to_quotas_cache_lock,
802+
lock=lambda self: self._regions_to_quotas_cache.lock,
795803
)
796804
def _get_regions_to_quotas(
797805
self,
@@ -808,9 +816,9 @@ def _get_regions_to_zones_key(
808816
return hashkey(tuple(regions))
809817

810818
@cachedmethod(
811-
cache=lambda self: self._get_regions_to_zones_cache,
819+
cache=lambda self: self._regions_to_zones_cache.cache,
812820
key=_get_regions_to_zones_key,
813-
lock=lambda self: self._get_regions_to_zones_cache_lock,
821+
lock=lambda self: self._regions_to_zones_cache.lock,
814822
)
815823
def _get_regions_to_zones(
816824
self,
@@ -832,9 +840,9 @@ def _get_vpc_id_subnet_id_or_error_cache_key(
832840
)
833841

834842
@cachedmethod(
835-
cache=lambda self: self._get_vpc_id_subnet_id_or_error_cache,
843+
cache=lambda self: self._vpc_id_subnet_id_cache.cache,
836844
key=_get_vpc_id_subnet_id_or_error_cache_key,
837-
lock=lambda self: self._get_vpc_id_subnet_id_or_error_cache_lock,
845+
lock=lambda self: self._vpc_id_subnet_id_cache.lock,
838846
)
839847
def _get_vpc_id_subnet_id_or_error(
840848
self,
@@ -853,9 +861,9 @@ def _get_vpc_id_subnet_id_or_error(
853861
)
854862

855863
@cachedmethod(
856-
cache=lambda self: self._get_maximum_efa_interfaces_cache,
864+
cache=lambda self: self._maximum_efa_interfaces_cache.cache,
857865
key=_ec2client_cache_methodkey,
858-
lock=lambda self: self._get_maximum_efa_interfaces_cache_lock,
866+
lock=lambda self: self._maximum_efa_interfaces_cache.lock,
859867
)
860868
def _get_maximum_efa_interfaces(
861869
self,
@@ -877,9 +885,9 @@ def _get_subnets_availability_zones_key(
877885
return hashkey(region, tuple(subnet_ids))
878886

879887
@cachedmethod(
880-
cache=lambda self: self._get_subnets_availability_zones_cache,
888+
cache=lambda self: self._subnets_availability_zones_cache.cache,
881889
key=_get_subnets_availability_zones_key,
882-
lock=lambda self: self._get_subnets_availability_zones_cache_lock,
890+
lock=lambda self: self._subnets_availability_zones_cache.lock,
883891
)
884892
def _get_subnets_availability_zones(
885893
self,
@@ -893,9 +901,9 @@ def _get_subnets_availability_zones(
893901
)
894902

895903
@cachedmethod(
896-
cache=lambda self: self._create_security_group_cache,
904+
cache=lambda self: self._security_group_cache.cache,
897905
key=_ec2client_cache_methodkey,
898-
lock=lambda self: self._create_security_group_cache_lock,
906+
lock=lambda self: self._security_group_cache.lock,
899907
)
900908
def _create_security_group(
901909
self,
@@ -923,9 +931,9 @@ def _get_image_id_and_username_cache_key(
923931
)
924932

925933
@cachedmethod(
926-
cache=lambda self: self._get_image_id_and_username_cache,
934+
cache=lambda self: self._image_id_and_username_cache.cache,
927935
key=_get_image_id_and_username_cache_key,
928-
lock=lambda self: self._get_image_id_and_username_cache_lock,
936+
lock=lambda self: self._image_id_and_username_cache.lock,
929937
)
930938
def _get_image_id_and_username(
931939
self,

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import threading
77
from abc import ABC, abstractmethod
88
from collections.abc import Iterable, Iterator
9+
from dataclasses import dataclass, field
910
from enum import Enum
1011
from functools import lru_cache
1112
from pathlib import Path
@@ -14,7 +15,7 @@
1415
import git
1516
import requests
1617
import yaml
17-
from cachetools import TTLCache, cachedmethod
18+
from cachetools import Cache, TTLCache, cachedmethod
1819
from gpuhunt import CPUArchitecture
1920

2021
from dstack._internal import settings
@@ -89,6 +90,18 @@ def to_cpu_architecture(self) -> CPUArchitecture:
8990
assert False, self
9091

9192

93+
@dataclass
94+
class ComputeCache:
95+
cache: Cache
96+
lock: threading.Lock = field(default_factory=threading.Lock)
97+
98+
99+
@dataclass
100+
class ComputeTTLCache:
101+
cache: TTLCache
102+
lock: threading.Lock = field(default_factory=threading.Lock)
103+
104+
92105
class Compute(ABC):
93106
"""
94107
A base class for all compute implementations with minimal features.

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import concurrent.futures
22
import json
33
import re
4-
import threading
54
from collections import defaultdict
65
from collections.abc import Iterable
76
from dataclasses import dataclass
@@ -19,6 +18,7 @@
1918
from dstack import version
2019
from dstack._internal.core.backends.base.compute import (
2120
Compute,
21+
ComputeTTLCache,
2222
ComputeWithAllOffersCached,
2323
ComputeWithCreateInstanceSupport,
2424
ComputeWithGatewaySupport,
@@ -127,11 +127,9 @@ def __init__(self, config: GCPConfig):
127127
credentials=self.credentials
128128
)
129129
self.reservations_client = compute_v1.ReservationsClient(credentials=self.credentials)
130-
self._usable_subnets_cache_lock = threading.Lock()
131-
self._usable_subnets_cache = TTLCache(maxsize=1, ttl=120)
132-
self._find_reservation_cache_lock = threading.Lock()
133-
# smaller TTL, since we check the reservation's in_use_count, which can change often
134-
self._find_reservation_cache = TTLCache(maxsize=8, ttl=20)
130+
self._usable_subnets_cache = ComputeTTLCache(cache=TTLCache(maxsize=1, ttl=120))
131+
# Smaller TTL since we check the reservation's in_use_count, which can change often
132+
self._reservation_cache = ComputeTTLCache(cache=TTLCache(maxsize=8, ttl=20))
135133

136134
def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
137135
regions = get_or_error(self.config.regions)
@@ -948,8 +946,8 @@ def _get_roce_subnets(
948946
return nic_subnets
949947

950948
@cachedmethod(
951-
cache=lambda self: self._usable_subnets_cache,
952-
lock=lambda self: self._usable_subnets_cache_lock,
949+
cache=lambda self: self._usable_subnets_cache.cache,
950+
lock=lambda self: self._usable_subnets_cache.lock,
953951
)
954952
def _list_usable_subnets(self) -> list[compute_v1.UsableSubnetwork]:
955953
# To avoid hitting the `ListUsable requests per minute` system limit, we fetch all subnets
@@ -969,8 +967,8 @@ def _get_vpc_subnet(self, region: str) -> Optional[str]:
969967
)
970968

971969
@cachedmethod(
972-
cache=lambda self: self._find_reservation_cache,
973-
lock=lambda self: self._find_reservation_cache_lock,
970+
cache=lambda self: self._reservation_cache.cache,
971+
lock=lambda self: self._reservation_cache.lock,
974972
)
975973
def _find_reservation(self, configured_name: str) -> dict[str, compute_v1.Reservation]:
976974
if match := RESERVATION_PATTERN.fullmatch(configured_name):

src/dstack/_internal/server/services/backends/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import heapq
3+
import time
34
from collections.abc import Iterable, Iterator
45
from typing import Callable, Coroutine, Dict, List, Optional, Tuple
56
from uuid import UUID
@@ -361,7 +362,7 @@ def get_filtered_offers_with_backends(
361362
yield (backend, offer)
362363

363364
logger.info("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends])
364-
tasks = [run_async(backend.compute().get_offers, requirements) for backend in backends]
365+
tasks = [run_async(get_offers_tracked, backend, requirements) for backend in backends]
365366
offers_by_backend = []
366367
for backend, result in zip(backends, await asyncio.gather(*tasks, return_exceptions=True)):
367368
if isinstance(result, BackendError):
@@ -391,3 +392,13 @@ def check_backend_type_available(backend_type: BackendType):
391392
" Ensure that backend dependencies are installed."
392393
f" Available backends: {[b.value for b in list_available_backend_types()]}."
393394
)
395+
396+
397+
def get_offers_tracked(
398+
backend: Backend, requirements: Requirements
399+
) -> Iterator[InstanceOfferWithAvailability]:
400+
start = time.time()
401+
res = backend.compute().get_offers(requirements)
402+
duration = time.time() - start
403+
logger.debug("Got offers from %s in %.6fs", backend.TYPE.value, duration)
404+
return res

0 commit comments

Comments
 (0)