Skip to content

Commit 6d2687c

Browse files
committed
Maintain at most nodes.max fleet instances
1 parent 4e7ff02 commit 6d2687c

1 file changed

Lines changed: 70 additions & 14 deletions

File tree

src/dstack/_internal/server/background/tasks/process_fleets.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from datetime import timedelta
23
from typing import List
34
from uuid import UUID
@@ -37,30 +38,64 @@
3738

3839
@sentry_utils.instrument_background_task
3940
async def process_fleets():
40-
lock, lockset = get_locker(get_db().dialect_name).get_lockset(FleetModel.__tablename__)
41+
fleet_lock, fleet_lockset = get_locker(get_db().dialect_name).get_lockset(
42+
FleetModel.__tablename__
43+
)
44+
instance_lock, instance_lockset = get_locker(get_db().dialect_name).get_lockset(
45+
InstanceModel.__tablename__
46+
)
4147
async with get_session_ctx() as session:
42-
async with lock:
48+
async with fleet_lock, instance_lock:
4349
res = await session.execute(
4450
select(FleetModel)
4551
.where(
4652
FleetModel.deleted == False,
47-
FleetModel.id.not_in(lockset),
53+
FleetModel.id.not_in(fleet_lockset),
4854
FleetModel.last_processed_at
4955
< get_current_datetime() - MIN_PROCESSING_INTERVAL,
5056
)
51-
.options(load_only(FleetModel.id))
57+
.options(
58+
load_only(FleetModel.id),
59+
joinedload(FleetModel.instances).load_only(InstanceModel.id),
60+
)
5261
.order_by(FleetModel.last_processed_at.asc())
5362
.limit(BATCH_SIZE)
5463
.with_for_update(skip_locked=True, key_share=True)
5564
)
56-
fleet_models = list(res.scalars().all())
65+
fleet_models = list(res.scalars().unique().all())
5766
fleet_ids = [fm.id for fm in fleet_models]
67+
res = await session.execute(
68+
select(InstanceModel)
69+
.where(
70+
InstanceModel.id.not_in(instance_lockset),
71+
InstanceModel.fleet_id.in_(fleet_ids),
72+
)
73+
.options(load_only(InstanceModel.id, InstanceModel.fleet_id))
74+
.order_by(InstanceModel.id)
75+
.with_for_update(skip_locked=True, key_share=True)
76+
)
77+
instance_models = list(res.scalars().all())
78+
fleet_id_to_locked_instances = defaultdict(list)
79+
for instance_model in instance_models:
80+
fleet_id_to_locked_instances[instance_model.fleet_id].append(instance_model)
81+
# Process only fleets with all instances locked.
82+
# Other fleets won't be processed but will still be locked to avoid new transaction.
83+
# This should not be problematic as long as process_fleets is quick.
84+
fleet_models_to_process = []
85+
for fleet_model in fleet_models:
86+
if len(fleet_model.instances) == len(fleet_id_to_locked_instances[fleet_model.id]):
87+
fleet_models_to_process.append(fleet_model)
88+
else:
89+
logger.debug(
90+
"Fleet %s processing will be skipped: some instance were not locked",
91+
fleet_model.name,
92+
)
5893
for fleet_id in fleet_ids:
59-
lockset.add(fleet_id)
94+
fleet_lockset.add(fleet_id)
6095
try:
61-
await _process_fleets(session=session, fleet_models=fleet_models)
96+
await _process_fleets(session=session, fleet_models=fleet_models_to_process)
6297
finally:
63-
lockset.difference_update(fleet_ids)
98+
fleet_lockset.difference_update(fleet_ids)
6499

65100

66101
async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel]):
@@ -99,8 +134,8 @@ def _consolidate_fleet_state_with_spec(session: AsyncSession, fleet_model: Fleet
99134
return
100135
if not _is_fleet_ready_for_consolidation(fleet_model):
101136
return
102-
added_instances = _maintain_fleet_nodes_min(session, fleet_model, fleet_spec)
103-
if added_instances:
137+
changed_instances = _maintain_fleet_nodes_in_min_max_range(session, fleet_model, fleet_spec)
138+
if changed_instances:
104139
fleet_model.consolidation_attempt += 1
105140
else:
106141
# The fleet is already consolidated or consolidation is in progress.
@@ -138,14 +173,14 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta:
138173
return _CONSOLIDATION_RETRY_DELAYS[-1]
139174

140175

141-
def _maintain_fleet_nodes_min(
176+
def _maintain_fleet_nodes_in_min_max_range(
142177
session: AsyncSession,
143178
fleet_model: FleetModel,
144179
fleet_spec: FleetSpec,
145180
) -> bool:
146181
"""
147-
Ensures the fleet has at least `nodes.min` instances.
148-
Returns `True` if retried or added new instances and `False` otherwise.
182+
Ensures the fleet has at least `nodes.min` and at most `nodes.max` instances.
183+
Returns `True` if retried, added new instances, or terminated redundant instances and `False` otherwise.
149184
"""
150185
assert fleet_spec.configuration.nodes is not None
151186
for instance in fleet_model.instances:
@@ -159,7 +194,28 @@ def _maintain_fleet_nodes_min(
159194
active_instances = [i for i in fleet_model.instances if not i.deleted]
160195
active_instances_num = len(active_instances)
161196
if active_instances_num >= fleet_spec.configuration.nodes.min:
162-
return False
197+
if (
198+
fleet_spec.configuration.nodes.max is None
199+
or active_instances_num <= fleet_spec.configuration.nodes.max
200+
):
201+
return False
202+
# Fleet has more instances than allowed by nodes.max.
203+
# This is possible due to race conditions (e.g. provisioning jobs in a fleet concurrently)
204+
# or if nodes.max is updated.
205+
nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max
206+
for instance in fleet_model.instances:
207+
if nodes_redundant == 0:
208+
break
209+
if instance.status in [InstanceStatus.IDLE]:
210+
instance.status = InstanceStatus.TERMINATING
211+
instance.termination_reason = "Fleet has too many instances"
212+
nodes_redundant -= 1
213+
logger.info(
214+
"Terminating instance %s: %s",
215+
instance.name,
216+
instance.termination_reason,
217+
)
218+
return True
163219
nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num
164220
for i in range(nodes_missing):
165221
instance_model = create_fleet_instance_model(

0 commit comments

Comments
 (0)