diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 587697a21c..c1e035aba1 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -104,6 +104,7 @@ class JobTerminationReason(str, Enum): # Set by the server FAILED_TO_START_DUE_TO_NO_CAPACITY = "failed_to_start_due_to_no_capacity" INTERRUPTED_BY_NO_CAPACITY = "interrupted_by_no_capacity" + INSTANCE_UNREACHABLE = "instance_unreachable" WAITING_INSTANCE_LIMIT_EXCEEDED = "waiting_instance_limit_exceeded" WAITING_RUNNER_LIMIT_EXCEEDED = "waiting_runner_limit_exceeded" TERMINATED_BY_USER = "terminated_by_user" @@ -126,6 +127,7 @@ def to_status(self) -> JobStatus: mapping = { self.FAILED_TO_START_DUE_TO_NO_CAPACITY: JobStatus.FAILED, self.INTERRUPTED_BY_NO_CAPACITY: JobStatus.FAILED, + self.INSTANCE_UNREACHABLE: JobStatus.FAILED, self.WAITING_INSTANCE_LIMIT_EXCEEDED: JobStatus.FAILED, self.WAITING_RUNNER_LIMIT_EXCEEDED: JobStatus.FAILED, self.TERMINATED_BY_USER: JobStatus.TERMINATED, diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 72de9e1c8a..ea048383c3 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -1,6 +1,6 @@ import asyncio from collections.abc import Iterable -from datetime import timedelta +from datetime import timedelta, timezone from typing import Dict, List, Optional from sqlalchemy import select @@ -71,6 +71,12 @@ logger = get_logger(__name__) +# Minimum time before terminating active job in case of connectivity issues. +# Should be sufficient to survive most problems caused by +# the server network flickering and providers' glitches. +JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2) + + async def process_running_jobs(batch_size: int = 1): tasks = [] for _ in range(batch_size): @@ -202,7 +208,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): user_ssh_key = run.run_spec.ssh_key_pub.strip() public_keys = [project.ssh_public_key.strip(), user_ssh_key] if job_provisioning_data.backend == BackendType.LOCAL: - # No need to update ~/.ssh/authorized_keys when running shim localy + # No need to update ~/.ssh/authorized_keys when running shim locally user_ssh_key = "" success = await common_utils.run_async( _process_provisioning_with_shim, @@ -299,19 +305,38 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): run_model, job_model, ) - if not success: - job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY - if not success: # kill the job - logger.warning( - "%s: failed because runner is not available or return an error, age=%s", - fmt(job_model), - job_submission.age, - ) - job_model.status = JobStatus.TERMINATING - if not job_model.termination_reason: - job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY - # job will be terminated and instance will be emptied by process_terminating_jobs + if success: + job_model.disconnected_at = None + else: + if job_model.termination_reason: + logger.warning( + "%s: failed because shim/runner returned an error, age=%s", + fmt(job_model), + job_submission.age, + ) + job_model.status = JobStatus.TERMINATING + # job will be terminated and instance will be emptied by process_terminating_jobs + else: + # No job_model.termination_reason set means ssh connection failed + if job_model.disconnected_at is None: + job_model.disconnected_at = common_utils.get_current_datetime() + if _should_terminate_job_due_to_disconnect(job_model): + logger.warning( + "%s: failed because instance is unreachable, age=%s", + fmt(job_model), + job_submission.age, + ) + # TODO: Replace with JobTerminationReason.INSTANCE_UNREACHABLE in 0.20 or + # when CLI <= 0.19.8 is no longer supported + job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + job_model.status = JobStatus.TERMINATING + else: + logger.warning( + "%s: is unreachable, waiting for the instance to become reachable again, age=%s", + fmt(job_model), + job_submission.age, + ) if ( initial_status != job_model.status @@ -692,6 +717,15 @@ def _terminate_if_inactivity_duration_exceeded( ) +def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool: + if job_model.disconnected_at is None: + return False + return ( + common_utils.get_current_datetime() + > job_model.disconnected_at.replace(tzinfo=timezone.utc) + JOB_DISCONNECTED_RETRY_TIMEOUT + ) + + async def _check_gpu_utilization(session: AsyncSession, job_model: JobModel, job: Job) -> None: policy = job.job_spec.utilization_policy if policy is None: diff --git a/src/dstack/_internal/server/migrations/versions/20166748b60c_add_jobmodel_disconnected_at.py b/src/dstack/_internal/server/migrations/versions/20166748b60c_add_jobmodel_disconnected_at.py new file mode 100644 index 0000000000..cc1c33e254 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/20166748b60c_add_jobmodel_disconnected_at.py @@ -0,0 +1,100 @@ +"""Add JobModel.disconnected_at + +Revision ID: 20166748b60c +Revises: 6c1a9d6530ee +Create Date: 2025-05-13 16:24:32.496578 + +""" + +import sqlalchemy as sa +from alembic import op +from alembic_postgresql_enum import TableReference + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "20166748b60c" +down_revision = "6c1a9d6530ee" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "disconnected_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + + op.sync_enum_values( + enum_schema="public", + enum_name="jobterminationreason", + new_values=[ + "FAILED_TO_START_DUE_TO_NO_CAPACITY", + "INTERRUPTED_BY_NO_CAPACITY", + "INSTANCE_UNREACHABLE", + "WAITING_INSTANCE_LIMIT_EXCEEDED", + "WAITING_RUNNER_LIMIT_EXCEEDED", + "TERMINATED_BY_USER", + "VOLUME_ERROR", + "GATEWAY_ERROR", + "SCALED_DOWN", + "DONE_BY_RUNNER", + "ABORTED_BY_USER", + "TERMINATED_BY_SERVER", + "INACTIVITY_DURATION_EXCEEDED", + "TERMINATED_DUE_TO_UTILIZATION_POLICY", + "CONTAINER_EXITED_WITH_ERROR", + "PORTS_BINDING_FAILED", + "CREATING_CONTAINER_ERROR", + "EXECUTOR_ERROR", + "MAX_DURATION_EXCEEDED", + ], + affected_columns=[ + TableReference( + table_schema="public", table_name="jobs", column_name="termination_reason" + ) + ], + enum_values_to_rename=[], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values( + enum_schema="public", + enum_name="jobterminationreason", + new_values=[ + "FAILED_TO_START_DUE_TO_NO_CAPACITY", + "INTERRUPTED_BY_NO_CAPACITY", + "WAITING_INSTANCE_LIMIT_EXCEEDED", + "WAITING_RUNNER_LIMIT_EXCEEDED", + "TERMINATED_BY_USER", + "VOLUME_ERROR", + "GATEWAY_ERROR", + "SCALED_DOWN", + "DONE_BY_RUNNER", + "ABORTED_BY_USER", + "TERMINATED_BY_SERVER", + "INACTIVITY_DURATION_EXCEEDED", + "TERMINATED_DUE_TO_UTILIZATION_POLICY", + "CONTAINER_EXITED_WITH_ERROR", + "PORTS_BINDING_FAILED", + "CREATING_CONTAINER_ERROR", + "EXECUTOR_ERROR", + "MAX_DURATION_EXCEEDED", + ], + affected_columns=[ + TableReference( + table_schema="public", table_name="jobs", column_name="termination_reason" + ) + ], + enum_values_to_rename=[], + ) + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.drop_column("disconnected_at") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index abcbc2529c..322f2163bb 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -382,6 +382,9 @@ class JobModel(BaseModel): Enum(JobTerminationReason) ) termination_reason_message: Mapped[Optional[str]] = mapped_column(Text) + # `disconnected_at` stores the first time of connectivity issues with the instance. + # Resets every time connectivity is restored. + disconnected_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) exit_status: Mapped[Optional[int]] = mapped_column(Integer) job_spec_data: Mapped[str] = mapped_column(Text) job_provisioning_data: Mapped[Optional[str]] = mapped_column(Text) @@ -391,7 +394,7 @@ class JobModel(BaseModel): remove_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) volumes_detached_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) # `instance_assigned` means instance assignment was done. - # if `instance_assigned` is True and `instance` is None, no instance was assiged. + # if `instance_assigned` is True and `instance` is None, no instance was assigned. instance_assigned: Mapped[bool] = mapped_column(Boolean, default=False) instance_id: Mapped[Optional[uuid.UUID]] = mapped_column( ForeignKey("instances.id", ondelete="CASCADE") diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 03e95a57f3..6d6f37e9f5 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -302,6 +302,7 @@ async def create_job( job_num: int = 0, replica_num: int = 0, instance_assigned: bool = False, + disconnected_at: Optional[datetime] = None, ) -> JobModel: run_spec = RunSpec.parse_raw(run.run_spec) job_spec = (await get_job_specs_from_run_spec(run_spec, replica_num=replica_num))[0] @@ -323,6 +324,7 @@ async def create_job( instance=instance, instance_assigned=instance_assigned, used_instance_id=instance.id if instance is not None else None, + disconnected_at=disconnected_at, ) session.add(job) await session.commit() diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index 666c385fa9..ab5123d53b 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Optional from unittest.mock import MagicMock, Mock, patch @@ -490,6 +490,17 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession): assert SSHTunnelMock.call_count == 3 await session.refresh(job) assert job is not None + assert job.disconnected_at is not None + assert job.status == JobStatus.PULLING + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, + patch("dstack._internal.server.services.runner.ssh.time.sleep"), + freeze_time(job.disconnected_at + timedelta(minutes=5)), + ): + SSHTunnelMock.side_effect = SSHError + await process_running_jobs() + assert SSHTunnelMock.call_count == 3 + await session.refresh(job) assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY assert job.remove_at is None