22import uuid
33from dataclasses import dataclass , field
44from datetime import datetime , timedelta
5- from typing import Sequence
5+ from typing import Sequence , TypedDict
66
77from sqlalchemy import or_ , select , update
88from sqlalchemy .orm import joinedload , load_only
1212from dstack ._internal .core .models .compute_groups import ComputeGroupStatus
1313from dstack ._internal .core .models .instances import InstanceStatus
1414from dstack ._internal .server .background .pipeline_tasks .base import (
15+ NOW_PLACEHOLDER ,
1516 Fetcher ,
1617 Heartbeater ,
18+ ItemUpdateMap ,
1719 Pipeline ,
1820 PipelineItem ,
19- UpdateMap ,
21+ UpdateMapDateTime ,
2022 Worker ,
21- get_processed_update_map ,
22- get_unlock_update_map ,
23+ resolve_now_placeholders ,
24+ set_processed_update_map_fields ,
25+ set_unlock_update_map_fields ,
2326)
2427from dstack ._internal .server .db import get_db , get_session_ctx
2528from dstack ._internal .server .models import ComputeGroupModel , InstanceModel , ProjectModel
@@ -199,25 +202,28 @@ async def process(self, item: PipelineItem):
199202 )
200203 return
201204
202- terminate_result = _TerminateResult ()
205+ result = _TerminateResult ()
203206 # TODO: Fetch only compute groups with all instances terminating.
204207 if all (i .status == InstanceStatus .TERMINATING for i in compute_group_model .instances ):
205- terminate_result = await _terminate_compute_group (compute_group_model )
206- if terminate_result .compute_group_update_map :
208+ result = await _terminate_compute_group (compute_group_model )
209+ set_processed_update_map_fields (result .compute_group_update_map )
210+ if result .instances_update_map :
211+ set_processed_update_map_fields (result .instances_update_map )
212+ set_unlock_update_map_fields (result .compute_group_update_map )
213+ if result .compute_group_update_map .get ("deleted" , False ):
207214 logger .info ("Terminated compute group %s" , compute_group_model .id )
208- else :
209- terminate_result .compute_group_update_map = get_processed_update_map ()
210-
211- terminate_result .compute_group_update_map |= get_unlock_update_map ()
212215
213216 async with get_session_ctx () as session :
217+ now = get_current_datetime ()
218+ resolve_now_placeholders (result .compute_group_update_map , now = now )
219+ resolve_now_placeholders (result .instances_update_map , now = now )
214220 res = await session .execute (
215221 update (ComputeGroupModel )
216222 .where (
217223 ComputeGroupModel .id == compute_group_model .id ,
218224 ComputeGroupModel .lock_token == compute_group_model .lock_token ,
219225 )
220- .values (** terminate_result .compute_group_update_map )
226+ .values (** result .compute_group_update_map )
221227 .returning (ComputeGroupModel .id )
222228 )
223229 updated_ids = list (res .scalars ().all ())
@@ -229,13 +235,13 @@ async def process(self, item: PipelineItem):
229235 item .id ,
230236 )
231237 return
232- if not terminate_result .instances_update_map :
238+ if not result .instances_update_map :
233239 return
234240 instances_ids = [i .id for i in compute_group_model .instances ]
235241 res = await session .execute (
236242 update (InstanceModel )
237243 .where (InstanceModel .id .in_ (instances_ids ))
238- .values (** terminate_result .instances_update_map )
244+ .values (** result .instances_update_map )
239245 )
240246 for instance_model in compute_group_model .instances :
241247 emit_instance_status_change_event (
@@ -246,10 +252,28 @@ async def process(self, item: PipelineItem):
246252 )
247253
248254
255+ class _ComputeGroupUpdateMap (ItemUpdateMap , total = False ):
256+ status : ComputeGroupStatus
257+ deleted : bool
258+ deleted_at : UpdateMapDateTime
259+ first_termination_retry_at : UpdateMapDateTime
260+ last_termination_retry_at : UpdateMapDateTime
261+
262+
263+ class _InstanceBulkUpdateMap (TypedDict , total = False ):
264+ last_processed_at : UpdateMapDateTime
265+ deleted : bool
266+ deleted_at : UpdateMapDateTime
267+ finished_at : UpdateMapDateTime
268+ status : InstanceStatus
269+
270+
249271@dataclass
250272class _TerminateResult :
251- compute_group_update_map : UpdateMap = field (default_factory = dict )
252- instances_update_map : UpdateMap = field (default_factory = dict )
273+ compute_group_update_map : _ComputeGroupUpdateMap = field (
274+ default_factory = _ComputeGroupUpdateMap
275+ )
276+ instances_update_map : _InstanceBulkUpdateMap = field (default_factory = _InstanceBulkUpdateMap )
253277
254278
255279async def _terminate_compute_group (compute_group_model : ComputeGroupModel ) -> _TerminateResult :
@@ -283,15 +307,15 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T
283307 compute_group ,
284308 )
285309 except Exception as e :
310+ retry_at = get_current_datetime ()
311+ first_termination_retry_at = compute_group_model .first_termination_retry_at
286312 if compute_group_model .first_termination_retry_at is None :
287- result .compute_group_update_map ["first_termination_retry_at" ] = get_current_datetime ()
288- result .compute_group_update_map ["last_termination_retry_at" ] = get_current_datetime ()
289- if _next_termination_retry_at (
290- result .compute_group_update_map ["last_termination_retry_at" ]
291- ) < _get_termination_deadline (
292- result .compute_group_update_map .get (
293- "first_termination_retry_at" , compute_group_model .first_termination_retry_at
294- )
313+ result .compute_group_update_map ["first_termination_retry_at" ] = NOW_PLACEHOLDER
314+ first_termination_retry_at = retry_at
315+ assert first_termination_retry_at is not None
316+ result .compute_group_update_map ["last_termination_retry_at" ] = NOW_PLACEHOLDER
317+ if _next_termination_retry_at (retry_at ) < _get_termination_deadline (
318+ first_termination_retry_at
295319 ):
296320 logger .warning (
297321 "Failed to terminate compute group %s. Will retry. Error: %r" ,
@@ -309,11 +333,9 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T
309333 exc_info = not isinstance (e , BackendError ),
310334 )
311335 terminated_result = _get_terminated_result ()
312- return _TerminateResult (
313- compute_group_update_map = result .compute_group_update_map
314- | terminated_result .compute_group_update_map ,
315- instances_update_map = result .instances_update_map | terminated_result .instances_update_map ,
316- )
336+ terminated_result .compute_group_update_map .update (result .compute_group_update_map )
337+ terminated_result .instances_update_map .update (result .instances_update_map )
338+ return terminated_result
317339
318340
319341def _next_termination_retry_at (last_termination_retry_at : datetime ) -> datetime :
@@ -325,19 +347,16 @@ def _get_termination_deadline(first_termination_retry_at: datetime) -> datetime:
325347
326348
327349def _get_terminated_result () -> _TerminateResult :
328- now = get_current_datetime ()
329350 return _TerminateResult (
330351 compute_group_update_map = {
331- "last_processed_at" : now ,
332352 "deleted" : True ,
333- "deleted_at" : now ,
353+ "deleted_at" : NOW_PLACEHOLDER ,
334354 "status" : ComputeGroupStatus .TERMINATED ,
335355 },
336356 instances_update_map = {
337- "last_processed_at" : now ,
338357 "deleted" : True ,
339- "deleted_at" : now ,
340- "finished_at" : now ,
358+ "deleted_at" : NOW_PLACEHOLDER ,
359+ "finished_at" : NOW_PLACEHOLDER ,
341360 "status" : InstanceStatus .TERMINATED ,
342361 },
343362 )
0 commit comments