Skip to content

Commit cdc67fe

Browse files
committed
Use keyword args for running job RPCs
1 parent 17cf5e8 commit cdc67fe

1 file changed

Lines changed: 40 additions & 44 deletions

File tree

src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -385,15 +385,15 @@ async def _process_running_job_provisioning_state(
385385
server_ssh_private_keys,
386386
job_provisioning_data,
387387
None,
388-
session,
389-
context.run,
390-
context.job_model,
391-
job_provisioning_data,
392-
startup_context.volumes,
393-
context.job.job_spec.registry_auth,
394-
public_keys,
395-
ssh_user,
396-
user_ssh_key,
388+
session=session,
389+
run=context.run,
390+
job_model=context.job_model,
391+
jpd=job_provisioning_data,
392+
volumes=startup_context.volumes,
393+
registry_auth=context.job.job_spec.registry_auth,
394+
public_keys=public_keys,
395+
ssh_user=ssh_user,
396+
ssh_key=user_ssh_key,
397397
)
398398
else:
399399
logger.debug(
@@ -419,15 +419,15 @@ async def _process_running_job_provisioning_state(
419419
server_ssh_private_keys,
420420
job_provisioning_data,
421421
None,
422-
session,
423-
context.run,
424-
context.job_model,
425-
context.job,
426-
startup_context.cluster_info,
427-
code,
428-
file_archives,
429-
startup_context.secrets,
430-
startup_context.repo_creds,
422+
session=session,
423+
run=context.run,
424+
job_model=context.job_model,
425+
job=context.job,
426+
cluster_info=startup_context.cluster_info,
427+
code=code,
428+
file_archives=file_archives,
429+
secrets=startup_context.secrets,
430+
repo_credentials=startup_context.repo_creds,
431431
success_if_not_available=False,
432432
)
433433

@@ -480,17 +480,17 @@ async def _process_running_job_pulling_state(
480480
server_ssh_private_keys,
481481
job_provisioning_data,
482482
None,
483-
session,
484-
context.run,
485-
context.job_model,
486-
context.job,
487-
startup_context.cluster_info,
488-
code,
489-
file_archives,
490-
startup_context.secrets,
491-
startup_context.repo_creds,
492-
server_ssh_private_keys,
493-
job_provisioning_data,
483+
session=session,
484+
run=context.run,
485+
job_model=context.job_model,
486+
job=context.job,
487+
cluster_info=startup_context.cluster_info,
488+
code=code,
489+
file_archives=file_archives,
490+
secrets=startup_context.secrets,
491+
repo_credentials=startup_context.repo_creds,
492+
server_ssh_private_keys=server_ssh_private_keys,
493+
jpd=job_provisioning_data,
494494
)
495495

496496
if success:
@@ -543,9 +543,9 @@ async def _process_running_job_running_state(
543543
server_ssh_private_keys,
544544
job_provisioning_data,
545545
context.job_submission.job_runtime_data,
546-
session,
547-
context.run_model,
548-
context.job_model,
546+
session=session,
547+
run_model=context.run_model,
548+
job_model=context.job_model,
549549
)
550550

551551
if success:
@@ -685,7 +685,7 @@ def _process_provisioning_with_shim(
685685
session: AsyncSession,
686686
run: Run,
687687
job_model: JobModel,
688-
job_provisioning_data: JobProvisioningData,
688+
jpd: JobProvisioningData,
689689
volumes: List[Volume],
690690
registry_auth: Optional[RegistryAuth],
691691
public_keys: List[str],
@@ -730,13 +730,9 @@ def _process_provisioning_with_shim(
730730
for volume, volume_mount in zip(volumes, volume_mounts):
731731
volume_mount.name = volume.name
732732

733-
instance_mounts += _get_instance_specific_mounts(
734-
job_provisioning_data.backend, job_provisioning_data.instance_type.name
735-
)
733+
instance_mounts += _get_instance_specific_mounts(jpd.backend, jpd.instance_type.name)
736734

737-
gpu_devices = _get_instance_specific_gpu_devices(
738-
job_provisioning_data.backend, job_provisioning_data.instance_type.name
739-
)
735+
gpu_devices = _get_instance_specific_gpu_devices(jpd.backend, jpd.instance_type.name)
740736

741737
container_user = "root"
742738

@@ -753,7 +749,7 @@ def _process_provisioning_with_shim(
753749
cpu = None
754750
memory = None
755751
network_mode = NetworkMode.HOST
756-
image_name = _patch_base_image_for_aws_efa(job_spec, job_provisioning_data)
752+
image_name = _patch_base_image_for_aws_efa(job_spec, jpd)
757753
if shim_client.is_api_v2_supported():
758754
shim_client.submit_task(
759755
task_id=job_model.id,
@@ -775,7 +771,7 @@ def _process_provisioning_with_shim(
775771
host_ssh_user=ssh_user,
776772
host_ssh_keys=[ssh_key] if ssh_key else [],
777773
container_ssh_keys=public_keys,
778-
instance_id=job_provisioning_data.instance_id,
774+
instance_id=jpd.instance_id,
779775
)
780776
else:
781777
submitted = shim_client.submit(
@@ -792,7 +788,7 @@ def _process_provisioning_with_shim(
792788
mounts=volume_mounts,
793789
volumes=volumes,
794790
instance_mounts=instance_mounts,
795-
instance_id=job_provisioning_data.instance_id,
791+
instance_id=jpd.instance_id,
796792
)
797793
if not submitted:
798794
# This can happen when we lost connection to the runner (e.g., network issues), marked
@@ -826,7 +822,7 @@ def _process_pulling_with_shim(
826822
secrets: Dict[str, str],
827823
repo_credentials: Optional[RemoteRepoCreds],
828824
server_ssh_private_keys: tuple[str, Optional[str]],
829-
job_provisioning_data: JobProvisioningData,
825+
jpd: JobProvisioningData,
830826
) -> bool:
831827
"""
832828
Possible next states:
@@ -892,7 +888,7 @@ def _process_pulling_with_shim(
892888

893889
return _submit_job_to_runner(
894890
server_ssh_private_keys,
895-
job_provisioning_data,
891+
jpd,
896892
job_runtime_data,
897893
session=session,
898894
run=run,

0 commit comments

Comments
 (0)