Skip to content

Commit ecb8ba7

Browse files
committed
Extracted _terminate_removed_replica_groups
1 parent f1d0180 commit ecb8ba7

1 file changed

Lines changed: 48 additions & 52 deletions

File tree

  • src/dstack/_internal/server/background/scheduled_tasks

src/dstack/_internal/server/background/scheduled_tasks/runs.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ class _ReplicaAnalysis:
320320
class _ActiveRunAnalysis:
321321
"""Aggregated replica analysis used to determine the run's next status.
322322
323-
Each replica contributes `RunStatus` based on its jobs.
323+
Each replica contributes `RunStatus` based on its jobs' statuses.
324324
The run's new status is the highest-priority value across all
325325
contributing replicas: FAILED > RUNNING > PROVISIONING > SUBMITTED > DONE.
326326
Replicas that need full retry do not contribute and instead cause a PENDING transition.
@@ -520,6 +520,7 @@ def _job_needs_retry_evaluation(job_model: JobModel) -> bool:
520520

521521

522522
def _get_active_run_transition(run: Run, analysis: _ActiveRunAnalysis) -> _ActiveRunTransition:
523+
# Check `analysis.contributed_statuses` in the priority order.
523524
if RunStatus.FAILED in analysis.contributed_statuses:
524525
if RunTerminationReason.JOB_FAILED in analysis.termination_reasons:
525526
termination_reason = RunTerminationReason.JOB_FAILED
@@ -660,12 +661,10 @@ async def _handle_run_replicas(
660661
replicas_info: list[autoscalers.ReplicaInfo],
661662
) -> None:
662663
"""
663-
Does ONE of:
664-
- replica retry
665-
- replica scaling
666-
- replica rolling deployment
667-
668-
Does not do everything at once to avoid conflicts between the stages and long DB transactions.
664+
Performs one or more steps:
665+
- replicas retry
666+
- replicas scaling
667+
- replicas rolling deployment
669668
"""
670669

671670
if replicas_to_retry:
@@ -683,53 +682,26 @@ async def _handle_run_replicas(
683682
# FIXME: should only include scaling events, not retries and deployments
684683
last_scaled_at=max((r.timestamp for r in replicas_info), default=None),
685684
)
686-
replicas: List[ReplicaGroup] = run_spec.configuration.replica_groups
687-
assert replicas, "replica groups should always return at least one group"
685+
replica_groups: List[ReplicaGroup] = run_spec.configuration.replica_groups
686+
assert replica_groups, "replica groups should always return at least one group"
688687

689-
await scale_run_replicas_for_all_groups(session, run_model, replicas)
688+
await scale_run_replicas_for_all_groups(session, run_model, replica_groups)
690689

691-
# Handle per-group rolling deployment
692690
await _update_jobs_to_new_deployment_in_place(
693691
session=session,
694692
run_model=run_model,
695693
run_spec=run_spec,
696-
replicas=replicas,
694+
replicas=replica_groups,
697695
)
698-
# Process per-group rolling deployment
699-
for group in replicas:
696+
697+
for group in replica_groups:
700698
await _handle_rolling_deployment_for_group(
701699
session=session, run_model=run_model, group=group, run_spec=run_spec
702700
)
703-
# Terminate replicas from groups that were removed from the configuration
704-
existing_group_names = set()
705-
for job in run_model.jobs:
706-
if job.status.is_finished():
707-
continue
708-
job_spec = get_job_spec(job)
709-
existing_group_names.add(job_spec.replica_group)
710-
new_group_names = {group.name for group in replicas}
711-
removed_group_names = existing_group_names - new_group_names
712-
for removed_group_name in removed_group_names:
713-
# Build replica lists for this removed group
714-
active_replicas, inactive_replicas = build_replica_lists(
715-
run_model=run_model,
716-
group_filter=removed_group_name,
717-
)
718701

719-
total_replicas = len(active_replicas) + len(inactive_replicas)
720-
if total_replicas > 0:
721-
logger.info(
722-
"%s: terminating %d replica(s) from removed group '%s'",
723-
fmt(run_model),
724-
total_replicas,
725-
removed_group_name,
726-
)
727-
# Terminate all active replicas in the removed group
728-
if active_replicas:
729-
scale_down_replicas(session, active_replicas, len(active_replicas))
730-
# Terminate all inactive replicas in the removed group
731-
if inactive_replicas:
732-
scale_down_replicas(session, inactive_replicas, len(inactive_replicas))
702+
_terminate_removed_replica_groups(
703+
session=session, run_model=run_model, replica_groups=replica_groups
704+
)
733705
return
734706

735707
await _update_jobs_to_new_deployment_in_place(
@@ -883,21 +855,15 @@ async def _handle_rolling_deployment_for_group(
883855
"""
884856
Handle rolling deployment for a single replica group.
885857
"""
858+
if not has_out_of_date_replicas(run_model, group_filter=group.name):
859+
return
860+
886861
desired_replica_counts = (
887862
json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {}
888863
)
889-
890864
group_desired = desired_replica_counts.get(group.name, group.count.min or 0)
891-
892-
# Check if group has out-of-date replicas
893-
if not has_out_of_date_replicas(run_model, group_filter=group.name):
894-
return # Group is up-to-date
895-
896-
# Calculate max replicas (allow surge during deployment)
897865
group_max_replica_count = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE
898866

899-
# Count non-terminated replicas for this group only
900-
901867
non_terminated_replica_count = len(
902868
{
903869
j.replica_num
@@ -970,3 +936,33 @@ async def _handle_rolling_deployment_for_group(
970936
active_replicas=active_replicas,
971937
inactive_replicas=inactive_replicas,
972938
)
939+
940+
941+
def _terminate_removed_replica_groups(
942+
session: AsyncSession, run_model: RunModel, replica_groups: List[ReplicaGroup]
943+
):
944+
existing_group_names = set()
945+
for job in run_model.jobs:
946+
if job.status.is_finished():
947+
continue
948+
job_spec = get_job_spec(job)
949+
existing_group_names.add(job_spec.replica_group)
950+
new_group_names = {group.name for group in replica_groups}
951+
removed_group_names = existing_group_names - new_group_names
952+
for removed_group_name in removed_group_names:
953+
active_replicas, inactive_replicas = build_replica_lists(
954+
run_model=run_model,
955+
group_filter=removed_group_name,
956+
)
957+
total_replicas = len(active_replicas) + len(inactive_replicas)
958+
if total_replicas > 0:
959+
logger.info(
960+
"%s: terminating %d replica(s) from removed group '%s'",
961+
fmt(run_model),
962+
total_replicas,
963+
removed_group_name,
964+
)
965+
if active_replicas:
966+
scale_down_replicas(session, active_replicas, len(active_replicas))
967+
if inactive_replicas:
968+
scale_down_replicas(session, inactive_replicas, len(inactive_replicas))

0 commit comments

Comments
 (0)