|
1 | 1 | import asyncio |
2 | 2 | from collections.abc import Iterable |
3 | | -from datetime import timedelta |
| 3 | +from datetime import timedelta, timezone |
4 | 4 | from typing import Dict, List, Optional |
5 | 5 |
|
6 | 6 | from sqlalchemy import select |
|
71 | 71 | logger = get_logger(__name__) |
72 | 72 |
|
73 | 73 |
|
| 74 | +# Minimum time before terminating active job in case of connectivity issues. |
| 75 | +# Should be sufficient to survive most problems caused by |
| 76 | +# the server network flickering and providers' glitches. |
| 77 | +JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2) |
| 78 | + |
| 79 | + |
74 | 80 | async def process_running_jobs(batch_size: int = 1): |
75 | 81 | tasks = [] |
76 | 82 | for _ in range(batch_size): |
@@ -202,7 +208,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): |
202 | 208 | user_ssh_key = run.run_spec.ssh_key_pub.strip() |
203 | 209 | public_keys = [project.ssh_public_key.strip(), user_ssh_key] |
204 | 210 | if job_provisioning_data.backend == BackendType.LOCAL: |
205 | | - # No need to update ~/.ssh/authorized_keys when running shim localy |
| 211 | + # No need to update ~/.ssh/authorized_keys when running shim locally |
206 | 212 | user_ssh_key = "" |
207 | 213 | success = await common_utils.run_async( |
208 | 214 | _process_provisioning_with_shim, |
@@ -299,19 +305,38 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): |
299 | 305 | run_model, |
300 | 306 | job_model, |
301 | 307 | ) |
302 | | - if not success: |
303 | | - job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY |
304 | 308 |
|
305 | | - if not success: # kill the job |
306 | | - logger.warning( |
307 | | - "%s: failed because runner is not available or return an error, age=%s", |
308 | | - fmt(job_model), |
309 | | - job_submission.age, |
310 | | - ) |
311 | | - job_model.status = JobStatus.TERMINATING |
312 | | - if not job_model.termination_reason: |
313 | | - job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY |
314 | | - # job will be terminated and instance will be emptied by process_terminating_jobs |
| 309 | + if success: |
| 310 | + job_model.disconnected_at = None |
| 311 | + else: |
| 312 | + if job_model.termination_reason: |
| 313 | + logger.warning( |
| 314 | + "%s: failed because shim/runner returned an error, age=%s", |
| 315 | + fmt(job_model), |
| 316 | + job_submission.age, |
| 317 | + ) |
| 318 | + job_model.status = JobStatus.TERMINATING |
| 319 | + # job will be terminated and instance will be emptied by process_terminating_jobs |
| 320 | + else: |
| 321 | + # No job_model.termination_reason set means ssh connection failed |
| 322 | + if job_model.disconnected_at is None: |
| 323 | + job_model.disconnected_at = common_utils.get_current_datetime() |
| 324 | + if _should_terminate_job_due_to_disconnect(job_model): |
| 325 | + logger.warning( |
| 326 | + "%s: failed because instance is unreachable, age=%s", |
| 327 | + fmt(job_model), |
| 328 | + job_submission.age, |
| 329 | + ) |
| 330 | + # TODO: Replace with JobTerminationReason.INSTANCE_UNREACHABLE in 0.20 or |
| 331 | + # when CLI <= 0.19.8 is no longer supported |
| 332 | + job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY |
| 333 | + job_model.status = JobStatus.TERMINATING |
| 334 | + else: |
| 335 | + logger.warning( |
| 336 | + "%s: is unreachable, waiting for the instance to become reachable again, age=%s", |
| 337 | + fmt(job_model), |
| 338 | + job_submission.age, |
| 339 | + ) |
315 | 340 |
|
316 | 341 | if ( |
317 | 342 | initial_status != job_model.status |
@@ -692,6 +717,15 @@ def _terminate_if_inactivity_duration_exceeded( |
692 | 717 | ) |
693 | 718 |
|
694 | 719 |
|
| 720 | +def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool: |
| 721 | + if job_model.disconnected_at is None: |
| 722 | + return False |
| 723 | + return ( |
| 724 | + common_utils.get_current_datetime() |
| 725 | + > job_model.disconnected_at.replace(tzinfo=timezone.utc) + JOB_DISCONNECTED_RETRY_TIMEOUT |
| 726 | + ) |
| 727 | + |
| 728 | + |
695 | 729 | async def _check_gpu_utilization(session: AsyncSession, job_model: JobModel, job: Job) -> None: |
696 | 730 | policy = job.job_spec.utilization_policy |
697 | 731 | if policy is None: |
|
0 commit comments