Skip to content

Commit f975c2d

Browse files
committed
Defer running job payload loading
1 parent cdc67fe commit f975c2d

2 files changed

Lines changed: 200 additions & 103 deletions

File tree

src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py

Lines changed: 105 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import asyncio
2+
import enum
23
import re
34
import uuid
45
from collections.abc import Iterable
56
from dataclasses import dataclass
67
from datetime import timedelta
7-
from typing import Dict, List, Optional
8+
from typing import Dict, List, Literal, Optional, Union
89

910
from sqlalchemy import and_, func, select
1011
from sqlalchemy.ext.asyncio import AsyncSession
@@ -367,6 +368,7 @@ async def _process_running_job_provisioning_state(
367368
return
368369

369370
# fails are acceptable until timeout is exceeded
371+
success = False
370372
if job_provisioning_data.dockerized:
371373
logger.debug(
372374
"%s: process provisioning job with shim, age=%s",
@@ -401,35 +403,40 @@ async def _process_running_job_provisioning_state(
401403
fmt(context.job_model),
402404
context.job_submission.age,
403405
)
404-
# FIXME: downloading file archives and code here is a waste of time if
405-
# the runner is not ready yet
406-
file_archives = await _get_job_file_archives(
407-
session=session,
408-
archive_mappings=context.job.job_spec.file_archives,
409-
user=context.run_model.user,
410-
)
411-
code = await _get_job_code(
412-
session=session,
413-
project=context.project,
414-
repo=context.repo_model,
415-
code_hash=_get_repo_code_hash(context.run, context.job),
416-
)
417-
success = await common_utils.run_async(
418-
_submit_job_to_runner,
406+
runner_availability = await common_utils.run_async(
407+
_get_runner_availability,
419408
server_ssh_private_keys,
420409
job_provisioning_data,
421410
None,
422-
session=session,
423-
run=context.run,
424-
job_model=context.job_model,
425-
job=context.job,
426-
cluster_info=startup_context.cluster_info,
427-
code=code,
428-
file_archives=file_archives,
429-
secrets=startup_context.secrets,
430-
repo_credentials=startup_context.repo_creds,
431-
success_if_not_available=False,
432411
)
412+
if runner_availability == _RunnerAvailability.AVAILABLE:
413+
file_archives = await _get_job_file_archives(
414+
session=session,
415+
archive_mappings=context.job.job_spec.file_archives,
416+
user=context.run_model.user,
417+
)
418+
code = await _get_job_code(
419+
session=session,
420+
project=context.project,
421+
repo=context.repo_model,
422+
code_hash=_get_repo_code_hash(context.run, context.job),
423+
)
424+
success = await common_utils.run_async(
425+
_submit_job_to_runner,
426+
server_ssh_private_keys,
427+
job_provisioning_data,
428+
None,
429+
session=session,
430+
run=context.run,
431+
job_model=context.job_model,
432+
job=context.job,
433+
cluster_info=startup_context.cluster_info,
434+
code=code,
435+
file_archives=file_archives,
436+
secrets=startup_context.secrets,
437+
repo_credentials=startup_context.repo_creds,
438+
success_if_not_available=False,
439+
)
433440

434441
if success:
435442
return
@@ -462,41 +469,60 @@ async def _process_running_job_pulling_state(
462469
fmt(context.job_model),
463470
context.job_submission.age,
464471
)
465-
# FIXME: downloading file archives and code here is a waste of time if
466-
# the runner is not ready yet
467-
file_archives = await _get_job_file_archives(
468-
session=session,
469-
archive_mappings=context.job.job_spec.file_archives,
470-
user=context.run_model.user,
471-
)
472-
code = await _get_job_code(
473-
session=session,
474-
project=context.project,
475-
repo=context.repo_model,
476-
code_hash=_get_repo_code_hash(context.run, context.job),
477-
)
478-
success = await common_utils.run_async(
479-
_process_pulling_with_shim,
472+
shim_state = await common_utils.run_async(
473+
_get_shim_pulling_state,
480474
server_ssh_private_keys,
481475
job_provisioning_data,
482476
None,
483-
session=session,
484-
run=context.run,
485477
job_model=context.job_model,
486-
job=context.job,
487-
cluster_info=startup_context.cluster_info,
488-
code=code,
489-
file_archives=file_archives,
490-
secrets=startup_context.secrets,
491-
repo_credentials=startup_context.repo_creds,
492-
server_ssh_private_keys=server_ssh_private_keys,
493-
jpd=job_provisioning_data,
494478
)
495-
496-
if success:
479+
if shim_state == _ShimPullingState.WAITING:
497480
_reset_disconnected_at(session, context.job_model)
498481
return
499482

483+
if shim_state == _ShimPullingState.READY:
484+
runner_availability = await common_utils.run_async(
485+
_get_runner_availability,
486+
server_ssh_private_keys,
487+
job_provisioning_data,
488+
None,
489+
)
490+
if runner_availability == _RunnerAvailability.UNAVAILABLE:
491+
_reset_disconnected_at(session, context.job_model)
492+
return
493+
494+
if runner_availability == _RunnerAvailability.AVAILABLE:
495+
file_archives = await _get_job_file_archives(
496+
session=session,
497+
archive_mappings=context.job.job_spec.file_archives,
498+
user=context.run_model.user,
499+
)
500+
code = await _get_job_code(
501+
session=session,
502+
project=context.project,
503+
repo=context.repo_model,
504+
code_hash=_get_repo_code_hash(context.run, context.job),
505+
)
506+
success = await common_utils.run_async(
507+
_submit_job_to_runner,
508+
server_ssh_private_keys,
509+
job_provisioning_data,
510+
None,
511+
session=session,
512+
run=context.run,
513+
job_model=context.job_model,
514+
job=context.job,
515+
cluster_info=startup_context.cluster_info,
516+
code=code,
517+
file_archives=file_archives,
518+
secrets=startup_context.secrets,
519+
repo_credentials=startup_context.repo_creds,
520+
success_if_not_available=True,
521+
)
522+
if success:
523+
_reset_disconnected_at(session, context.job_model)
524+
return
525+
500526
if context.job_model.termination_reason:
501527
logger.warning(
502528
"%s: failed due to %s, age=%s",
@@ -562,6 +588,7 @@ async def _process_running_job_running_state(
562588
switch_job_status(session, context.job_model, JobStatus.TERMINATING)
563589
# job will be terminated and instance will be emptied by process_terminating_jobs
564590
return
591+
565592
# No job_model.termination_reason set means ssh connection failed
566593
_set_disconnected_at_now(session, context.job_model)
567594
if not _should_terminate_job_due_to_disconnect(context.job_model):
@@ -571,6 +598,7 @@ async def _process_running_job_running_state(
571598
context.job_submission.age,
572599
)
573600
return
601+
574602
if job_provisioning_data.instance_type.resources.spot:
575603
context.job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
576604
else:
@@ -809,31 +837,30 @@ def _process_provisioning_with_shim(
809837
return True
810838

811839

840+
class _RunnerAvailability(enum.Enum):
841+
AVAILABLE = "available"
842+
UNAVAILABLE = "unavailable"
843+
844+
845+
class _ShimPullingState(enum.Enum):
846+
WAITING = "waiting"
847+
READY = "ready"
848+
849+
850+
@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1)
851+
def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability:
852+
runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT])
853+
if runner_client.healthcheck() is None:
854+
return _RunnerAvailability.UNAVAILABLE
855+
return _RunnerAvailability.AVAILABLE
856+
857+
812858
@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT])
813-
def _process_pulling_with_shim(
859+
def _get_shim_pulling_state(
814860
ports: Dict[int, int],
815-
session: AsyncSession,
816-
run: Run,
817861
job_model: JobModel,
818-
job: Job,
819-
cluster_info: ClusterInfo,
820-
code: bytes,
821-
file_archives: Iterable[tuple[uuid.UUID, bytes]],
822-
secrets: Dict[str, str],
823-
repo_credentials: Optional[RemoteRepoCreds],
824-
server_ssh_private_keys: tuple[str, Optional[str]],
825-
jpd: JobProvisioningData,
826-
) -> bool:
827-
"""
828-
Possible next states:
829-
- JobStatus.RUNNING if runner is available
830-
- JobStatus.TERMINATING if shim is not available
831-
832-
Returns:
833-
is successful
834-
"""
862+
) -> Union[Literal[False], _ShimPullingState]:
835863
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
836-
job_runtime_data = None
837864
if shim_client.is_api_v2_supported(): # raises error if shim is down, causes retry
838865
task = shim_client.get_task(job_model.id)
839866

@@ -851,18 +878,17 @@ def _process_pulling_with_shim(
851878
return False
852879

853880
if task.status != TaskStatus.RUNNING:
854-
return True
881+
return _ShimPullingState.WAITING
855882

856883
job_runtime_data = get_job_runtime_data(job_model)
857884
# should check for None, as there may be older jobs submitted before
858885
# JobRuntimeData was introduced
859886
if job_runtime_data is not None:
860887
# port mapping is not yet available, waiting
861888
if task.ports is None:
862-
return True
889+
return _ShimPullingState.WAITING
863890
job_runtime_data.ports = {pm.container: pm.host for pm in task.ports}
864891
job_model.job_runtime_data = job_runtime_data.json()
865-
866892
else:
867893
shim_status = shim_client.pull() # raises error if shim is down, causes retry
868894

@@ -884,23 +910,9 @@ def _process_pulling_with_shim(
884910
return False
885911

886912
if shim_status.state in ("pulling", "creating"):
887-
return True
913+
return _ShimPullingState.WAITING
888914

889-
return _submit_job_to_runner(
890-
server_ssh_private_keys,
891-
jpd,
892-
job_runtime_data,
893-
session=session,
894-
run=run,
895-
job_model=job_model,
896-
job=job,
897-
cluster_info=cluster_info,
898-
code=code,
899-
file_archives=file_archives,
900-
secrets=secrets,
901-
repo_credentials=repo_credentials,
902-
success_if_not_available=True,
903-
)
915+
return _ShimPullingState.READY
904916

905917

906918
@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT])

0 commit comments

Comments
 (0)