11import json
2+ from dataclasses import dataclass
23from typing import List , Optional , Tuple
34
45from sqlalchemy .ext .asyncio import AsyncSession
@@ -220,6 +221,65 @@ async def _scale_up_replicas(
220221 run_model .jobs .append (job_model )
221222
222223
224+ @dataclass
225+ class GroupRolloutState :
226+ active_replicas : List [Tuple [int , bool , int , List [JobModel ]]]
227+ inactive_replicas : List [Tuple [int , bool , int , List [JobModel ]]]
228+ has_out_of_date_replicas : bool
229+ non_terminated_replica_count : int
230+ unregistered_out_of_date_replica_count : int
231+ registered_non_terminating_replica_count : int
232+
233+
234+ def get_group_desired_replica_count (run_model : RunModel , group : ReplicaGroup ) -> int :
235+ assert group .name is not None , "Group name is always set"
236+ desired_replica_counts = (
237+ json .loads (run_model .desired_replica_counts ) if run_model .desired_replica_counts else {}
238+ )
239+ return desired_replica_counts .get (group .name , group .count .min or 0 )
240+
241+
242+ def get_group_rollout_state (run_model : RunModel , group : ReplicaGroup ) -> GroupRolloutState :
243+ assert group .name is not None , "Group name is always set"
244+ active_replicas , inactive_replicas = build_replica_lists (
245+ run_model = run_model ,
246+ group_filter = group .name ,
247+ )
248+
249+ non_terminated_replica_nums = set ()
250+ unregistered_out_of_date_replica_count = 0
251+ registered_non_terminating_replica_count = 0
252+
253+ for _ , jobs in group_jobs_by_replica_latest (run_model .jobs ):
254+ if not job_belongs_to_group (jobs [0 ], group .name ):
255+ continue
256+
257+ if any (not j .status .is_finished () for j in jobs ):
258+ non_terminated_replica_nums .add (jobs [0 ].replica_num )
259+
260+ if (
261+ any (j .deployment_num < run_model .deployment_num for j in jobs )
262+ and any (
263+ j .status not in [JobStatus .TERMINATING ] + JobStatus .finished_statuses ()
264+ for j in jobs
265+ )
266+ and not is_replica_registered (jobs )
267+ ):
268+ unregistered_out_of_date_replica_count += 1
269+
270+ if is_replica_registered (jobs ) and all (j .status != JobStatus .TERMINATING for j in jobs ):
271+ registered_non_terminating_replica_count += 1
272+
273+ return GroupRolloutState (
274+ active_replicas = active_replicas ,
275+ inactive_replicas = inactive_replicas ,
276+ has_out_of_date_replicas = has_out_of_date_replicas (run_model , group_filter = group .name ),
277+ non_terminated_replica_count = len (non_terminated_replica_nums ),
278+ unregistered_out_of_date_replica_count = unregistered_out_of_date_replica_count ,
279+ registered_non_terminating_replica_count = registered_non_terminating_replica_count ,
280+ )
281+
282+
223283async def scale_run_replicas_for_all_groups (
224284 session : AsyncSession ,
225285 run_model : RunModel ,
@@ -229,41 +289,27 @@ async def scale_run_replicas_for_all_groups(
229289 if not replicas :
230290 return
231291
232- desired_replica_counts = (
233- json .loads (run_model .desired_replica_counts ) if run_model .desired_replica_counts else {}
234- )
292+ run_spec = get_run_spec (run_model )
235293
236294 for group in replicas :
237- assert group .name is not None , "Group name is always set"
238- group_desired = desired_replica_counts .get (group .name , group .count .min or 0 )
239-
240- # Build replica lists filtered by this group
241- active_replicas , inactive_replicas = build_replica_lists (
242- run_model = run_model , group_filter = group .name
243- )
244-
245- # Count active replicas
246- active_group_count = len (active_replicas )
247- group_diff = group_desired - active_group_count
295+ group_desired = get_group_desired_replica_count (run_model , group )
296+ state = get_group_rollout_state (run_model , group )
297+ group_diff = group_desired - len (state .active_replicas )
248298
249299 if group_diff != 0 :
250- # Check if rolling deployment is in progress for THIS GROUP
251-
252- group_has_out_of_date = has_out_of_date_replicas (run_model , group_filter = group .name )
253-
254300 # During rolling deployment, don't scale down old replicas
255301 # Let rolling deployment handle stopping old replicas
256- if group_diff < 0 and group_has_out_of_date :
302+ if group_diff < 0 and state . has_out_of_date_replicas :
257303 # Skip scaling down during rolling deployment
258304 continue
259305 await scale_run_replicas_for_group (
260306 session = session ,
261307 run_model = run_model ,
262308 group = group ,
263309 replicas_diff = group_diff ,
264- run_spec = get_run_spec ( run_model ) ,
265- active_replicas = active_replicas ,
266- inactive_replicas = inactive_replicas ,
310+ run_spec = run_spec ,
311+ active_replicas = state . active_replicas ,
312+ inactive_replicas = state . inactive_replicas ,
267313 )
268314
269315
0 commit comments