Skip to content

Commit 4ff4d63

Browse files
authored
Introduce JOB_DISCONNECTED_RETRY_TIMEOUT (#2627)
* Introduce JOB_DISCONNECTED_RETRY_TIMEOUT * Fix tests
1 parent c521a57 commit 4ff4d63

File tree

6 files changed

+168
-16
lines changed

6 files changed

+168
-16
lines changed

src/dstack/_internal/core/models/runs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class JobTerminationReason(str, Enum):
104104
# Set by the server
105105
FAILED_TO_START_DUE_TO_NO_CAPACITY = "failed_to_start_due_to_no_capacity"
106106
INTERRUPTED_BY_NO_CAPACITY = "interrupted_by_no_capacity"
107+
INSTANCE_UNREACHABLE = "instance_unreachable"
107108
WAITING_INSTANCE_LIMIT_EXCEEDED = "waiting_instance_limit_exceeded"
108109
WAITING_RUNNER_LIMIT_EXCEEDED = "waiting_runner_limit_exceeded"
109110
TERMINATED_BY_USER = "terminated_by_user"
@@ -126,6 +127,7 @@ def to_status(self) -> JobStatus:
126127
mapping = {
127128
self.FAILED_TO_START_DUE_TO_NO_CAPACITY: JobStatus.FAILED,
128129
self.INTERRUPTED_BY_NO_CAPACITY: JobStatus.FAILED,
130+
self.INSTANCE_UNREACHABLE: JobStatus.FAILED,
129131
self.WAITING_INSTANCE_LIMIT_EXCEEDED: JobStatus.FAILED,
130132
self.WAITING_RUNNER_LIMIT_EXCEEDED: JobStatus.FAILED,
131133
self.TERMINATED_BY_USER: JobStatus.TERMINATED,

src/dstack/_internal/server/background/tasks/process_running_jobs.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from collections.abc import Iterable
3-
from datetime import timedelta
3+
from datetime import timedelta, timezone
44
from typing import Dict, List, Optional
55

66
from sqlalchemy import select
@@ -71,6 +71,12 @@
7171
logger = get_logger(__name__)
7272

7373

74+
# Minimum time before terminating active job in case of connectivity issues.
75+
# Should be sufficient to survive most problems caused by
76+
# the server network flickering and providers' glitches.
77+
JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2)
78+
79+
7480
async def process_running_jobs(batch_size: int = 1):
7581
tasks = []
7682
for _ in range(batch_size):
@@ -202,7 +208,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
202208
user_ssh_key = run.run_spec.ssh_key_pub.strip()
203209
public_keys = [project.ssh_public_key.strip(), user_ssh_key]
204210
if job_provisioning_data.backend == BackendType.LOCAL:
205-
# No need to update ~/.ssh/authorized_keys when running shim localy
211+
# No need to update ~/.ssh/authorized_keys when running shim locally
206212
user_ssh_key = ""
207213
success = await common_utils.run_async(
208214
_process_provisioning_with_shim,
@@ -299,19 +305,38 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
299305
run_model,
300306
job_model,
301307
)
302-
if not success:
303-
job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
304308

305-
if not success: # kill the job
306-
logger.warning(
307-
"%s: failed because runner is not available or return an error, age=%s",
308-
fmt(job_model),
309-
job_submission.age,
310-
)
311-
job_model.status = JobStatus.TERMINATING
312-
if not job_model.termination_reason:
313-
job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
314-
# job will be terminated and instance will be emptied by process_terminating_jobs
309+
if success:
310+
job_model.disconnected_at = None
311+
else:
312+
if job_model.termination_reason:
313+
logger.warning(
314+
"%s: failed because shim/runner returned an error, age=%s",
315+
fmt(job_model),
316+
job_submission.age,
317+
)
318+
job_model.status = JobStatus.TERMINATING
319+
# job will be terminated and instance will be emptied by process_terminating_jobs
320+
else:
321+
# No job_model.termination_reason set means ssh connection failed
322+
if job_model.disconnected_at is None:
323+
job_model.disconnected_at = common_utils.get_current_datetime()
324+
if _should_terminate_job_due_to_disconnect(job_model):
325+
logger.warning(
326+
"%s: failed because instance is unreachable, age=%s",
327+
fmt(job_model),
328+
job_submission.age,
329+
)
330+
# TODO: Replace with JobTerminationReason.INSTANCE_UNREACHABLE in 0.20 or
331+
# when CLI <= 0.19.8 is no longer supported
332+
job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
333+
job_model.status = JobStatus.TERMINATING
334+
else:
335+
logger.warning(
336+
"%s: is unreachable, waiting for the instance to become reachable again, age=%s",
337+
fmt(job_model),
338+
job_submission.age,
339+
)
315340

316341
if (
317342
initial_status != job_model.status
@@ -692,6 +717,15 @@ def _terminate_if_inactivity_duration_exceeded(
692717
)
693718

694719

720+
def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool:
721+
if job_model.disconnected_at is None:
722+
return False
723+
return (
724+
common_utils.get_current_datetime()
725+
> job_model.disconnected_at.replace(tzinfo=timezone.utc) + JOB_DISCONNECTED_RETRY_TIMEOUT
726+
)
727+
728+
695729
async def _check_gpu_utilization(session: AsyncSession, job_model: JobModel, job: Job) -> None:
696730
policy = job.job_spec.utilization_policy
697731
if policy is None:
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""Add JobModel.disconnected_at
2+
3+
Revision ID: 20166748b60c
4+
Revises: 6c1a9d6530ee
5+
Create Date: 2025-05-13 16:24:32.496578
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
from alembic_postgresql_enum import TableReference
12+
13+
import dstack._internal.server.models
14+
15+
# revision identifiers, used by Alembic.
16+
revision = "20166748b60c"
17+
down_revision = "6c1a9d6530ee"
18+
branch_labels = None
19+
depends_on = None
20+
21+
22+
def upgrade() -> None:
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
with op.batch_alter_table("jobs", schema=None) as batch_op:
25+
batch_op.add_column(
26+
sa.Column(
27+
"disconnected_at", dstack._internal.server.models.NaiveDateTime(), nullable=True
28+
)
29+
)
30+
31+
op.sync_enum_values(
32+
enum_schema="public",
33+
enum_name="jobterminationreason",
34+
new_values=[
35+
"FAILED_TO_START_DUE_TO_NO_CAPACITY",
36+
"INTERRUPTED_BY_NO_CAPACITY",
37+
"INSTANCE_UNREACHABLE",
38+
"WAITING_INSTANCE_LIMIT_EXCEEDED",
39+
"WAITING_RUNNER_LIMIT_EXCEEDED",
40+
"TERMINATED_BY_USER",
41+
"VOLUME_ERROR",
42+
"GATEWAY_ERROR",
43+
"SCALED_DOWN",
44+
"DONE_BY_RUNNER",
45+
"ABORTED_BY_USER",
46+
"TERMINATED_BY_SERVER",
47+
"INACTIVITY_DURATION_EXCEEDED",
48+
"TERMINATED_DUE_TO_UTILIZATION_POLICY",
49+
"CONTAINER_EXITED_WITH_ERROR",
50+
"PORTS_BINDING_FAILED",
51+
"CREATING_CONTAINER_ERROR",
52+
"EXECUTOR_ERROR",
53+
"MAX_DURATION_EXCEEDED",
54+
],
55+
affected_columns=[
56+
TableReference(
57+
table_schema="public", table_name="jobs", column_name="termination_reason"
58+
)
59+
],
60+
enum_values_to_rename=[],
61+
)
62+
# ### end Alembic commands ###
63+
64+
65+
def downgrade() -> None:
66+
# ### commands auto generated by Alembic - please adjust! ###
67+
op.sync_enum_values(
68+
enum_schema="public",
69+
enum_name="jobterminationreason",
70+
new_values=[
71+
"FAILED_TO_START_DUE_TO_NO_CAPACITY",
72+
"INTERRUPTED_BY_NO_CAPACITY",
73+
"WAITING_INSTANCE_LIMIT_EXCEEDED",
74+
"WAITING_RUNNER_LIMIT_EXCEEDED",
75+
"TERMINATED_BY_USER",
76+
"VOLUME_ERROR",
77+
"GATEWAY_ERROR",
78+
"SCALED_DOWN",
79+
"DONE_BY_RUNNER",
80+
"ABORTED_BY_USER",
81+
"TERMINATED_BY_SERVER",
82+
"INACTIVITY_DURATION_EXCEEDED",
83+
"TERMINATED_DUE_TO_UTILIZATION_POLICY",
84+
"CONTAINER_EXITED_WITH_ERROR",
85+
"PORTS_BINDING_FAILED",
86+
"CREATING_CONTAINER_ERROR",
87+
"EXECUTOR_ERROR",
88+
"MAX_DURATION_EXCEEDED",
89+
],
90+
affected_columns=[
91+
TableReference(
92+
table_schema="public", table_name="jobs", column_name="termination_reason"
93+
)
94+
],
95+
enum_values_to_rename=[],
96+
)
97+
with op.batch_alter_table("jobs", schema=None) as batch_op:
98+
batch_op.drop_column("disconnected_at")
99+
100+
# ### end Alembic commands ###

src/dstack/_internal/server/models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ class JobModel(BaseModel):
382382
Enum(JobTerminationReason)
383383
)
384384
termination_reason_message: Mapped[Optional[str]] = mapped_column(Text)
385+
# `disconnected_at` stores the first time of connectivity issues with the instance.
386+
# Resets every time connectivity is restored.
387+
disconnected_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
385388
exit_status: Mapped[Optional[int]] = mapped_column(Integer)
386389
job_spec_data: Mapped[str] = mapped_column(Text)
387390
job_provisioning_data: Mapped[Optional[str]] = mapped_column(Text)
@@ -391,7 +394,7 @@ class JobModel(BaseModel):
391394
remove_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
392395
volumes_detached_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
393396
# `instance_assigned` means instance assignment was done.
394-
# if `instance_assigned` is True and `instance` is None, no instance was assiged.
397+
# if `instance_assigned` is True and `instance` is None, no instance was assigned.
395398
instance_assigned: Mapped[bool] = mapped_column(Boolean, default=False)
396399
instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(
397400
ForeignKey("instances.id", ondelete="CASCADE")

src/dstack/_internal/server/testing/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ async def create_job(
302302
job_num: int = 0,
303303
replica_num: int = 0,
304304
instance_assigned: bool = False,
305+
disconnected_at: Optional[datetime] = None,
305306
) -> JobModel:
306307
run_spec = RunSpec.parse_raw(run.run_spec)
307308
job_spec = (await get_job_specs_from_run_spec(run_spec, replica_num=replica_num))[0]
@@ -323,6 +324,7 @@ async def create_job(
323324
instance=instance,
324325
instance_assigned=instance_assigned,
325326
used_instance_id=instance.id if instance is not None else None,
327+
disconnected_at=disconnected_at,
326328
)
327329
session.add(job)
328330
await session.commit()

src/tests/_internal/server/background/tasks/test_process_running_jobs.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from datetime import datetime, timezone
1+
from datetime import datetime, timedelta, timezone
22
from pathlib import Path
33
from typing import Optional
44
from unittest.mock import MagicMock, Mock, patch
@@ -490,6 +490,17 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession):
490490
assert SSHTunnelMock.call_count == 3
491491
await session.refresh(job)
492492
assert job is not None
493+
assert job.disconnected_at is not None
494+
assert job.status == JobStatus.PULLING
495+
with (
496+
patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
497+
patch("dstack._internal.server.services.runner.ssh.time.sleep"),
498+
freeze_time(job.disconnected_at + timedelta(minutes=5)),
499+
):
500+
SSHTunnelMock.side_effect = SSHError
501+
await process_running_jobs()
502+
assert SSHTunnelMock.call_count == 3
503+
await session.refresh(job)
493504
assert job.status == JobStatus.TERMINATING
494505
assert job.termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
495506
assert job.remove_at is None

0 commit comments

Comments
 (0)