Skip to content

Commit b9ce68f

Browse files
authored
Respect fleet nodes.max (#3164)
* Maintain at most nodes.max fleet instances * Test redundant instances termination * Fix FleetModel.name loading * Check if run cannot fit into fleet * Fix missing instance lock on SQLite * Check if a run can fit into SSH fleet * Fix FOR UPDATE with nullable join
1 parent a863f6d commit b9ce68f

File tree

4 files changed

+226
-39
lines changed

4 files changed

+226
-39
lines changed

src/dstack/_internal/server/background/tasks/process_fleets.py

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
from collections import defaultdict
12
from datetime import timedelta
23
from typing import List
34
from uuid import UUID
45

56
from sqlalchemy import select, update
67
from sqlalchemy.ext.asyncio import AsyncSession
7-
from sqlalchemy.orm import joinedload, load_only
8+
from sqlalchemy.orm import joinedload, load_only, selectinload
89

910
from dstack._internal.core.models.fleets import FleetSpec, FleetStatus
1011
from dstack._internal.core.models.instances import InstanceStatus
@@ -37,30 +38,68 @@
3738

3839
@sentry_utils.instrument_background_task
3940
async def process_fleets():
40-
lock, lockset = get_locker(get_db().dialect_name).get_lockset(FleetModel.__tablename__)
41+
fleet_lock, fleet_lockset = get_locker(get_db().dialect_name).get_lockset(
42+
FleetModel.__tablename__
43+
)
44+
instance_lock, instance_lockset = get_locker(get_db().dialect_name).get_lockset(
45+
InstanceModel.__tablename__
46+
)
4147
async with get_session_ctx() as session:
42-
async with lock:
48+
async with fleet_lock, instance_lock:
4349
res = await session.execute(
4450
select(FleetModel)
4551
.where(
4652
FleetModel.deleted == False,
47-
FleetModel.id.not_in(lockset),
53+
FleetModel.id.not_in(fleet_lockset),
4854
FleetModel.last_processed_at
4955
< get_current_datetime() - MIN_PROCESSING_INTERVAL,
5056
)
51-
.options(load_only(FleetModel.id))
57+
.options(
58+
load_only(FleetModel.id, FleetModel.name),
59+
selectinload(FleetModel.instances).load_only(InstanceModel.id),
60+
)
5261
.order_by(FleetModel.last_processed_at.asc())
5362
.limit(BATCH_SIZE)
5463
.with_for_update(skip_locked=True, key_share=True)
5564
)
56-
fleet_models = list(res.scalars().all())
65+
fleet_models = list(res.scalars().unique().all())
5766
fleet_ids = [fm.id for fm in fleet_models]
67+
res = await session.execute(
68+
select(InstanceModel)
69+
.where(
70+
InstanceModel.id.not_in(instance_lockset),
71+
InstanceModel.fleet_id.in_(fleet_ids),
72+
)
73+
.options(load_only(InstanceModel.id, InstanceModel.fleet_id))
74+
.order_by(InstanceModel.id)
75+
.with_for_update(skip_locked=True, key_share=True)
76+
)
77+
instance_models = list(res.scalars().all())
78+
fleet_id_to_locked_instances = defaultdict(list)
79+
for instance_model in instance_models:
80+
fleet_id_to_locked_instances[instance_model.fleet_id].append(instance_model)
81+
# Process only fleets with all instances locked.
82+
# Other fleets won't be processed but will still be locked to avoid new transaction.
83+
# This should not be problematic as long as process_fleets is quick.
84+
fleet_models_to_process = []
85+
for fleet_model in fleet_models:
86+
if len(fleet_model.instances) == len(fleet_id_to_locked_instances[fleet_model.id]):
87+
fleet_models_to_process.append(fleet_model)
88+
else:
89+
logger.debug(
90+
"Fleet %s processing will be skipped: some instance were not locked",
91+
fleet_model.name,
92+
)
5893
for fleet_id in fleet_ids:
59-
lockset.add(fleet_id)
94+
fleet_lockset.add(fleet_id)
95+
instance_ids = [im.id for im in instance_models]
96+
for instance_id in instance_ids:
97+
instance_lockset.add(instance_id)
6098
try:
61-
await _process_fleets(session=session, fleet_models=fleet_models)
99+
await _process_fleets(session=session, fleet_models=fleet_models_to_process)
62100
finally:
63-
lockset.difference_update(fleet_ids)
101+
fleet_lockset.difference_update(fleet_ids)
102+
instance_lockset.difference_update(instance_ids)
64103

65104

66105
async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel]):
@@ -99,8 +138,8 @@ def _consolidate_fleet_state_with_spec(session: AsyncSession, fleet_model: Fleet
99138
return
100139
if not _is_fleet_ready_for_consolidation(fleet_model):
101140
return
102-
added_instances = _maintain_fleet_nodes_min(session, fleet_model, fleet_spec)
103-
if added_instances:
141+
changed_instances = _maintain_fleet_nodes_in_min_max_range(session, fleet_model, fleet_spec)
142+
if changed_instances:
104143
fleet_model.consolidation_attempt += 1
105144
else:
106145
# The fleet is already consolidated or consolidation is in progress.
@@ -138,28 +177,47 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta:
138177
return _CONSOLIDATION_RETRY_DELAYS[-1]
139178

140179

141-
def _maintain_fleet_nodes_min(
180+
def _maintain_fleet_nodes_in_min_max_range(
142181
session: AsyncSession,
143182
fleet_model: FleetModel,
144183
fleet_spec: FleetSpec,
145184
) -> bool:
146185
"""
147-
Ensures the fleet has at least `nodes.min` instances.
148-
Returns `True` if retried or added new instances and `False` otherwise.
186+
Ensures the fleet has at least `nodes.min` and at most `nodes.max` instances.
187+
Returns `True` if retried, added new instances, or terminated redundant instances and `False` otherwise.
149188
"""
150189
assert fleet_spec.configuration.nodes is not None
151190
for instance in fleet_model.instances:
152191
# Delete terminated but not deleted instances since
153192
# they are going to be replaced with new pending instances.
154193
if instance.status == InstanceStatus.TERMINATED and not instance.deleted:
155-
# It's safe to modify instances without instance lock since
156-
# no other task modifies already terminated instances.
157194
instance.deleted = True
158195
instance.deleted_at = get_current_datetime()
159196
active_instances = [i for i in fleet_model.instances if not i.deleted]
160197
active_instances_num = len(active_instances)
161198
if active_instances_num >= fleet_spec.configuration.nodes.min:
162-
return False
199+
if (
200+
fleet_spec.configuration.nodes.max is None
201+
or active_instances_num <= fleet_spec.configuration.nodes.max
202+
):
203+
return False
204+
# Fleet has more instances than allowed by nodes.max.
205+
# This is possible due to race conditions (e.g. provisioning jobs in a fleet concurrently)
206+
# or if nodes.max is updated.
207+
nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max
208+
for instance in fleet_model.instances:
209+
if nodes_redundant == 0:
210+
break
211+
if instance.status in [InstanceStatus.IDLE]:
212+
instance.status = InstanceStatus.TERMINATING
213+
instance.termination_reason = "Fleet has too many instances"
214+
nodes_redundant -= 1
215+
logger.info(
216+
"Terminating instance %s: %s",
217+
instance.name,
218+
instance.termination_reason,
219+
)
220+
return True
163221
nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num
164222
for i in range(nodes_missing):
165223
instance_model = create_fleet_instance_model(

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
260260

261261
instance_filters = [
262262
InstanceModel.deleted == False,
263-
InstanceModel.total_blocks > InstanceModel.busy_blocks,
264263
InstanceModel.id.not_in(detaching_instances_ids),
265264
]
266265

@@ -514,9 +513,6 @@ async def _find_optimal_fleet_with_offers(
514513
)
515514
return run_model.fleet, fleet_instances_with_pool_offers
516515

517-
if len(fleet_models) == 0:
518-
return None, []
519-
520516
nodes_required_num = _get_nodes_required_num_for_run(run_spec)
521517
# The current strategy is first to consider fleets that can accommodate
522518
# the run without additional provisioning and choose the one with the cheapest pool offer.
@@ -534,31 +530,29 @@ async def _find_optimal_fleet_with_offers(
534530
]
535531
] = []
536532
for candidate_fleet_model in fleet_models:
533+
candidate_fleet = fleet_model_to_fleet(candidate_fleet_model)
537534
fleet_instances_with_pool_offers = _get_fleet_instances_with_pool_offers(
538535
fleet_model=candidate_fleet_model,
539536
run_spec=run_spec,
540537
job=job,
541538
master_job_provisioning_data=master_job_provisioning_data,
542539
volumes=volumes,
543540
)
544-
fleet_has_available_capacity = nodes_required_num <= len(fleet_instances_with_pool_offers)
541+
fleet_has_pool_capacity = nodes_required_num <= len(fleet_instances_with_pool_offers)
545542
fleet_cheapest_pool_offer = math.inf
546543
if len(fleet_instances_with_pool_offers) > 0:
547544
fleet_cheapest_pool_offer = fleet_instances_with_pool_offers[0][1].price
548545

549-
candidate_fleet = fleet_model_to_fleet(candidate_fleet_model)
550-
profile = None
551-
requirements = None
552546
try:
547+
_check_can_create_new_instance_in_fleet(candidate_fleet)
553548
profile, requirements = _get_run_profile_and_requirements_in_fleet(
554549
job=job,
555550
run_spec=run_spec,
556551
fleet=candidate_fleet,
557552
)
558553
except ValueError:
559-
pass
560-
fleet_backend_offers = []
561-
if profile is not None and requirements is not None:
554+
fleet_backend_offers = []
555+
else:
562556
multinode = (
563557
candidate_fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER
564558
or job.job_spec.jobs_per_replica > 1
@@ -579,8 +573,12 @@ async def _find_optimal_fleet_with_offers(
579573
if len(fleet_backend_offers) > 0:
580574
fleet_cheapest_backend_offer = fleet_backend_offers[0][1].price
581575

576+
if not _run_can_fit_into_fleet(run_spec, candidate_fleet):
577+
logger.debug("Skipping fleet %s from consideration: run cannot fit into fleet")
578+
continue
579+
582580
fleet_priority = (
583-
not fleet_has_available_capacity,
581+
not fleet_has_pool_capacity,
584582
fleet_cheapest_pool_offer,
585583
fleet_cheapest_backend_offer,
586584
)
@@ -593,10 +591,13 @@ async def _find_optimal_fleet_with_offers(
593591
fleet_priority,
594592
)
595593
)
594+
if len(candidate_fleets_with_offers) == 0:
595+
return None, []
596596
if run_spec.merged_profile.fleets is None and all(
597597
t[2] == 0 and t[3] == 0 for t in candidate_fleets_with_offers
598598
):
599-
# If fleets are not specified and no fleets have available pool or backend offers, create a new fleet.
599+
# If fleets are not specified and no fleets have available pool
600+
# or backend offers, create a new fleet.
600601
# This is for compatibility with non-fleet-first UX when runs created new fleets
601602
# if there are no instances to reuse.
602603
return None, []
@@ -616,6 +617,39 @@ def _get_nodes_required_num_for_run(run_spec: RunSpec) -> int:
616617
return nodes_required_num
617618

618619

620+
def _run_can_fit_into_fleet(run_spec: RunSpec, fleet: Fleet) -> bool:
621+
"""
622+
Returns `False` if the run cannot fit into fleet for sure.
623+
This is helpful heuristic to avoid even considering fleets too small for a run.
624+
A run may not fit even if this function returns `True`.
625+
This will lead to some jobs failing due to exceeding `nodes.max`
626+
or more than `nodes.max` instances being provisioned
627+
and eventually removed by the fleet consolidation logic.
628+
"""
629+
# No check for cloud fleets with blocks > 1 since we don't know
630+
# how many jobs such fleets can accommodate.
631+
nodes_required_num = _get_nodes_required_num_for_run(run_spec)
632+
if (
633+
fleet.spec.configuration.nodes is not None
634+
and fleet.spec.configuration.blocks == 1
635+
and fleet.spec.configuration.nodes.max is not None
636+
):
637+
busy_instances = [i for i in fleet.instances if i.busy_blocks > 0]
638+
fleet_available_capacity = fleet.spec.configuration.nodes.max - len(busy_instances)
639+
if fleet_available_capacity < nodes_required_num:
640+
return False
641+
elif fleet.spec.configuration.ssh_config is not None:
642+
# Currently assume that each idle block can run a job.
643+
# TODO: Take resources / eligible offers into account.
644+
total_idle_blocks = 0
645+
for instance in fleet.instances:
646+
total_blocks = instance.total_blocks or 1
647+
total_idle_blocks += total_blocks - instance.busy_blocks
648+
if total_idle_blocks < nodes_required_num:
649+
return False
650+
return True
651+
652+
619653
def _get_fleet_instances_with_pool_offers(
620654
fleet_model: FleetModel,
621655
run_spec: RunSpec,
@@ -713,6 +747,7 @@ async def _run_job_on_new_instance(
713747
if fleet_model is not None:
714748
fleet = fleet_model_to_fleet(fleet_model)
715749
try:
750+
_check_can_create_new_instance_in_fleet(fleet)
716751
profile, requirements = _get_run_profile_and_requirements_in_fleet(
717752
job=job,
718753
run_spec=run.run_spec,
@@ -787,8 +822,6 @@ def _get_run_profile_and_requirements_in_fleet(
787822
run_spec: RunSpec,
788823
fleet: Fleet,
789824
) -> tuple[Profile, Requirements]:
790-
if not _check_can_create_new_instance_in_fleet(fleet):
791-
raise ValueError("Cannot fit new instance into fleet")
792825
profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, run_spec.merged_profile)
793826
if profile is None:
794827
raise ValueError("Cannot combine fleet profile")
@@ -801,13 +834,23 @@ def _get_run_profile_and_requirements_in_fleet(
801834
return profile, requirements
802835

803836

804-
def _check_can_create_new_instance_in_fleet(fleet: Fleet) -> bool:
837+
def _check_can_create_new_instance_in_fleet(fleet: Fleet):
838+
if not _can_create_new_instance_in_fleet(fleet):
839+
raise ValueError("Cannot fit new instance into fleet")
840+
841+
842+
def _can_create_new_instance_in_fleet(fleet: Fleet) -> bool:
805843
if fleet.spec.configuration.ssh_config is not None:
806844
return False
807-
# TODO: Respect nodes.max
808-
# Ensure concurrent provisioning does not violate nodes.max
809-
# E.g. lock fleet and split instance model creation
810-
# and instance provisioning into separate transactions.
845+
active_instances = [i for i in fleet.instances if i.status.is_active()]
846+
# nodes.max is a soft limit that can be exceeded when provisioning concurrently.
847+
# The fleet consolidation logic will remove redundant nodes eventually.
848+
if (
849+
fleet.spec.configuration.nodes is not None
850+
and fleet.spec.configuration.nodes.max is not None
851+
and len(active_instances) >= fleet.spec.configuration.nodes.max
852+
):
853+
return False
811854
return True
812855

813856

src/tests/_internal/server/background/tasks/test_process_fleets.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,36 @@ async def test_consolidation_creates_missing_instances(self, test_db, session: A
126126
instances = (await session.execute(select(InstanceModel))).scalars().all()
127127
assert len(instances) == 2
128128
assert {i.instance_num for i in instances} == {0, 1} # uses 0 for next instance num
129+
130+
@pytest.mark.asyncio
131+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
132+
async def test_consolidation_terminates_redundant_instances(
133+
self, test_db, session: AsyncSession
134+
):
135+
project = await create_project(session)
136+
spec = get_fleet_spec()
137+
spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1)
138+
fleet = await create_fleet(
139+
session=session,
140+
project=project,
141+
spec=spec,
142+
)
143+
instance1 = await create_instance(
144+
session=session,
145+
project=project,
146+
fleet=fleet,
147+
status=InstanceStatus.BUSY,
148+
instance_num=0,
149+
)
150+
instance2 = await create_instance(
151+
session=session,
152+
project=project,
153+
fleet=fleet,
154+
status=InstanceStatus.IDLE,
155+
instance_num=1,
156+
)
157+
await process_fleets()
158+
await session.refresh(instance1)
159+
await session.refresh(instance2)
160+
assert instance1.status == InstanceStatus.BUSY
161+
assert instance2.status == InstanceStatus.TERMINATING

0 commit comments

Comments
 (0)