@@ -320,7 +320,7 @@ class _ReplicaAnalysis:
320320class _ActiveRunAnalysis :
321321 """Aggregated replica analysis used to determine the run's next status.
322322
323- Each replica contributes `RunStatus` based on its jobs.
323+ Each replica contributes `RunStatus` based on its jobs' statuses .
324324 The run's new status is the highest-priority value across all
325325 contributing replicas: FAILED > RUNNING > PROVISIONING > SUBMITTED > DONE.
326326 Replicas that need full retry do not contribute and instead cause a PENDING transition.
@@ -520,6 +520,7 @@ def _job_needs_retry_evaluation(job_model: JobModel) -> bool:
520520
521521
522522def _get_active_run_transition (run : Run , analysis : _ActiveRunAnalysis ) -> _ActiveRunTransition :
523+ # Check `analysis.contributed_statuses` in the priority order.
523524 if RunStatus .FAILED in analysis .contributed_statuses :
524525 if RunTerminationReason .JOB_FAILED in analysis .termination_reasons :
525526 termination_reason = RunTerminationReason .JOB_FAILED
@@ -660,12 +661,10 @@ async def _handle_run_replicas(
660661 replicas_info : list [autoscalers .ReplicaInfo ],
661662) -> None :
662663 """
663- Does ONE of:
664- - replica retry
665- - replica scaling
666- - replica rolling deployment
667-
668- Does not do everything at once to avoid conflicts between the stages and long DB transactions.
664+ Performs one or more steps:
665+ - replicas retry
666+ - replicas scaling
667+ - replicas rolling deployment
669668 """
670669
671670 if replicas_to_retry :
@@ -683,53 +682,26 @@ async def _handle_run_replicas(
683682 # FIXME: should only include scaling events, not retries and deployments
684683 last_scaled_at = max ((r .timestamp for r in replicas_info ), default = None ),
685684 )
686- replicas : List [ReplicaGroup ] = run_spec .configuration .replica_groups
687- assert replicas , "replica groups should always return at least one group"
685+ replica_groups : List [ReplicaGroup ] = run_spec .configuration .replica_groups
686+ assert replica_groups , "replica groups should always return at least one group"
688687
689- await scale_run_replicas_for_all_groups (session , run_model , replicas )
688+ await scale_run_replicas_for_all_groups (session , run_model , replica_groups )
690689
691- # Handle per-group rolling deployment
692690 await _update_jobs_to_new_deployment_in_place (
693691 session = session ,
694692 run_model = run_model ,
695693 run_spec = run_spec ,
696- replicas = replicas ,
694+ replicas = replica_groups ,
697695 )
698- # Process per-group rolling deployment
699- for group in replicas :
696+
697+ for group in replica_groups :
700698 await _handle_rolling_deployment_for_group (
701699 session = session , run_model = run_model , group = group , run_spec = run_spec
702700 )
703- # Terminate replicas from groups that were removed from the configuration
704- existing_group_names = set ()
705- for job in run_model .jobs :
706- if job .status .is_finished ():
707- continue
708- job_spec = get_job_spec (job )
709- existing_group_names .add (job_spec .replica_group )
710- new_group_names = {group .name for group in replicas }
711- removed_group_names = existing_group_names - new_group_names
712- for removed_group_name in removed_group_names :
713- # Build replica lists for this removed group
714- active_replicas , inactive_replicas = build_replica_lists (
715- run_model = run_model ,
716- group_filter = removed_group_name ,
717- )
718701
719- total_replicas = len (active_replicas ) + len (inactive_replicas )
720- if total_replicas > 0 :
721- logger .info (
722- "%s: terminating %d replica(s) from removed group '%s'" ,
723- fmt (run_model ),
724- total_replicas ,
725- removed_group_name ,
726- )
727- # Terminate all active replicas in the removed group
728- if active_replicas :
729- scale_down_replicas (session , active_replicas , len (active_replicas ))
730- # Terminate all inactive replicas in the removed group
731- if inactive_replicas :
732- scale_down_replicas (session , inactive_replicas , len (inactive_replicas ))
702+ _terminate_removed_replica_groups (
703+ session = session , run_model = run_model , replica_groups = replica_groups
704+ )
733705 return
734706
735707 await _update_jobs_to_new_deployment_in_place (
@@ -883,21 +855,15 @@ async def _handle_rolling_deployment_for_group(
883855 """
884856 Handle rolling deployment for a single replica group.
885857 """
858+ if not has_out_of_date_replicas (run_model , group_filter = group .name ):
859+ return
860+
886861 desired_replica_counts = (
887862 json .loads (run_model .desired_replica_counts ) if run_model .desired_replica_counts else {}
888863 )
889-
890864 group_desired = desired_replica_counts .get (group .name , group .count .min or 0 )
891-
892- # Check if group has out-of-date replicas
893- if not has_out_of_date_replicas (run_model , group_filter = group .name ):
894- return # Group is up-to-date
895-
896- # Calculate max replicas (allow surge during deployment)
897865 group_max_replica_count = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE
898866
899- # Count non-terminated replicas for this group only
900-
901867 non_terminated_replica_count = len (
902868 {
903869 j .replica_num
@@ -970,3 +936,33 @@ async def _handle_rolling_deployment_for_group(
970936 active_replicas = active_replicas ,
971937 inactive_replicas = inactive_replicas ,
972938 )
939+
940+
941+ def _terminate_removed_replica_groups (
942+ session : AsyncSession , run_model : RunModel , replica_groups : List [ReplicaGroup ]
943+ ):
944+ existing_group_names = set ()
945+ for job in run_model .jobs :
946+ if job .status .is_finished ():
947+ continue
948+ job_spec = get_job_spec (job )
949+ existing_group_names .add (job_spec .replica_group )
950+ new_group_names = {group .name for group in replica_groups }
951+ removed_group_names = existing_group_names - new_group_names
952+ for removed_group_name in removed_group_names :
953+ active_replicas , inactive_replicas = build_replica_lists (
954+ run_model = run_model ,
955+ group_filter = removed_group_name ,
956+ )
957+ total_replicas = len (active_replicas ) + len (inactive_replicas )
958+ if total_replicas > 0 :
959+ logger .info (
960+ "%s: terminating %d replica(s) from removed group '%s'" ,
961+ fmt (run_model ),
962+ total_replicas ,
963+ removed_group_name ,
964+ )
965+ if active_replicas :
966+ scale_down_replicas (session , active_replicas , len (active_replicas ))
967+ if inactive_replicas :
968+ scale_down_replicas (session , inactive_replicas , len (inactive_replicas ))
0 commit comments