11import asyncio
2+ import enum
23import re
34import uuid
45from collections .abc import Iterable
56from dataclasses import dataclass
67from datetime import timedelta
7- from typing import Dict , List , Optional
8+ from typing import Dict , List , Literal , Optional , Union
89
910from sqlalchemy import and_ , func , select
1011from sqlalchemy .ext .asyncio import AsyncSession
@@ -367,6 +368,7 @@ async def _process_running_job_provisioning_state(
367368 return
368369
369370 # fails are acceptable until timeout is exceeded
371+ success = False
370372 if job_provisioning_data .dockerized :
371373 logger .debug (
372374 "%s: process provisioning job with shim, age=%s" ,
@@ -401,35 +403,40 @@ async def _process_running_job_provisioning_state(
401403 fmt (context .job_model ),
402404 context .job_submission .age ,
403405 )
404- # FIXME: downloading file archives and code here is a waste of time if
405- # the runner is not ready yet
406- file_archives = await _get_job_file_archives (
407- session = session ,
408- archive_mappings = context .job .job_spec .file_archives ,
409- user = context .run_model .user ,
410- )
411- code = await _get_job_code (
412- session = session ,
413- project = context .project ,
414- repo = context .repo_model ,
415- code_hash = _get_repo_code_hash (context .run , context .job ),
416- )
417- success = await common_utils .run_async (
418- _submit_job_to_runner ,
406+ runner_availability = await common_utils .run_async (
407+ _get_runner_availability ,
419408 server_ssh_private_keys ,
420409 job_provisioning_data ,
421410 None ,
422- session = session ,
423- run = context .run ,
424- job_model = context .job_model ,
425- job = context .job ,
426- cluster_info = startup_context .cluster_info ,
427- code = code ,
428- file_archives = file_archives ,
429- secrets = startup_context .secrets ,
430- repo_credentials = startup_context .repo_creds ,
431- success_if_not_available = False ,
432411 )
412+ if runner_availability == _RunnerAvailability .AVAILABLE :
413+ file_archives = await _get_job_file_archives (
414+ session = session ,
415+ archive_mappings = context .job .job_spec .file_archives ,
416+ user = context .run_model .user ,
417+ )
418+ code = await _get_job_code (
419+ session = session ,
420+ project = context .project ,
421+ repo = context .repo_model ,
422+ code_hash = _get_repo_code_hash (context .run , context .job ),
423+ )
424+ success = await common_utils .run_async (
425+ _submit_job_to_runner ,
426+ server_ssh_private_keys ,
427+ job_provisioning_data ,
428+ None ,
429+ session = session ,
430+ run = context .run ,
431+ job_model = context .job_model ,
432+ job = context .job ,
433+ cluster_info = startup_context .cluster_info ,
434+ code = code ,
435+ file_archives = file_archives ,
436+ secrets = startup_context .secrets ,
437+ repo_credentials = startup_context .repo_creds ,
438+ success_if_not_available = False ,
439+ )
433440
434441 if success :
435442 return
@@ -462,41 +469,60 @@ async def _process_running_job_pulling_state(
462469 fmt (context .job_model ),
463470 context .job_submission .age ,
464471 )
465- # FIXME: downloading file archives and code here is a waste of time if
466- # the runner is not ready yet
467- file_archives = await _get_job_file_archives (
468- session = session ,
469- archive_mappings = context .job .job_spec .file_archives ,
470- user = context .run_model .user ,
471- )
472- code = await _get_job_code (
473- session = session ,
474- project = context .project ,
475- repo = context .repo_model ,
476- code_hash = _get_repo_code_hash (context .run , context .job ),
477- )
478- success = await common_utils .run_async (
479- _process_pulling_with_shim ,
472+ shim_state = await common_utils .run_async (
473+ _get_shim_pulling_state ,
480474 server_ssh_private_keys ,
481475 job_provisioning_data ,
482476 None ,
483- session = session ,
484- run = context .run ,
485477 job_model = context .job_model ,
486- job = context .job ,
487- cluster_info = startup_context .cluster_info ,
488- code = code ,
489- file_archives = file_archives ,
490- secrets = startup_context .secrets ,
491- repo_credentials = startup_context .repo_creds ,
492- server_ssh_private_keys = server_ssh_private_keys ,
493- jpd = job_provisioning_data ,
494478 )
495-
496- if success :
479+ if shim_state == _ShimPullingState .WAITING :
497480 _reset_disconnected_at (session , context .job_model )
498481 return
499482
483+ if shim_state == _ShimPullingState .READY :
484+ runner_availability = await common_utils .run_async (
485+ _get_runner_availability ,
486+ server_ssh_private_keys ,
487+ job_provisioning_data ,
488+ None ,
489+ )
490+ if runner_availability == _RunnerAvailability .UNAVAILABLE :
491+ _reset_disconnected_at (session , context .job_model )
492+ return
493+
494+ if runner_availability == _RunnerAvailability .AVAILABLE :
495+ file_archives = await _get_job_file_archives (
496+ session = session ,
497+ archive_mappings = context .job .job_spec .file_archives ,
498+ user = context .run_model .user ,
499+ )
500+ code = await _get_job_code (
501+ session = session ,
502+ project = context .project ,
503+ repo = context .repo_model ,
504+ code_hash = _get_repo_code_hash (context .run , context .job ),
505+ )
506+ success = await common_utils .run_async (
507+ _submit_job_to_runner ,
508+ server_ssh_private_keys ,
509+ job_provisioning_data ,
510+ None ,
511+ session = session ,
512+ run = context .run ,
513+ job_model = context .job_model ,
514+ job = context .job ,
515+ cluster_info = startup_context .cluster_info ,
516+ code = code ,
517+ file_archives = file_archives ,
518+ secrets = startup_context .secrets ,
519+ repo_credentials = startup_context .repo_creds ,
520+ success_if_not_available = True ,
521+ )
522+ if success :
523+ _reset_disconnected_at (session , context .job_model )
524+ return
525+
500526 if context .job_model .termination_reason :
501527 logger .warning (
502528 "%s: failed due to %s, age=%s" ,
@@ -562,6 +588,7 @@ async def _process_running_job_running_state(
562588 switch_job_status (session , context .job_model , JobStatus .TERMINATING )
563589 # job will be terminated and instance will be emptied by process_terminating_jobs
564590 return
591+
565592 # No job_model.termination_reason set means ssh connection failed
566593 _set_disconnected_at_now (session , context .job_model )
567594 if not _should_terminate_job_due_to_disconnect (context .job_model ):
@@ -571,6 +598,7 @@ async def _process_running_job_running_state(
571598 context .job_submission .age ,
572599 )
573600 return
601+
574602 if job_provisioning_data .instance_type .resources .spot :
575603 context .job_model .termination_reason = JobTerminationReason .INTERRUPTED_BY_NO_CAPACITY
576604 else :
@@ -809,31 +837,30 @@ def _process_provisioning_with_shim(
809837 return True
810838
811839
840+ class _RunnerAvailability (enum .Enum ):
841+ AVAILABLE = "available"
842+ UNAVAILABLE = "unavailable"
843+
844+
845+ class _ShimPullingState (enum .Enum ):
846+ WAITING = "waiting"
847+ READY = "ready"
848+
849+
850+ @runner_ssh_tunnel (ports = [DSTACK_RUNNER_HTTP_PORT ], retries = 1 )
851+ def _get_runner_availability (ports : Dict [int , int ]) -> _RunnerAvailability :
852+ runner_client = client .RunnerClient (port = ports [DSTACK_RUNNER_HTTP_PORT ])
853+ if runner_client .healthcheck () is None :
854+ return _RunnerAvailability .UNAVAILABLE
855+ return _RunnerAvailability .AVAILABLE
856+
857+
812858@runner_ssh_tunnel (ports = [DSTACK_SHIM_HTTP_PORT ])
813- def _process_pulling_with_shim (
859+ def _get_shim_pulling_state (
814860 ports : Dict [int , int ],
815- session : AsyncSession ,
816- run : Run ,
817861 job_model : JobModel ,
818- job : Job ,
819- cluster_info : ClusterInfo ,
820- code : bytes ,
821- file_archives : Iterable [tuple [uuid .UUID , bytes ]],
822- secrets : Dict [str , str ],
823- repo_credentials : Optional [RemoteRepoCreds ],
824- server_ssh_private_keys : tuple [str , Optional [str ]],
825- jpd : JobProvisioningData ,
826- ) -> bool :
827- """
828- Possible next states:
829- - JobStatus.RUNNING if runner is available
830- - JobStatus.TERMINATING if shim is not available
831-
832- Returns:
833- is successful
834- """
862+ ) -> Union [Literal [False ], _ShimPullingState ]:
835863 shim_client = client .ShimClient (port = ports [DSTACK_SHIM_HTTP_PORT ])
836- job_runtime_data = None
837864 if shim_client .is_api_v2_supported (): # raises error if shim is down, causes retry
838865 task = shim_client .get_task (job_model .id )
839866
@@ -851,18 +878,17 @@ def _process_pulling_with_shim(
851878 return False
852879
853880 if task .status != TaskStatus .RUNNING :
854- return True
881+ return _ShimPullingState . WAITING
855882
856883 job_runtime_data = get_job_runtime_data (job_model )
857884 # should check for None, as there may be older jobs submitted before
858885 # JobRuntimeData was introduced
859886 if job_runtime_data is not None :
860887 # port mapping is not yet available, waiting
861888 if task .ports is None :
862- return True
889+ return _ShimPullingState . WAITING
863890 job_runtime_data .ports = {pm .container : pm .host for pm in task .ports }
864891 job_model .job_runtime_data = job_runtime_data .json ()
865-
866892 else :
867893 shim_status = shim_client .pull () # raises error if shim is down, causes retry
868894
@@ -884,23 +910,9 @@ def _process_pulling_with_shim(
884910 return False
885911
886912 if shim_status .state in ("pulling" , "creating" ):
887- return True
913+ return _ShimPullingState . WAITING
888914
889- return _submit_job_to_runner (
890- server_ssh_private_keys ,
891- jpd ,
892- job_runtime_data ,
893- session = session ,
894- run = run ,
895- job_model = job_model ,
896- job = job ,
897- cluster_info = cluster_info ,
898- code = code ,
899- file_archives = file_archives ,
900- secrets = secrets ,
901- repo_credentials = repo_credentials ,
902- success_if_not_available = True ,
903- )
915+ return _ShimPullingState .READY
904916
905917
906918@runner_ssh_tunnel (ports = [DSTACK_RUNNER_HTTP_PORT ])
0 commit comments