Skip to content

Commit 54c4a9f

Browse files
authored
Respect fleet spec when provisioning on run apply (#3022)
* Implement basic fleet-run spec combiners * Test combine profiles * Test combine requirements * Refactor optional combine * Fix _get_min_optional typing * Move combine to separate module * Combine profile tags * Fix typing * Respect fleet specs when provisioning new instance on run apply * Remove match from test
1 parent 5f700f0 commit 54c4a9f

File tree

9 files changed

+714
-17
lines changed

9 files changed

+714
-17
lines changed

src/dstack/_internal/core/models/profiles.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ class ProfileProps(CoreModel):
397397
Field(
398398
description="The name of the profile that can be passed as `--profile` to `dstack apply`"
399399
),
400-
]
400+
] = ""
401401
default: Annotated[
402402
bool, Field(description="If set to true, `dstack apply` will use this profile by default.")
403403
] = False

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from dstack._internal.server.services.backends import get_project_backend_by_type_or_error
5454
from dstack._internal.server.services.fleets import (
5555
fleet_model_to_fleet,
56+
get_fleet_requirements,
5657
)
5758
from dstack._internal.server.services.instances import (
5859
filter_pool_instances,
@@ -71,6 +72,10 @@
7172
from dstack._internal.server.services.locking import get_locker
7273
from dstack._internal.server.services.logging import fmt
7374
from dstack._internal.server.services.offers import get_offers_by_requirements
75+
from dstack._internal.server.services.requirements.combine import (
76+
combine_fleet_and_run_profiles,
77+
combine_fleet_and_run_requirements,
78+
)
7479
from dstack._internal.server.services.runs import (
7580
check_run_spec_requires_instance_mounts,
7681
run_model_to_run,
@@ -646,6 +651,8 @@ async def _run_job_on_new_instance(
646651
) -> Optional[Tuple[JobProvisioningData, InstanceOfferWithAvailability]]:
647652
if volumes is None:
648653
volumes = []
654+
profile = run.run_spec.merged_profile
655+
requirements = job.job_spec.requirements
649656
fleet = None
650657
if fleet_model is not None:
651658
fleet = fleet_model_to_fleet(fleet_model)
@@ -654,13 +661,26 @@ async def _run_job_on_new_instance(
654661
"%s: cannot fit new instance into fleet %s", fmt(job_model), fleet_model.name
655662
)
656663
return None
664+
profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, profile)
665+
if profile is None:
666+
logger.debug("%s: cannot combine fleet %s profile", fmt(job_model), fleet_model.name)
667+
return None
668+
fleet_requirements = get_fleet_requirements(fleet.spec)
669+
requirements = combine_fleet_and_run_requirements(fleet_requirements, requirements)
670+
if requirements is None:
671+
logger.debug(
672+
"%s: cannot combine fleet %s requirements", fmt(job_model), fleet_model.name
673+
)
674+
return None
675+
# TODO: Respect fleet provisioning properties such as tags
676+
657677
multinode = job.job_spec.jobs_per_replica > 1 or (
658678
fleet is not None and fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER
659679
)
660680
offers = await get_offers_by_requirements(
661681
project=project,
662-
profile=run.run_spec.merged_profile,
663-
requirements=job.job_spec.requirements,
682+
profile=profile,
683+
requirements=requirements,
664684
exclude_not_available=True,
665685
multinode=multinode,
666686
master_job_provisioning_data=master_job_provisioning_data,

src/dstack/_internal/server/services/fleets.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ async def get_plan(
279279
offers_with_backends = await get_create_instance_offers(
280280
project=project,
281281
profile=effective_spec.merged_profile,
282-
requirements=_get_fleet_requirements(effective_spec),
282+
requirements=get_fleet_requirements(effective_spec),
283283
fleet_spec=effective_spec,
284284
blocks=effective_spec.configuration.blocks,
285285
)
@@ -458,7 +458,7 @@ async def create_fleet_instance_model(
458458
instance_num: int,
459459
) -> InstanceModel:
460460
profile = spec.merged_profile
461-
requirements = _get_fleet_requirements(spec)
461+
requirements = get_fleet_requirements(spec)
462462
instance_model = await instances_services.create_instance_model(
463463
session=session,
464464
project=project,
@@ -644,6 +644,17 @@ def is_fleet_empty(fleet_model: FleetModel) -> bool:
644644
return len(active_instances) == 0
645645

646646

647+
def get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements:
648+
profile = fleet_spec.merged_profile
649+
requirements = Requirements(
650+
resources=fleet_spec.configuration.resources or ResourcesSpec(),
651+
max_price=profile.max_price,
652+
spot=get_policy_map(profile.spot_policy, default=SpotPolicy.ONDEMAND),
653+
reservation=fleet_spec.configuration.reservation,
654+
)
655+
return requirements
656+
657+
647658
async def _create_fleet(
648659
session: AsyncSession,
649660
project: ProjectModel,
@@ -1004,17 +1015,6 @@ def _terminate_fleet_instances(fleet_model: FleetModel, instance_nums: Optional[
10041015
instance.status = InstanceStatus.TERMINATING
10051016

10061017

1007-
def _get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements:
1008-
profile = fleet_spec.merged_profile
1009-
requirements = Requirements(
1010-
resources=fleet_spec.configuration.resources or ResourcesSpec(),
1011-
max_price=profile.max_price,
1012-
spot=get_policy_map(profile.spot_policy, default=SpotPolicy.ONDEMAND),
1013-
reservation=fleet_spec.configuration.reservation,
1014-
)
1015-
return requirements
1016-
1017-
10181018
def _get_next_instance_num(instance_nums: set[int]) -> int:
10191019
if not instance_nums:
10201020
return 0

src/dstack/_internal/server/services/requirements/__init__.py

Whitespace-only changes.
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
from typing import Callable, Optional, Protocol, TypeVar
2+
3+
from pydantic import BaseModel
4+
from typing_extensions import Self
5+
6+
from dstack._internal.core.models.profiles import Profile, SpotPolicy
7+
from dstack._internal.core.models.resources import (
8+
CPUSpec,
9+
DiskSpec,
10+
GPUSpec,
11+
Memory,
12+
Range,
13+
ResourcesSpec,
14+
)
15+
from dstack._internal.core.models.runs import Requirements
16+
from dstack._internal.utils.typing import SupportsRichComparison
17+
18+
19+
class CombineError(ValueError):
20+
pass
21+
22+
23+
def combine_fleet_and_run_profiles(
24+
fleet_profile: Profile, run_profile: Profile
25+
) -> Optional[Profile]:
26+
"""
27+
Combines fleet and run profile parameters that affect offer selection or provisioning.
28+
"""
29+
try:
30+
return Profile(
31+
backends=_intersect_lists_optional(fleet_profile.backends, run_profile.backends),
32+
regions=_intersect_lists_optional(fleet_profile.regions, run_profile.regions),
33+
availability_zones=_intersect_lists_optional(
34+
fleet_profile.availability_zones, run_profile.availability_zones
35+
),
36+
instance_types=_intersect_lists_optional(
37+
fleet_profile.instance_types, run_profile.instance_types
38+
),
39+
reservation=_get_single_value_optional(
40+
fleet_profile.reservation, run_profile.reservation
41+
),
42+
spot_policy=_combine_spot_policy_optional(
43+
fleet_profile.spot_policy, run_profile.spot_policy
44+
),
45+
max_price=_get_min_optional(fleet_profile.max_price, run_profile.max_price),
46+
idle_duration=_combine_idle_duration_optional(
47+
fleet_profile.idle_duration, run_profile.idle_duration
48+
),
49+
tags=_combine_tags_optional(fleet_profile.tags, run_profile.tags),
50+
)
51+
except CombineError:
52+
return None
53+
54+
55+
def combine_fleet_and_run_requirements(
56+
fleet_requirements: Requirements, run_requirements: Requirements
57+
) -> Optional[Requirements]:
58+
try:
59+
return Requirements(
60+
resources=_combine_resources(fleet_requirements.resources, run_requirements.resources),
61+
max_price=_get_min_optional(fleet_requirements.max_price, run_requirements.max_price),
62+
spot=_combine_spot_optional(fleet_requirements.spot, run_requirements.spot),
63+
reservation=_get_single_value_optional(
64+
fleet_requirements.reservation, run_requirements.reservation
65+
),
66+
)
67+
except CombineError:
68+
return None
69+
70+
71+
_T = TypeVar("_T")
72+
_ModelT = TypeVar("_ModelT", bound=BaseModel)
73+
_CompT = TypeVar("_CompT", bound=SupportsRichComparison)
74+
75+
76+
class _SupportsCopy(Protocol):
77+
def copy(self) -> Self: ...
78+
79+
80+
_CopyT = TypeVar("_CopyT", bound=_SupportsCopy)
81+
82+
83+
def _intersect_lists_optional(
84+
list1: Optional[list[_T]], list2: Optional[list[_T]]
85+
) -> Optional[list[_T]]:
86+
if list1 is None:
87+
if list2 is None:
88+
return None
89+
return list2.copy()
90+
if list2 is None:
91+
return list1.copy()
92+
return [x for x in list1 if x in list2]
93+
94+
95+
def _get_min(value1: _CompT, value2: _CompT) -> _CompT:
96+
return min(value1, value2)
97+
98+
99+
def _get_min_optional(value1: Optional[_CompT], value2: Optional[_CompT]) -> Optional[_CompT]:
100+
return _combine_optional(value1, value2, _get_min)
101+
102+
103+
def _get_single_value(value1: _T, value2: _T) -> _T:
104+
if value1 == value2:
105+
return value1
106+
raise CombineError(f"Values {value1} and {value2} cannot be combined")
107+
108+
109+
def _get_single_value_optional(value1: Optional[_T], value2: Optional[_T]) -> Optional[_T]:
110+
return _combine_optional(value1, value2, _get_single_value)
111+
112+
113+
def _combine_spot_policy(value1: SpotPolicy, value2: SpotPolicy) -> SpotPolicy:
114+
if value1 == SpotPolicy.AUTO:
115+
return value2
116+
if value2 == SpotPolicy.AUTO:
117+
return value1
118+
if value1 == value2:
119+
return value1
120+
raise CombineError(f"spot_policy values {value1} and {value2} cannot be combined")
121+
122+
123+
def _combine_spot_policy_optional(
124+
value1: Optional[SpotPolicy], value2: Optional[SpotPolicy]
125+
) -> Optional[SpotPolicy]:
126+
return _combine_optional(value1, value2, _combine_spot_policy)
127+
128+
129+
def _combine_idle_duration(value1: int, value2: int) -> int:
130+
if value1 < 0 and value2 >= 0 or value2 < 0 and value1 >= 0:
131+
raise CombineError(f"idle_duration values {value1} and {value2} cannot be combined")
132+
return min(value1, value2)
133+
134+
135+
def _combine_idle_duration_optional(value1: Optional[int], value2: Optional[int]) -> Optional[int]:
136+
return _combine_optional(value1, value2, _combine_idle_duration)
137+
138+
139+
def _combine_tags_optional(
140+
value1: Optional[dict[str, str]], value2: Optional[dict[str, str]]
141+
) -> Optional[dict[str, str]]:
142+
return _combine_copy_optional(value1, value2, _combine_tags)
143+
144+
145+
def _combine_tags(value1: dict[str, str], value2: dict[str, str]) -> dict[str, str]:
146+
return value1 | value2
147+
148+
149+
def _combine_resources(value1: ResourcesSpec, value2: ResourcesSpec) -> ResourcesSpec:
150+
return ResourcesSpec(
151+
cpu=_combine_cpu(value1.cpu, value2.cpu), # type: ignore[attr-defined]
152+
memory=_combine_memory(value1.memory, value2.memory),
153+
shm_size=_combine_shm_size_optional(value1.shm_size, value2.shm_size),
154+
gpu=_combine_gpu_optional(value1.gpu, value2.gpu),
155+
disk=_combine_disk_optional(value1.disk, value2.disk),
156+
)
157+
158+
159+
def _combine_cpu(value1: CPUSpec, value2: CPUSpec) -> CPUSpec:
160+
return CPUSpec(
161+
arch=_get_single_value_optional(value1.arch, value2.arch),
162+
count=_combine_range(value1.count, value2.count),
163+
)
164+
165+
166+
def _combine_memory(value1: Range[Memory], value2: Range[Memory]) -> Range[Memory]:
167+
return _combine_range(value1, value2)
168+
169+
170+
def _combine_shm_size_optional(
171+
value1: Optional[Memory], value2: Optional[Memory]
172+
) -> Optional[Memory]:
173+
return _get_min_optional(value1, value2)
174+
175+
176+
def _combine_gpu(value1: GPUSpec, value2: GPUSpec) -> GPUSpec:
177+
return GPUSpec(
178+
vendor=_get_single_value_optional(value1.vendor, value2.vendor),
179+
name=_intersect_lists_optional(value1.name, value2.name),
180+
count=_combine_range(value1.count, value2.count),
181+
memory=_combine_range_optional(value1.memory, value2.memory),
182+
total_memory=_combine_range_optional(value1.total_memory, value2.total_memory),
183+
compute_capability=_get_min_optional(value1.compute_capability, value2.compute_capability),
184+
)
185+
186+
187+
def _combine_gpu_optional(
188+
value1: Optional[GPUSpec], value2: Optional[GPUSpec]
189+
) -> Optional[GPUSpec]:
190+
return _combine_models_optional(value1, value2, _combine_gpu)
191+
192+
193+
def _combine_disk(value1: DiskSpec, value2: DiskSpec) -> DiskSpec:
194+
return DiskSpec(size=_combine_range(value1.size, value2.size))
195+
196+
197+
def _combine_disk_optional(
198+
value1: Optional[DiskSpec], value2: Optional[DiskSpec]
199+
) -> Optional[DiskSpec]:
200+
return _combine_models_optional(value1, value2, _combine_disk)
201+
202+
203+
def _combine_spot(value1: bool, value2: bool) -> bool:
204+
if value1 != value2:
205+
raise CombineError(f"spot values {value1} and {value2} cannot be combined")
206+
return value1
207+
208+
209+
def _combine_spot_optional(value1: Optional[bool], value2: Optional[bool]) -> Optional[bool]:
210+
return _combine_optional(value1, value2, _combine_spot)
211+
212+
213+
def _combine_range(value1: Range, value2: Range) -> Range:
214+
res = value1.intersect(value2)
215+
if res is None:
216+
raise CombineError(f"Ranges {value1} and {value2} cannot be combined")
217+
return res
218+
219+
220+
def _combine_range_optional(value1: Optional[Range], value2: Optional[Range]) -> Optional[Range]:
221+
return _combine_models_optional(value1, value2, _combine_range)
222+
223+
224+
def _combine_optional(
225+
value1: Optional[_T], value2: Optional[_T], combiner: Callable[[_T, _T], _T]
226+
) -> Optional[_T]:
227+
if value1 is None:
228+
return value2
229+
if value2 is None:
230+
return value1
231+
return combiner(value1, value2)
232+
233+
234+
def _combine_models_optional(
235+
value1: Optional[_ModelT],
236+
value2: Optional[_ModelT],
237+
combiner: Callable[[_ModelT, _ModelT], _ModelT],
238+
) -> Optional[_ModelT]:
239+
if value1 is None:
240+
if value2 is not None:
241+
return value2.copy(deep=True)
242+
return None
243+
if value2 is None:
244+
return value1.copy(deep=True)
245+
return combiner(value1, value2)
246+
247+
248+
def _combine_copy_optional(
249+
value1: Optional[_CopyT],
250+
value2: Optional[_CopyT],
251+
combiner: Callable[[_CopyT, _CopyT], _CopyT],
252+
) -> Optional[_CopyT]:
253+
if value1 is None:
254+
if value2 is not None:
255+
return value2.copy()
256+
return None
257+
if value2 is None:
258+
return value1.copy()
259+
return combiner(value1, value2)

src/dstack/_internal/server/testing/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ def get_fleet_spec(conf: Optional[FleetConfiguration] = None) -> FleetSpec:
573573
return FleetSpec(
574574
configuration=conf,
575575
configuration_path="fleet.dstack.yml",
576-
profile=Profile(name=""),
576+
profile=Profile(),
577577
)
578578

579579

0 commit comments

Comments
 (0)