Skip to content

Commit 4e7ff02

Browse files
authored
Consider multinode replica inactive only if all jobs done (#3157)
1 parent 30a0dd5 commit 4e7ff02

File tree

3 files changed

+78
-27
lines changed

3 files changed

+78
-27
lines changed

src/dstack/_internal/server/background/tasks/process_instances.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,7 @@ async def _add_remote(instance: InstanceModel) -> None:
259259
if instance.status == InstanceStatus.PENDING:
260260
instance.status = InstanceStatus.PROVISIONING
261261

262-
retry_duration_deadline = instance.created_at.replace(
263-
tzinfo=datetime.timezone.utc
264-
) + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
262+
retry_duration_deadline = instance.created_at + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
265263
if retry_duration_deadline < get_current_datetime():
266264
instance.status = InstanceStatus.TERMINATED
267265
instance.termination_reason = "Provisioning timeout expired"

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
256256
for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
257257
replica_statuses: Set[RunStatus] = set()
258258
replica_needs_retry = False
259-
260259
replica_active = True
260+
jobs_done_num = 0
261261
for job_model in job_models:
262262
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
263263
if (
@@ -272,8 +272,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
272272
):
273273
# the job is done or going to be done
274274
replica_statuses.add(RunStatus.DONE)
275-
# for some reason the replica is done, it's not active
276-
replica_active = False
275+
jobs_done_num += 1
277276
elif job_model.termination_reason == JobTerminationReason.SCALED_DOWN:
278277
# the job was scaled down
279278
replica_active = False
@@ -313,26 +312,14 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
313312
if not replica_needs_retry or retry_single_job:
314313
run_statuses.update(replica_statuses)
315314

316-
if replica_active:
317-
# submitted_at = replica created
318-
replicas_info.append(
319-
autoscalers.ReplicaInfo(
320-
active=True,
321-
timestamp=min(job.submitted_at for job in job_models).replace(
322-
tzinfo=datetime.timezone.utc
323-
),
324-
)
325-
)
326-
else:
327-
# last_processed_at = replica scaled down
328-
replicas_info.append(
329-
autoscalers.ReplicaInfo(
330-
active=False,
331-
timestamp=max(job.last_processed_at for job in job_models).replace(
332-
tzinfo=datetime.timezone.utc
333-
),
334-
)
335-
)
315+
if jobs_done_num == len(job_models):
316+
# Consider replica inactive if all its jobs are done for some reason.
317+
# If only some jobs are done, replica is considered active to avoid
318+
# provisioning new replicas for partially done multi-node tasks.
319+
replica_active = False
320+
321+
replica_info = _get_replica_info(job_models, replica_active)
322+
replicas_info.append(replica_info)
336323

337324
termination_reason: Optional[RunTerminationReason] = None
338325
if RunStatus.FAILED in run_statuses:
@@ -410,6 +397,23 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
410397
run_model.resubmission_attempt += 1
411398

412399

400+
def _get_replica_info(
401+
replica_job_models: list[JobModel],
402+
replica_active: bool,
403+
) -> autoscalers.ReplicaInfo:
404+
if replica_active:
405+
# submitted_at = replica created
406+
return autoscalers.ReplicaInfo(
407+
active=True,
408+
timestamp=min(job.submitted_at for job in replica_job_models),
409+
)
410+
# last_processed_at = replica scaled down
411+
return autoscalers.ReplicaInfo(
412+
active=False,
413+
timestamp=max(job.last_processed_at for job in replica_job_models),
414+
)
415+
416+
413417
async def _handle_run_replicas(
414418
session: AsyncSession,
415419
run_model: RunModel,

src/tests/_internal/server/background/tasks/test_process_runs.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ async def test_some_failed_to_terminating(
373373
session: AsyncSession,
374374
job_status: JobStatus,
375375
job_termination_reason: JobTerminationReason,
376-
) -> None:
376+
):
377377
run = await make_run(session, status=RunStatus.RUNNING, replicas=2)
378378
await create_job(
379379
session=session,
@@ -389,6 +389,55 @@ async def test_some_failed_to_terminating(
389389
assert run.status == RunStatus.TERMINATING
390390
assert run.termination_reason == RunTerminationReason.JOB_FAILED
391391

392+
@pytest.mark.asyncio
393+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
394+
async def test_considers_replicas_inactive_only_when_all_jobs_done(
395+
self,
396+
test_db,
397+
session: AsyncSession,
398+
):
399+
project = await create_project(session=session)
400+
user = await create_user(session=session)
401+
repo = await create_repo(session=session, project_id=project.id)
402+
run_name = "test-run"
403+
run_spec = get_run_spec(
404+
repo_id=repo.name,
405+
run_name=run_name,
406+
configuration=TaskConfiguration(
407+
commands=["echo hello"],
408+
nodes=2,
409+
),
410+
)
411+
run = await create_run(
412+
session=session,
413+
project=project,
414+
repo=repo,
415+
user=user,
416+
run_name=run_name,
417+
run_spec=run_spec,
418+
status=RunStatus.RUNNING,
419+
)
420+
await create_job(
421+
session=session,
422+
run=run,
423+
status=JobStatus.DONE,
424+
termination_reason=JobTerminationReason.DONE_BY_RUNNER,
425+
replica_num=0,
426+
job_num=0,
427+
)
428+
await create_job(
429+
session=session,
430+
run=run,
431+
status=JobStatus.RUNNING,
432+
replica_num=0,
433+
job_num=1,
434+
)
435+
await process_runs.process_runs()
436+
await session.refresh(run)
437+
assert run.status == RunStatus.RUNNING
438+
# Should not create new replica with new jobs
439+
assert len(run.jobs) == 2
440+
392441
@pytest.mark.asyncio
393442
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
394443
async def test_pending_to_submitted_adds_replicas(self, test_db, session: AsyncSession):

0 commit comments

Comments
 (0)