Skip to content

Commit 3704dd2

Browse files
committed
Extract get_group_rollout_state
1 parent fe16dc9 commit 3704dd2

2 files changed

Lines changed: 85 additions & 85 deletions

File tree

src/dstack/_internal/server/background/scheduled_tasks/runs.py

Lines changed: 17 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import datetime
3-
import json
43
from dataclasses import dataclass, field
54
from typing import Dict, List, Optional, Set, Tuple
65

@@ -49,9 +48,9 @@
4948
)
5049
from dstack._internal.server.services.runs.replicas import (
5150
build_replica_lists,
51+
get_group_desired_replica_count,
52+
get_group_rollout_state,
5253
has_out_of_date_replicas,
53-
is_replica_registered,
54-
job_belongs_to_group,
5554
retry_run_replica_jobs,
5655
scale_down_replicas,
5756
scale_run_replicas,
@@ -851,86 +850,41 @@ async def _handle_rolling_deployment_for_group(
851850
"""
852851
Handle rolling deployment for a single replica group.
853852
"""
854-
if not has_out_of_date_replicas(run_model, group_filter=group.name):
853+
group_desired = get_group_desired_replica_count(run_model, group)
854+
state = get_group_rollout_state(run_model, group)
855+
if not state.has_out_of_date_replicas:
855856
return
856857

857-
desired_replica_counts = (
858-
json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {}
859-
)
860-
group_desired = desired_replica_counts.get(group.name, group.count.min or 0)
861858
group_max_replica_count = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE
862859

863-
non_terminated_replica_count = len(
864-
{
865-
j.replica_num
866-
for j in run_model.jobs
867-
if not j.status.is_finished()
868-
and group.name is not None
869-
and job_belongs_to_group(job=j, group_name=group.name)
870-
}
871-
)
872-
873860
# Start new up-to-date replicas if needed
874-
if non_terminated_replica_count < group_max_replica_count:
875-
active_replicas, inactive_replicas = build_replica_lists(
876-
run_model=run_model,
877-
group_filter=group.name,
878-
)
879-
861+
if state.non_terminated_replica_count < group_max_replica_count:
880862
await scale_run_replicas_for_group(
881863
session=session,
882864
run_model=run_model,
883865
group=group,
884-
replicas_diff=group_max_replica_count - non_terminated_replica_count,
866+
replicas_diff=group_max_replica_count - state.non_terminated_replica_count,
885867
run_spec=run_spec,
886-
active_replicas=active_replicas,
887-
inactive_replicas=inactive_replicas,
868+
active_replicas=state.active_replicas,
869+
inactive_replicas=state.inactive_replicas,
888870
)
871+
state = get_group_rollout_state(run_model, group)
889872

890-
# Stop out-of-date replicas that are not registered
891-
replicas_to_stop_count = 0
892-
for _, jobs in group_jobs_by_replica_latest(run_model.jobs):
893-
assert group.name is not None, "Group name is always set"
894-
if not job_belongs_to_group(jobs[0], group.name):
895-
continue
896-
# Check if replica is out-of-date and not registered
897-
if (
898-
any(j.deployment_num < run_model.deployment_num for j in jobs)
899-
and any(
900-
j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses()
901-
for j in jobs
902-
)
903-
and not is_replica_registered(jobs)
904-
):
905-
replicas_to_stop_count += 1
906-
907-
# Stop excessive registered out-of-date replicas
908-
non_terminating_registered_replicas_count = 0
909-
for _, jobs in group_jobs_by_replica_latest(run_model.jobs):
910-
assert group.name is not None, "Group name is always set"
911-
if not job_belongs_to_group(jobs[0], group.name):
912-
continue
913-
914-
if is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs):
915-
non_terminating_registered_replicas_count += 1
916-
917-
replicas_to_stop_count += max(0, non_terminating_registered_replicas_count - group_desired)
873+
replicas_to_stop_count = state.unregistered_out_of_date_replica_count
874+
replicas_to_stop_count += max(
875+
0,
876+
state.registered_non_terminating_replica_count - group_desired,
877+
)
918878

919879
if replicas_to_stop_count > 0:
920-
# Build lists again to get current state
921-
active_replicas, inactive_replicas = build_replica_lists(
922-
run_model=run_model,
923-
group_filter=group.name,
924-
)
925-
926880
await scale_run_replicas_for_group(
927881
session=session,
928882
run_model=run_model,
929883
group=group,
930884
replicas_diff=-replicas_to_stop_count,
931885
run_spec=run_spec,
932-
active_replicas=active_replicas,
933-
inactive_replicas=inactive_replicas,
886+
active_replicas=state.active_replicas,
887+
inactive_replicas=state.inactive_replicas,
934888
)
935889

936890

src/dstack/_internal/server/services/runs/replicas.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from dataclasses import dataclass
23
from typing import List, Optional, Tuple
34

45
from sqlalchemy.ext.asyncio import AsyncSession
@@ -220,6 +221,65 @@ async def _scale_up_replicas(
220221
run_model.jobs.append(job_model)
221222

222223

224+
@dataclass
225+
class GroupRolloutState:
226+
active_replicas: List[Tuple[int, bool, int, List[JobModel]]]
227+
inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]]
228+
has_out_of_date_replicas: bool
229+
non_terminated_replica_count: int
230+
unregistered_out_of_date_replica_count: int
231+
registered_non_terminating_replica_count: int
232+
233+
234+
def get_group_desired_replica_count(run_model: RunModel, group: ReplicaGroup) -> int:
235+
assert group.name is not None, "Group name is always set"
236+
desired_replica_counts = (
237+
json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {}
238+
)
239+
return desired_replica_counts.get(group.name, group.count.min or 0)
240+
241+
242+
def get_group_rollout_state(run_model: RunModel, group: ReplicaGroup) -> GroupRolloutState:
243+
assert group.name is not None, "Group name is always set"
244+
active_replicas, inactive_replicas = build_replica_lists(
245+
run_model=run_model,
246+
group_filter=group.name,
247+
)
248+
249+
non_terminated_replica_nums = set()
250+
unregistered_out_of_date_replica_count = 0
251+
registered_non_terminating_replica_count = 0
252+
253+
for _, jobs in group_jobs_by_replica_latest(run_model.jobs):
254+
if not job_belongs_to_group(jobs[0], group.name):
255+
continue
256+
257+
if any(not j.status.is_finished() for j in jobs):
258+
non_terminated_replica_nums.add(jobs[0].replica_num)
259+
260+
if (
261+
any(j.deployment_num < run_model.deployment_num for j in jobs)
262+
and any(
263+
j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses()
264+
for j in jobs
265+
)
266+
and not is_replica_registered(jobs)
267+
):
268+
unregistered_out_of_date_replica_count += 1
269+
270+
if is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs):
271+
registered_non_terminating_replica_count += 1
272+
273+
return GroupRolloutState(
274+
active_replicas=active_replicas,
275+
inactive_replicas=inactive_replicas,
276+
has_out_of_date_replicas=has_out_of_date_replicas(run_model, group_filter=group.name),
277+
non_terminated_replica_count=len(non_terminated_replica_nums),
278+
unregistered_out_of_date_replica_count=unregistered_out_of_date_replica_count,
279+
registered_non_terminating_replica_count=registered_non_terminating_replica_count,
280+
)
281+
282+
223283
async def scale_run_replicas_for_all_groups(
224284
session: AsyncSession,
225285
run_model: RunModel,
@@ -229,41 +289,27 @@ async def scale_run_replicas_for_all_groups(
229289
if not replicas:
230290
return
231291

232-
desired_replica_counts = (
233-
json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {}
234-
)
292+
run_spec = get_run_spec(run_model)
235293

236294
for group in replicas:
237-
assert group.name is not None, "Group name is always set"
238-
group_desired = desired_replica_counts.get(group.name, group.count.min or 0)
239-
240-
# Build replica lists filtered by this group
241-
active_replicas, inactive_replicas = build_replica_lists(
242-
run_model=run_model, group_filter=group.name
243-
)
244-
245-
# Count active replicas
246-
active_group_count = len(active_replicas)
247-
group_diff = group_desired - active_group_count
295+
group_desired = get_group_desired_replica_count(run_model, group)
296+
state = get_group_rollout_state(run_model, group)
297+
group_diff = group_desired - len(state.active_replicas)
248298

249299
if group_diff != 0:
250-
# Check if rolling deployment is in progress for THIS GROUP
251-
252-
group_has_out_of_date = has_out_of_date_replicas(run_model, group_filter=group.name)
253-
254300
# During rolling deployment, don't scale down old replicas
255301
# Let rolling deployment handle stopping old replicas
256-
if group_diff < 0 and group_has_out_of_date:
302+
if group_diff < 0 and state.has_out_of_date_replicas:
257303
# Skip scaling down during rolling deployment
258304
continue
259305
await scale_run_replicas_for_group(
260306
session=session,
261307
run_model=run_model,
262308
group=group,
263309
replicas_diff=group_diff,
264-
run_spec=get_run_spec(run_model),
265-
active_replicas=active_replicas,
266-
inactive_replicas=inactive_replicas,
310+
run_spec=run_spec,
311+
active_replicas=state.active_replicas,
312+
inactive_replicas=state.inactive_replicas,
267313
)
268314

269315

0 commit comments

Comments
 (0)