Skip to content

Commit 165de64

Browse files
authored
Implement fleet pipeline (#3623)
* Load only fleets active runs in apply_plan * Replace fleets many-to-many joinedloads with selectinloads * Optimize selects * WIP: FleetPipeline * Fixes * Add TestFleetWorker * Fix consolidation_attempt reset * Use typed dicts for update maps * Unify processing result classes * Centralize last_processed_at setting * Refactor _build_instance_update_rows() * Make result naming consistent * Refactor _create_missing_fleet_instances() * Enable FleetPipeline * Respect fleet locks in the API endpoints * Add FleetModel pipeline migration * Adjust CONSOLIDATION_RETRY_DELAYS * Fix fleet autodelete comments * Add scheduled tasks deprecated note * Cleanup comment * Add ix_fleets_pipeline_fetch_q index
1 parent 84e2c70 commit 165de64

File tree

25 files changed

+1555
-141
lines changed

25 files changed

+1555
-141
lines changed

src/dstack/_internal/server/background/pipeline_tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from dstack._internal.server.background.pipeline_tasks.base import Pipeline
44
from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline
5+
from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline
56
from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline
67
from dstack._internal.server.background.pipeline_tasks.placement_groups import (
78
PlacementGroupPipeline,
@@ -16,6 +17,7 @@ class PipelineManager:
1617
def __init__(self) -> None:
1718
self._pipelines: list[Pipeline] = [
1819
ComputeGroupPipeline(),
20+
FleetPipeline(),
1921
GatewayPipeline(),
2022
PlacementGroupPipeline(),
2123
VolumePipeline(),

src/dstack/_internal/server/background/pipeline_tasks/base.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,20 @@
33
import random
44
import uuid
55
from abc import ABC, abstractmethod
6+
from collections.abc import Iterable, Sequence
67
from dataclasses import dataclass
78
from datetime import datetime, timedelta
8-
from typing import Any, ClassVar, Generic, Optional, Protocol, Sequence, TypeVar
9+
from typing import (
10+
Any,
11+
ClassVar,
12+
Final,
13+
Generic,
14+
Optional,
15+
Protocol,
16+
TypedDict,
17+
TypeVar,
18+
Union,
19+
)
920

1021
from sqlalchemy import and_, or_, update
1122
from sqlalchemy.orm import Mapped
@@ -337,16 +348,71 @@ async def process(self, item: ItemT):
337348
pass
338349

339350

340-
UpdateMap = dict[str, Any]
351+
class _NowPlaceholder:
352+
pass
353+
354+
355+
NOW_PLACEHOLDER: Final = _NowPlaceholder()
356+
"""
357+
Use `NOW_PLACEHOLDER` together with `resolve_now_placeholders()` in pipeline update maps
358+
instead of `get_current_time()` to have the same current time for all updates in the transaction.
359+
"""
360+
361+
362+
UpdateMapDateTime = Union[datetime, _NowPlaceholder]
363+
364+
365+
class _UnlockUpdateMap(TypedDict, total=False):
366+
lock_expires_at: Optional[datetime]
367+
lock_token: Optional[uuid.UUID]
368+
lock_owner: Optional[str]
369+
370+
371+
class _ProcessedUpdateMap(TypedDict, total=False):
372+
last_processed_at: UpdateMapDateTime
373+
341374

375+
class ItemUpdateMap(_UnlockUpdateMap, _ProcessedUpdateMap, total=False):
376+
lock_expires_at: Optional[datetime]
377+
lock_token: Optional[uuid.UUID]
378+
lock_owner: Optional[str]
379+
last_processed_at: UpdateMapDateTime
342380

343-
def get_unlock_update_map() -> UpdateMap:
344-
return {
345-
"lock_expires_at": None,
346-
"lock_token": None,
347-
"lock_owner": None,
348-
}
349381

382+
def set_unlock_update_map_fields(update_map: _UnlockUpdateMap):
383+
update_map["lock_expires_at"] = None
384+
update_map["lock_token"] = None
385+
update_map["lock_owner"] = None
350386

351-
def get_processed_update_map() -> UpdateMap:
352-
return {"last_processed_at": get_current_datetime()}
387+
388+
def set_processed_update_map_fields(
389+
update_map: _ProcessedUpdateMap,
390+
now: UpdateMapDateTime = NOW_PLACEHOLDER,
391+
):
392+
update_map["last_processed_at"] = now
393+
394+
395+
class _ResolveNowUpdateMap(Protocol):
396+
def items(self) -> Iterable[tuple[str, object]]: ...
397+
398+
399+
_ResolveNowInput = Union[_ResolveNowUpdateMap, Sequence[_ResolveNowUpdateMap]]
400+
401+
402+
def resolve_now_placeholders(update_values: _ResolveNowInput, now: datetime):
403+
"""
404+
Replaces `NOW_PLACEHOLDER` with `now` in an update map or a sequence of update rows.
405+
"""
406+
if isinstance(update_values, Sequence):
407+
for update_row in update_values:
408+
resolve_now_placeholders(update_row, now)
409+
return
410+
# Runtime dict narrowing is required here: pyright doesn't model TypedDicts as
411+
# supporting generic dynamic-key mutation via protocol methods.
412+
if not isinstance(update_values, dict):
413+
raise TypeError(
414+
"resolve_now_placeholders() expects update maps or sequences of update maps"
415+
)
416+
for key, value in update_values.items():
417+
if value is NOW_PLACEHOLDER:
418+
update_values[key] = now

src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import uuid
33
from dataclasses import dataclass, field
44
from datetime import datetime, timedelta
5-
from typing import Sequence
5+
from typing import Sequence, TypedDict
66

77
from sqlalchemy import or_, select, update
88
from sqlalchemy.orm import joinedload, load_only
@@ -12,14 +12,17 @@
1212
from dstack._internal.core.models.compute_groups import ComputeGroupStatus
1313
from dstack._internal.core.models.instances import InstanceStatus
1414
from dstack._internal.server.background.pipeline_tasks.base import (
15+
NOW_PLACEHOLDER,
1516
Fetcher,
1617
Heartbeater,
18+
ItemUpdateMap,
1719
Pipeline,
1820
PipelineItem,
19-
UpdateMap,
21+
UpdateMapDateTime,
2022
Worker,
21-
get_processed_update_map,
22-
get_unlock_update_map,
23+
resolve_now_placeholders,
24+
set_processed_update_map_fields,
25+
set_unlock_update_map_fields,
2326
)
2427
from dstack._internal.server.db import get_db, get_session_ctx
2528
from dstack._internal.server.models import ComputeGroupModel, InstanceModel, ProjectModel
@@ -199,25 +202,28 @@ async def process(self, item: PipelineItem):
199202
)
200203
return
201204

202-
terminate_result = _TerminateResult()
205+
result = _TerminateResult()
203206
# TODO: Fetch only compute groups with all instances terminating.
204207
if all(i.status == InstanceStatus.TERMINATING for i in compute_group_model.instances):
205-
terminate_result = await _terminate_compute_group(compute_group_model)
206-
if terminate_result.compute_group_update_map:
208+
result = await _terminate_compute_group(compute_group_model)
209+
set_processed_update_map_fields(result.compute_group_update_map)
210+
if result.instances_update_map:
211+
set_processed_update_map_fields(result.instances_update_map)
212+
set_unlock_update_map_fields(result.compute_group_update_map)
213+
if result.compute_group_update_map.get("deleted", False):
207214
logger.info("Terminated compute group %s", compute_group_model.id)
208-
else:
209-
terminate_result.compute_group_update_map = get_processed_update_map()
210-
211-
terminate_result.compute_group_update_map |= get_unlock_update_map()
212215

213216
async with get_session_ctx() as session:
217+
now = get_current_datetime()
218+
resolve_now_placeholders(result.compute_group_update_map, now=now)
219+
resolve_now_placeholders(result.instances_update_map, now=now)
214220
res = await session.execute(
215221
update(ComputeGroupModel)
216222
.where(
217223
ComputeGroupModel.id == compute_group_model.id,
218224
ComputeGroupModel.lock_token == compute_group_model.lock_token,
219225
)
220-
.values(**terminate_result.compute_group_update_map)
226+
.values(**result.compute_group_update_map)
221227
.returning(ComputeGroupModel.id)
222228
)
223229
updated_ids = list(res.scalars().all())
@@ -229,13 +235,13 @@ async def process(self, item: PipelineItem):
229235
item.id,
230236
)
231237
return
232-
if not terminate_result.instances_update_map:
238+
if not result.instances_update_map:
233239
return
234240
instances_ids = [i.id for i in compute_group_model.instances]
235241
res = await session.execute(
236242
update(InstanceModel)
237243
.where(InstanceModel.id.in_(instances_ids))
238-
.values(**terminate_result.instances_update_map)
244+
.values(**result.instances_update_map)
239245
)
240246
for instance_model in compute_group_model.instances:
241247
emit_instance_status_change_event(
@@ -246,10 +252,28 @@ async def process(self, item: PipelineItem):
246252
)
247253

248254

255+
class _ComputeGroupUpdateMap(ItemUpdateMap, total=False):
256+
status: ComputeGroupStatus
257+
deleted: bool
258+
deleted_at: UpdateMapDateTime
259+
first_termination_retry_at: UpdateMapDateTime
260+
last_termination_retry_at: UpdateMapDateTime
261+
262+
263+
class _InstanceBulkUpdateMap(TypedDict, total=False):
264+
last_processed_at: UpdateMapDateTime
265+
deleted: bool
266+
deleted_at: UpdateMapDateTime
267+
finished_at: UpdateMapDateTime
268+
status: InstanceStatus
269+
270+
249271
@dataclass
250272
class _TerminateResult:
251-
compute_group_update_map: UpdateMap = field(default_factory=dict)
252-
instances_update_map: UpdateMap = field(default_factory=dict)
273+
compute_group_update_map: _ComputeGroupUpdateMap = field(
274+
default_factory=_ComputeGroupUpdateMap
275+
)
276+
instances_update_map: _InstanceBulkUpdateMap = field(default_factory=_InstanceBulkUpdateMap)
253277

254278

255279
async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _TerminateResult:
@@ -283,15 +307,15 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T
283307
compute_group,
284308
)
285309
except Exception as e:
310+
retry_at = get_current_datetime()
311+
first_termination_retry_at = compute_group_model.first_termination_retry_at
286312
if compute_group_model.first_termination_retry_at is None:
287-
result.compute_group_update_map["first_termination_retry_at"] = get_current_datetime()
288-
result.compute_group_update_map["last_termination_retry_at"] = get_current_datetime()
289-
if _next_termination_retry_at(
290-
result.compute_group_update_map["last_termination_retry_at"]
291-
) < _get_termination_deadline(
292-
result.compute_group_update_map.get(
293-
"first_termination_retry_at", compute_group_model.first_termination_retry_at
294-
)
313+
result.compute_group_update_map["first_termination_retry_at"] = NOW_PLACEHOLDER
314+
first_termination_retry_at = retry_at
315+
assert first_termination_retry_at is not None
316+
result.compute_group_update_map["last_termination_retry_at"] = NOW_PLACEHOLDER
317+
if _next_termination_retry_at(retry_at) < _get_termination_deadline(
318+
first_termination_retry_at
295319
):
296320
logger.warning(
297321
"Failed to terminate compute group %s. Will retry. Error: %r",
@@ -309,11 +333,9 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T
309333
exc_info=not isinstance(e, BackendError),
310334
)
311335
terminated_result = _get_terminated_result()
312-
return _TerminateResult(
313-
compute_group_update_map=result.compute_group_update_map
314-
| terminated_result.compute_group_update_map,
315-
instances_update_map=result.instances_update_map | terminated_result.instances_update_map,
316-
)
336+
terminated_result.compute_group_update_map.update(result.compute_group_update_map)
337+
terminated_result.instances_update_map.update(result.instances_update_map)
338+
return terminated_result
317339

318340

319341
def _next_termination_retry_at(last_termination_retry_at: datetime) -> datetime:
@@ -325,19 +347,16 @@ def _get_termination_deadline(first_termination_retry_at: datetime) -> datetime:
325347

326348

327349
def _get_terminated_result() -> _TerminateResult:
328-
now = get_current_datetime()
329350
return _TerminateResult(
330351
compute_group_update_map={
331-
"last_processed_at": now,
332352
"deleted": True,
333-
"deleted_at": now,
353+
"deleted_at": NOW_PLACEHOLDER,
334354
"status": ComputeGroupStatus.TERMINATED,
335355
},
336356
instances_update_map={
337-
"last_processed_at": now,
338357
"deleted": True,
339-
"deleted_at": now,
340-
"finished_at": now,
358+
"deleted_at": NOW_PLACEHOLDER,
359+
"finished_at": NOW_PLACEHOLDER,
341360
"status": InstanceStatus.TERMINATED,
342361
},
343362
)

0 commit comments

Comments
 (0)