@@ -385,15 +385,15 @@ async def _process_running_job_provisioning_state(
385385 server_ssh_private_keys ,
386386 job_provisioning_data ,
387387 None ,
388- session ,
389- context .run ,
390- context .job_model ,
391- job_provisioning_data ,
392- startup_context .volumes ,
393- context .job .job_spec .registry_auth ,
394- public_keys ,
395- ssh_user ,
396- user_ssh_key ,
388+ session = session ,
389+ run = context .run ,
390+ job_model = context .job_model ,
391+ jpd = job_provisioning_data ,
392+ volumes = startup_context .volumes ,
393+ registry_auth = context .job .job_spec .registry_auth ,
394+ public_keys = public_keys ,
395+ ssh_user = ssh_user ,
396+ ssh_key = user_ssh_key ,
397397 )
398398 else :
399399 logger .debug (
@@ -419,15 +419,15 @@ async def _process_running_job_provisioning_state(
419419 server_ssh_private_keys ,
420420 job_provisioning_data ,
421421 None ,
422- session ,
423- context .run ,
424- context .job_model ,
425- context .job ,
426- startup_context .cluster_info ,
427- code ,
428- file_archives ,
429- startup_context .secrets ,
430- startup_context .repo_creds ,
422+ session = session ,
423+ run = context .run ,
424+ job_model = context .job_model ,
425+ job = context .job ,
426+ cluster_info = startup_context .cluster_info ,
427+ code = code ,
428+ file_archives = file_archives ,
429+ secrets = startup_context .secrets ,
430+ repo_credentials = startup_context .repo_creds ,
431431 success_if_not_available = False ,
432432 )
433433
@@ -480,17 +480,17 @@ async def _process_running_job_pulling_state(
480480 server_ssh_private_keys ,
481481 job_provisioning_data ,
482482 None ,
483- session ,
484- context .run ,
485- context .job_model ,
486- context .job ,
487- startup_context .cluster_info ,
488- code ,
489- file_archives ,
490- startup_context .secrets ,
491- startup_context .repo_creds ,
492- server_ssh_private_keys ,
493- job_provisioning_data ,
483+ session = session ,
484+ run = context .run ,
485+ job_model = context .job_model ,
486+ job = context .job ,
487+ cluster_info = startup_context .cluster_info ,
488+ code = code ,
489+ file_archives = file_archives ,
490+ secrets = startup_context .secrets ,
491+ repo_credentials = startup_context .repo_creds ,
492+ server_ssh_private_keys = server_ssh_private_keys ,
493+ jpd = job_provisioning_data ,
494494 )
495495
496496 if success :
@@ -543,9 +543,9 @@ async def _process_running_job_running_state(
543543 server_ssh_private_keys ,
544544 job_provisioning_data ,
545545 context .job_submission .job_runtime_data ,
546- session ,
547- context .run_model ,
548- context .job_model ,
546+ session = session ,
547+ run_model = context .run_model ,
548+ job_model = context .job_model ,
549549 )
550550
551551 if success :
@@ -685,7 +685,7 @@ def _process_provisioning_with_shim(
685685 session : AsyncSession ,
686686 run : Run ,
687687 job_model : JobModel ,
688- job_provisioning_data : JobProvisioningData ,
688+ jpd : JobProvisioningData ,
689689 volumes : List [Volume ],
690690 registry_auth : Optional [RegistryAuth ],
691691 public_keys : List [str ],
@@ -730,13 +730,9 @@ def _process_provisioning_with_shim(
730730 for volume , volume_mount in zip (volumes , volume_mounts ):
731731 volume_mount .name = volume .name
732732
733- instance_mounts += _get_instance_specific_mounts (
734- job_provisioning_data .backend , job_provisioning_data .instance_type .name
735- )
733+ instance_mounts += _get_instance_specific_mounts (jpd .backend , jpd .instance_type .name )
736734
737- gpu_devices = _get_instance_specific_gpu_devices (
738- job_provisioning_data .backend , job_provisioning_data .instance_type .name
739- )
735+ gpu_devices = _get_instance_specific_gpu_devices (jpd .backend , jpd .instance_type .name )
740736
741737 container_user = "root"
742738
@@ -753,7 +749,7 @@ def _process_provisioning_with_shim(
753749 cpu = None
754750 memory = None
755751 network_mode = NetworkMode .HOST
756- image_name = _patch_base_image_for_aws_efa (job_spec , job_provisioning_data )
752+ image_name = _patch_base_image_for_aws_efa (job_spec , jpd )
757753 if shim_client .is_api_v2_supported ():
758754 shim_client .submit_task (
759755 task_id = job_model .id ,
@@ -775,7 +771,7 @@ def _process_provisioning_with_shim(
775771 host_ssh_user = ssh_user ,
776772 host_ssh_keys = [ssh_key ] if ssh_key else [],
777773 container_ssh_keys = public_keys ,
778- instance_id = job_provisioning_data .instance_id ,
774+ instance_id = jpd .instance_id ,
779775 )
780776 else :
781777 submitted = shim_client .submit (
@@ -792,7 +788,7 @@ def _process_provisioning_with_shim(
792788 mounts = volume_mounts ,
793789 volumes = volumes ,
794790 instance_mounts = instance_mounts ,
795- instance_id = job_provisioning_data .instance_id ,
791+ instance_id = jpd .instance_id ,
796792 )
797793 if not submitted :
798794 # This can happen when we lost connection to the runner (e.g., network issues), marked
@@ -826,7 +822,7 @@ def _process_pulling_with_shim(
826822 secrets : Dict [str , str ],
827823 repo_credentials : Optional [RemoteRepoCreds ],
828824 server_ssh_private_keys : tuple [str , Optional [str ]],
829- job_provisioning_data : JobProvisioningData ,
825+ jpd : JobProvisioningData ,
830826) -> bool :
831827 """
832828 Possible next states:
@@ -892,7 +888,7 @@ def _process_pulling_with_shim(
892888
893889 return _submit_job_to_runner (
894890 server_ssh_private_keys ,
895- job_provisioning_data ,
891+ jpd ,
896892 job_runtime_data ,
897893 session = session ,
898894 run = run ,
0 commit comments