Skip to content

Commit 90fc7c9

Browse files
Implement requirements-independent offers cache (#3091)
* Cache GCP offers with availability * refactor: update get_offers method signature to remove optional requirements Co-authored-by: aider (bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0) <aider@aider.chat> * Introduce ComputeWithAllOffersCached * feat: migrate AWSCompute to use ComputeWithAllOffersCached with reservation handling Co-authored-by: aider (bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0) <aider@aider.chat> * refactor: update compute classes to use flexible requirements filtering Co-authored-by: aider (bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0) <aider@aider.chat> * Cache AWS offers with availability * refactor: migrate AzureCompute to use ComputeWithAllOffersCached Co-authored-by: aider (bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0) <aider@aider.chat> * refactor: migrate CloudriftCompute to use ComputeWithAllOffersCached Co-authored-by: aider (bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0) <aider@aider.chat> * refactor: migrate DatacrunchCompute to use ComputeWithAllOffersCached Co-authored-by: aider (bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0) <aider@aider.chat> * fix missing Compute * Migrate all backends to ComputeWithAllOffersCached * refactor: inherit from ComputeWithAllOffersCached and update get_offers method Co-authored-by: aider (bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0) <aider@aider.chat> * Move by requirements cache to ComputeWithFilteredOffersCached * Implement get_offers_modifier for AWS * Implement get_offers_modifier for all backends with CONFIGURABLE_DISK_SIZE * Fix backend offers * Fix nebius * Fix oci * Use ComputeWithAllOffersCached for kuberenetes * Cache AWS.get_offers_post_filter * Update template * Fix tests * Lint --------- Co-authored-by: aider (bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0) <aider@aider.chat>
1 parent 49fd9e5 commit 90fc7c9

File tree

27 files changed

+453
-193
lines changed

27 files changed

+453
-193
lines changed

src/dstack/_internal/core/backends/aws/compute.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import threading
22
from concurrent.futures import ThreadPoolExecutor, as_completed
3-
from typing import Any, Dict, List, Optional, Tuple
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
44

55
import boto3
66
import botocore.client
@@ -18,6 +18,7 @@
1818
)
1919
from dstack._internal.core.backends.base.compute import (
2020
Compute,
21+
ComputeWithAllOffersCached,
2122
ComputeWithCreateInstanceSupport,
2223
ComputeWithGatewaySupport,
2324
ComputeWithMultinodeSupport,
@@ -32,7 +33,7 @@
3233
get_user_data,
3334
merge_tags,
3435
)
35-
from dstack._internal.core.backends.base.offers import get_catalog_offers
36+
from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
3637
from dstack._internal.core.errors import (
3738
ComputeError,
3839
NoCapacityError,
@@ -87,6 +88,7 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
8788

8889

8990
class AWSCompute(
91+
ComputeWithAllOffersCached,
9092
ComputeWithCreateInstanceSupport,
9193
ComputeWithMultinodeSupport,
9294
ComputeWithReservationSupport,
@@ -109,6 +111,8 @@ def __init__(self, config: AWSConfig):
109111
# Caches to avoid redundant API calls when provisioning many instances
110112
# get_offers is already cached but we still cache its sub-functions
111113
# with more aggressive/longer caches.
114+
self._offers_post_filter_cache_lock = threading.Lock()
115+
self._offers_post_filter_cache = TTLCache(maxsize=10, ttl=180)
112116
self._get_regions_to_quotas_cache_lock = threading.Lock()
113117
self._get_regions_to_quotas_execution_lock = threading.Lock()
114118
self._get_regions_to_quotas_cache = TTLCache(maxsize=10, ttl=300)
@@ -125,43 +129,11 @@ def __init__(self, config: AWSConfig):
125129
self._get_image_id_and_username_cache_lock = threading.Lock()
126130
self._get_image_id_and_username_cache = TTLCache(maxsize=100, ttl=600)
127131

128-
def get_offers(
129-
self, requirements: Optional[Requirements] = None
130-
) -> List[InstanceOfferWithAvailability]:
131-
filter = _supported_instances
132-
if requirements and requirements.reservation:
133-
region_to_reservation = {}
134-
for region in self.config.regions:
135-
reservation = aws_resources.get_reservation(
136-
ec2_client=self.session.client("ec2", region_name=region),
137-
reservation_id=requirements.reservation,
138-
instance_count=1,
139-
)
140-
if reservation is not None:
141-
region_to_reservation[region] = reservation
142-
143-
def _supported_instances_with_reservation(offer: InstanceOffer) -> bool:
144-
# Filter: only instance types supported by dstack
145-
if not _supported_instances(offer):
146-
return False
147-
# Filter: Spot instances can't be used with reservations
148-
if offer.instance.resources.spot:
149-
return False
150-
region = offer.region
151-
reservation = region_to_reservation.get(region)
152-
# Filter: only instance types matching the capacity reservation
153-
if not bool(reservation and offer.instance.name == reservation["InstanceType"]):
154-
return False
155-
return True
156-
157-
filter = _supported_instances_with_reservation
158-
132+
def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
159133
offers = get_catalog_offers(
160134
backend=BackendType.AWS,
161135
locations=self.config.regions,
162-
requirements=requirements,
163-
configurable_disk_size=CONFIGURABLE_DISK_SIZE,
164-
extra_filter=filter,
136+
extra_filter=_supported_instances,
165137
)
166138
regions = list(set(i.region for i in offers))
167139
with self._get_regions_to_quotas_execution_lock:
@@ -185,6 +157,49 @@ def _supported_instances_with_reservation(offer: InstanceOffer) -> bool:
185157
)
186158
return availability_offers
187159

160+
def get_offers_modifier(
161+
self, requirements: Requirements
162+
) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
163+
return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)
164+
165+
def _get_offers_cached_key(self, requirements: Requirements) -> int:
166+
# Requirements is not hashable, so we use a hack to get arguments hash
167+
return hash(requirements.json())
168+
169+
@cachedmethod(
170+
cache=lambda self: self._offers_post_filter_cache,
171+
key=_get_offers_cached_key,
172+
lock=lambda self: self._offers_post_filter_cache_lock,
173+
)
174+
def get_offers_post_filter(
175+
self, requirements: Requirements
176+
) -> Optional[Callable[[InstanceOfferWithAvailability], bool]]:
177+
if requirements.reservation:
178+
region_to_reservation = {}
179+
for region in get_or_error(self.config.regions):
180+
reservation = aws_resources.get_reservation(
181+
ec2_client=self.session.client("ec2", region_name=region),
182+
reservation_id=requirements.reservation,
183+
instance_count=1,
184+
)
185+
if reservation is not None:
186+
region_to_reservation[region] = reservation
187+
188+
def reservation_filter(offer: InstanceOfferWithAvailability) -> bool:
189+
# Filter: Spot instances can't be used with reservations
190+
if offer.instance.resources.spot:
191+
return False
192+
region = offer.region
193+
reservation = region_to_reservation.get(region)
194+
# Filter: only instance types matching the capacity reservation
195+
if not bool(reservation and offer.instance.name == reservation["InstanceType"]):
196+
return False
197+
return True
198+
199+
return reservation_filter
200+
201+
return None
202+
188203
def terminate_instance(
189204
self, instance_id: str, region: str, backend_data: Optional[str] = None
190205
) -> None:

