11import threading
22from collections .abc import Iterable
33from concurrent .futures import ThreadPoolExecutor , as_completed
4+ from dataclasses import dataclass , field
45from typing import Any , Callable , Dict , List , Optional , Tuple
56
67import boto3
1920)
2021from dstack ._internal .core .backends .base .compute import (
2122 Compute ,
23+ ComputeCache ,
24+ ComputeTTLCache ,
2225 ComputeWithAllOffersCached ,
2326 ComputeWithCreateInstanceSupport ,
2427 ComputeWithGatewaySupport ,
@@ -94,6 +97,11 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
9497 return hashkey (* args , ** kwargs )
9598
9699
100+ @dataclass
101+ class AWSQuotasCache (ComputeTTLCache ):
102+ execution_lock : threading .Lock = field (default_factory = threading .Lock )
103+
104+
97105class AWSCompute (
98106 ComputeWithAllOffersCached ,
99107 ComputeWithCreateInstanceSupport ,
@@ -106,7 +114,12 @@ class AWSCompute(
106114 ComputeWithVolumeSupport ,
107115 Compute ,
108116):
109- def __init__ (self , config : AWSConfig ):
117+ def __init__ (
118+ self ,
119+ config : AWSConfig ,
120+ quotas_cache : Optional [AWSQuotasCache ] = None ,
121+ zones_cache : Optional [ComputeCache ] = None ,
122+ ):
110123 super ().__init__ ()
111124 self .config = config
112125 if isinstance (config .creds , AWSAccessKeyCreds ):
@@ -119,23 +132,18 @@ def __init__(self, config: AWSConfig):
119132 # Caches to avoid redundant API calls when provisioning many instances
120133 # get_offers is already cached but we still cache its sub-functions
121134 # with more aggressive/longer caches.
122- self ._offers_post_filter_cache_lock = threading .Lock ()
123- self ._offers_post_filter_cache = TTLCache (maxsize = 10 , ttl = 180 )
124- self ._get_regions_to_quotas_cache_lock = threading .Lock ()
125- self ._get_regions_to_quotas_execution_lock = threading .Lock ()
126- self ._get_regions_to_quotas_cache = TTLCache (maxsize = 10 , ttl = 300 )
127- self ._get_regions_to_zones_cache_lock = threading .Lock ()
128- self ._get_regions_to_zones_cache = Cache (maxsize = 10 )
129- self ._get_vpc_id_subnet_id_or_error_cache_lock = threading .Lock ()
130- self ._get_vpc_id_subnet_id_or_error_cache = TTLCache (maxsize = 100 , ttl = 600 )
131- self ._get_maximum_efa_interfaces_cache_lock = threading .Lock ()
132- self ._get_maximum_efa_interfaces_cache = Cache (maxsize = 100 )
133- self ._get_subnets_availability_zones_cache_lock = threading .Lock ()
134- self ._get_subnets_availability_zones_cache = Cache (maxsize = 100 )
135- self ._create_security_group_cache_lock = threading .Lock ()
136- self ._create_security_group_cache = TTLCache (maxsize = 100 , ttl = 600 )
137- self ._get_image_id_and_username_cache_lock = threading .Lock ()
138- self ._get_image_id_and_username_cache = TTLCache (maxsize = 100 , ttl = 600 )
135+ self ._offers_post_filter_cache = ComputeTTLCache (cache = TTLCache (maxsize = 10 , ttl = 180 ))
136+ if quotas_cache is None :
137+ quotas_cache = AWSQuotasCache (cache = TTLCache (maxsize = 10 , ttl = 600 ))
138+ self ._regions_to_quotas_cache = quotas_cache
139+ if zones_cache is None :
140+ zones_cache = ComputeCache (cache = Cache (maxsize = 10 ))
141+ self ._regions_to_zones_cache = zones_cache
142+ self ._vpc_id_subnet_id_cache = ComputeTTLCache (cache = TTLCache (maxsize = 100 , ttl = 600 ))
143+ self ._maximum_efa_interfaces_cache = ComputeCache (cache = Cache (maxsize = 100 ))
144+ self ._subnets_availability_zones_cache = ComputeCache (cache = Cache (maxsize = 100 ))
145+ self ._security_group_cache = ComputeTTLCache (cache = TTLCache (maxsize = 100 , ttl = 600 ))
146+ self ._image_id_and_username_cache = ComputeTTLCache (cache = TTLCache (maxsize = 100 , ttl = 600 ))
139147
140148 def get_all_offers_with_availability (self ) -> List [InstanceOfferWithAvailability ]:
141149 offers = get_catalog_offers (
@@ -144,7 +152,7 @@ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability
144152 extra_filter = _supported_instances ,
145153 )
146154 regions = list (set (i .region for i in offers ))
147- with self ._get_regions_to_quotas_execution_lock :
155+ with self ._regions_to_quotas_cache . execution_lock :
148156 # Cache lock does not prevent concurrent execution.
149157 # We use a separate lock to avoid requesting quotas in parallel and hitting rate limits.
150158 regions_to_quotas = self ._get_regions_to_quotas (self .session , regions )
@@ -173,9 +181,9 @@ def _get_offers_cached_key(self, requirements: Requirements) -> int:
173181 return hash (requirements .json ())
174182
175183 @cachedmethod (
176- cache = lambda self : self ._offers_post_filter_cache ,
184+ cache = lambda self : self ._offers_post_filter_cache . cache ,
177185 key = _get_offers_cached_key ,
178- lock = lambda self : self ._offers_post_filter_cache_lock ,
186+ lock = lambda self : self ._offers_post_filter_cache . lock ,
179187 )
180188 def get_offers_post_filter (
181189 self , requirements : Requirements
@@ -789,9 +797,9 @@ def _get_regions_to_quotas_key(
789797 return hashkey (tuple (regions ))
790798
791799 @cachedmethod (
792- cache = lambda self : self ._get_regions_to_quotas_cache ,
800+ cache = lambda self : self ._regions_to_quotas_cache . cache ,
793801 key = _get_regions_to_quotas_key ,
794- lock = lambda self : self ._get_regions_to_quotas_cache_lock ,
802+ lock = lambda self : self ._regions_to_quotas_cache . lock ,
795803 )
796804 def _get_regions_to_quotas (
797805 self ,
@@ -808,9 +816,9 @@ def _get_regions_to_zones_key(
808816 return hashkey (tuple (regions ))
809817
810818 @cachedmethod (
811- cache = lambda self : self ._get_regions_to_zones_cache ,
819+ cache = lambda self : self ._regions_to_zones_cache . cache ,
812820 key = _get_regions_to_zones_key ,
813- lock = lambda self : self ._get_regions_to_zones_cache_lock ,
821+ lock = lambda self : self ._regions_to_zones_cache . lock ,
814822 )
815823 def _get_regions_to_zones (
816824 self ,
@@ -832,9 +840,9 @@ def _get_vpc_id_subnet_id_or_error_cache_key(
832840 )
833841
834842 @cachedmethod (
835- cache = lambda self : self ._get_vpc_id_subnet_id_or_error_cache ,
843+ cache = lambda self : self ._vpc_id_subnet_id_cache . cache ,
836844 key = _get_vpc_id_subnet_id_or_error_cache_key ,
837- lock = lambda self : self ._get_vpc_id_subnet_id_or_error_cache_lock ,
845+ lock = lambda self : self ._vpc_id_subnet_id_cache . lock ,
838846 )
839847 def _get_vpc_id_subnet_id_or_error (
840848 self ,
@@ -853,9 +861,9 @@ def _get_vpc_id_subnet_id_or_error(
853861 )
854862
855863 @cachedmethod (
856- cache = lambda self : self ._get_maximum_efa_interfaces_cache ,
864+ cache = lambda self : self ._maximum_efa_interfaces_cache . cache ,
857865 key = _ec2client_cache_methodkey ,
858- lock = lambda self : self ._get_maximum_efa_interfaces_cache_lock ,
866+ lock = lambda self : self ._maximum_efa_interfaces_cache . lock ,
859867 )
860868 def _get_maximum_efa_interfaces (
861869 self ,
@@ -877,9 +885,9 @@ def _get_subnets_availability_zones_key(
877885 return hashkey (region , tuple (subnet_ids ))
878886
879887 @cachedmethod (
880- cache = lambda self : self ._get_subnets_availability_zones_cache ,
888+ cache = lambda self : self ._subnets_availability_zones_cache . cache ,
881889 key = _get_subnets_availability_zones_key ,
882- lock = lambda self : self ._get_subnets_availability_zones_cache_lock ,
890+ lock = lambda self : self ._subnets_availability_zones_cache . lock ,
883891 )
884892 def _get_subnets_availability_zones (
885893 self ,
@@ -893,9 +901,9 @@ def _get_subnets_availability_zones(
893901 )
894902
895903 @cachedmethod (
896- cache = lambda self : self ._create_security_group_cache ,
904+ cache = lambda self : self ._security_group_cache . cache ,
897905 key = _ec2client_cache_methodkey ,
898- lock = lambda self : self ._create_security_group_cache_lock ,
906+ lock = lambda self : self ._security_group_cache . lock ,
899907 )
900908 def _create_security_group (
901909 self ,
@@ -923,9 +931,9 @@ def _get_image_id_and_username_cache_key(
923931 )
924932
925933 @cachedmethod (
926- cache = lambda self : self ._get_image_id_and_username_cache ,
934+ cache = lambda self : self ._image_id_and_username_cache . cache ,
927935 key = _get_image_id_and_username_cache_key ,
928- lock = lambda self : self ._get_image_id_and_username_cache_lock ,
936+ lock = lambda self : self ._image_id_and_username_cache . lock ,
929937 )
930938 def _get_image_id_and_username (
931939 self ,
0 commit comments