src/dstack/_internal/core/backends/azure/compute.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import enum
33
import re
44
from concurrent.futures import ThreadPoolExecutor, as_completed
5-
from typing import Dict, List, Optional, Tuple
5+
from typing import Callable, Dict, List, Optional, Tuple
66

77
from azure.core.credentials import TokenCredential
88
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
@@ -39,6 +39,7 @@
3939
from dstack._internal.core.backends.azure.models import AzureConfig
4040
from dstack._internal.core.backends.base.compute import (
4141
Compute,
42+
ComputeWithAllOffersCached,
4243
ComputeWithCreateInstanceSupport,
4344
ComputeWithGatewaySupport,
4445
ComputeWithMultinodeSupport,
@@ -48,7 +49,7 @@
4849
get_user_data,
4950
merge_tags,
5051
)
51-
from dstack._internal.core.backends.base.offers import get_catalog_offers
52+
from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
5253
from dstack._internal.core.errors import ComputeError, NoCapacityError
5354
from dstack._internal.core.models.backends.base import BackendType
5455
from dstack._internal.core.models.gateways import (
@@ -73,6 +74,7 @@
7374

7475

7576
class AzureCompute(
77+
ComputeWithAllOffersCached,
7678
ComputeWithCreateInstanceSupport,
7779
ComputeWithMultinodeSupport,
7880
ComputeWithGatewaySupport,
@@ -89,14 +91,10 @@ def __init__(self, config: AzureConfig, credential: TokenCredential):
8991
credential=credential, subscription_id=config.subscription_id
9092
)
9193

92-
def get_offers(
93-
self, requirements: Optional[Requirements] = None
94-
) -> List[InstanceOfferWithAvailability]:
94+
def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
9595
offers = get_catalog_offers(
9696
backend=BackendType.AZURE,
9797
locations=self.config.regions,
98-
requirements=requirements,
99-
configurable_disk_size=CONFIGURABLE_DISK_SIZE,
10098
extra_filter=_supported_instances,
10199
)
102100
offers_with_availability = _get_offers_with_availability(
@@ -106,6 +104,11 @@ def get_offers(
106104
)
107105
return offers_with_availability
108106

107+
def get_offers_modifier(
108+
self, requirements: Requirements
109+
) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
110+
return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)
111+
109112
def create_instance(
110113
self,
111114
instance_offer: InstanceOfferWithAvailability,

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
from collections.abc import Iterable
88
from functools import lru_cache
99
from pathlib import Path
10-
from typing import Dict, List, Literal, Optional
10+
from typing import Callable, Dict, List, Literal, Optional
1111

1212
import git
1313
import requests
1414
import yaml
1515
from cachetools import TTLCache, cachedmethod
1616

1717
from dstack._internal import settings
18+
from dstack._internal.core.backends.base.offers import filter_offers_by_requirements
1819
from dstack._internal.core.consts import (
1920
DSTACK_RUNNER_HTTP_PORT,
2021
DSTACK_RUNNER_SSH_PORT,
@@ -57,14 +58,8 @@ class Compute(ABC):
5758
If a compute supports additional features, it must also subclass `ComputeWith*` classes.
5859
"""
5960

60-
def __init__(self):
61-
self._offers_cache_lock = threading.Lock()
62-
self._offers_cache = TTLCache(maxsize=10, ttl=180)
63-
6461
@abstractmethod
65-
def get_offers(
66-
self, requirements: Optional[Requirements] = None
67-
) -> List[InstanceOfferWithAvailability]:
62+
def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
6863
"""
6964
Returns offers with availability matching `requirements`.
7065
If the provider is added to gpuhunt, typically gets offers using `base.offers.get_catalog_offers()`
@@ -121,21 +116,108 @@ def update_provisioning_data(
121116
"""
122117
pass
123118

124-
def _get_offers_cached_key(self, requirements: Optional[Requirements] = None) -> int:
119+
120+
class ComputeWithAllOffersCached(ABC):
121+
"""
122+
Provides common `get_offers()` implementation for backends
123+
whose offers do not depend on requirements.
124+
It caches all offers with availability and post-filters by requirements.
125+
"""
126+
127+
def __init__(self) -> None:
128+
super().__init__()
129+
self._offers_cache_lock = threading.Lock()
130+
self._offers_cache = TTLCache(maxsize=1, ttl=180)
131+
132+
@abstractmethod
133+
def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
134+
"""
135+
Returns all backend offers with availability.
136+
"""
137+
pass
138+
139+
def get_offers_modifier(
140+
self, requirements: Requirements
141+
) -> Optional[
142+
Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]
143+
]:
144+
"""
145+
Returns a modifier function that modifies offers before they are filtered by requirements.
146+
Can return `None` to exclude the offer.
147+
E.g. can be used to set appropriate disk size based on requirements.
148+
"""
149+
return None
150+
151+
def get_offers_post_filter(
152+
self, requirements: Requirements
153+
) -> Optional[Callable[[InstanceOfferWithAvailability], bool]]:
154+
"""
155+
Returns a filter function to apply to offers based on requirements.
156+
This allows backends to implement custom post-filtering logic for specific requirements.
157+
"""
158+
return None
159+
160+
def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
161+
offers = self._get_all_offers_with_availability_cached()
162+
modifier = self.get_offers_modifier(requirements)
163+
if modifier is not None:
164+
modified_offers = []
165+
for o in offers:
166+
modified_offer = modifier(o)
167+
if modified_offer is not None:
168+
modified_offers.append(modified_offer)
169+
offers = modified_offers
170+
offers = filter_offers_by_requirements(offers, requirements)
171+
post_filter = self.get_offers_post_filter(requirements)
172+
if post_filter is not None:
173+
offers = [o for o in offers if post_filter(o)]
174+
return offers
175+
176+
@cachedmethod(
177+
cache=lambda self: self._offers_cache,
178+
lock=lambda self: self._offers_cache_lock,
179+
)
180+
def _get_all_offers_with_availability_cached(self) -> List[InstanceOfferWithAvailability]:
181+
return self.get_all_offers_with_availability()
182+
183+
184+
class ComputeWithFilteredOffersCached(ABC):
185+
"""
186+
Provides common `get_offers()` implementation for backends
187+
whose offers depend on requirements.
188+
It caches offers using requirements as key.
189+
"""
190+
191+
def __init__(self) -> None:
192+
super().__init__()
193+
self._offers_cache_lock = threading.Lock()
194+
self._offers_cache = TTLCache(maxsize=10, ttl=180)
195+
196+
@abstractmethod
197+
def get_offers_by_requirements(
198+
self, requirements: Requirements
199+
) -> List[InstanceOfferWithAvailability]:
200+
"""
201+
Returns backend offers with availability matching requirements.
202+
"""
203+
pass
204+
205+
def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
206+
return self._get_offers_cached(requirements)
207+
208+
def _get_offers_cached_key(self, requirements: Requirements) -> int:
125209
# Requirements is not hashable, so we use a hack to get arguments hash
126-
if requirements is None:
127-
return hash(None)
128210
return hash(requirements.json())
129211

130212
@cachedmethod(
131213
cache=lambda self: self._offers_cache,
132214
key=_get_offers_cached_key,
133215
lock=lambda self: self._offers_cache_lock,
134216
)
135-
def get_offers_cached(
136-
self, requirements: Optional[Requirements] = None
217+
def _get_offers_cached(
218+
self, requirements: Requirements
137219
) -> List[InstanceOfferWithAvailability]:
138-
return self.get_offers(requirements)
220+
return self.get_offers_by_requirements(requirements)
139221

140222

141223
class ComputeWithCreateInstanceSupport(ABC):

0 commit comments

Comments
 (0)