|
19 | 19 | from dstack._internal.core.models.configurations import RunConfigurationType |
20 | 20 | from dstack._internal.core.models.runs import ( |
21 | 21 | Job, |
| 22 | + JobConnectionInfo, |
22 | 23 | JobProvisioningData, |
23 | 24 | JobRuntimeData, |
24 | 25 | JobSpec, |
|
37 | 38 | ) |
38 | 39 | from dstack._internal.server.services import events |
39 | 40 | from dstack._internal.server.services import volumes as volumes_services |
| 41 | +from dstack._internal.server.services.ides import get_ide |
40 | 42 | from dstack._internal.server.services.instances import ( |
41 | 43 | get_instance_ssh_private_keys, |
42 | 44 | ) |
|
51 | 53 | from dstack._internal.server.services.probes import probe_model_to_probe |
52 | 54 | from dstack._internal.server.services.runner import client |
53 | 55 | from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel |
| 56 | +from dstack._internal.server.services.sshproxy import ( |
| 57 | + build_proxied_job_ssh_command, |
| 58 | + build_proxied_job_ssh_url_authority, |
| 59 | +) |
54 | 60 | from dstack._internal.utils import common |
55 | 61 | from dstack._internal.utils.common import run_async |
56 | 62 | from dstack._internal.utils.logging import get_logger |
| 63 | +from dstack._internal.utils.ssh import build_ssh_command, build_ssh_url_authority |
57 | 64 |
|
58 | 65 | logger = get_logger(__name__) |
59 | 66 |
|
@@ -490,6 +497,47 @@ def remove_job_spec_sensitive_info(spec: JobSpec): |
490 | 497 | spec.ssh_key = None |
491 | 498 |
|
492 | 499 |
|
| 500 | +def get_job_connection_info(job_model: JobModel, run_spec: RunSpec) -> JobConnectionInfo: |
| 501 | + # Run.attach() Python API method, used internally by CLI, uses the following as the Hostname |
| 502 | + # in the SSH config: |
| 503 | + # * for the (job=0 replica=0) job - run name, e.g., `my-task` |
| 504 | + # * for other jobs - job name, e.g., `my-task-0-1` |
| 505 | + attached_hostname = run_spec.run_name |
| 506 | + if job_model.job_num != 0 or job_model.replica_num != 0: |
| 507 | + attached_hostname = job_model.job_name |
| 508 | + assert attached_hostname is not None |
| 509 | + |
| 510 | + # ide_* fields are for dev-environment only |
| 511 | + ide_name: Optional[str] = None |
| 512 | + # IDE URLs are not set until the job status is switched to RUNNING, |
| 513 | + # as JobRuntimeData.working_dir, which is required to build URLs, is returned |
| 514 | + # by dstack-runner's `/api/run` method |
| 515 | + attached_ide_url: Optional[str] = None |
| 516 | + proxied_ide_url: Optional[str] = None |
| 517 | + if ( |
| 518 | + run_spec.configuration.type == RunConfigurationType.DEV_ENVIRONMENT.value |
| 519 | + and run_spec.configuration.ide is not None |
| 520 | + ): |
| 521 | + ide = get_ide(run_spec.configuration.ide) |
| 522 | + if ide is not None: |
| 523 | + ide_name = ide.name |
| 524 | + jrd = get_job_runtime_data(job_model) |
| 525 | + if jrd is not None and jrd.working_dir is not None: |
| 526 | + attached_url_authority = build_ssh_url_authority(hostname=attached_hostname) |
| 527 | + attached_ide_url = ide.get_url(attached_url_authority, jrd.working_dir) |
| 528 | + proxied_url_authority = build_proxied_job_ssh_url_authority(job_model) |
| 529 | + if proxied_url_authority is not None: |
| 530 | + proxied_ide_url = ide.get_url(proxied_url_authority, jrd.working_dir) |
| 531 | + |
| 532 | + return JobConnectionInfo( |
| 533 | + ide_name=ide_name, |
| 534 | + attached_ide_url=attached_ide_url, |
| 535 | + proxied_ide_url=proxied_ide_url, |
| 536 | + attached_ssh_command=build_ssh_command(hostname=attached_hostname), |
| 537 | + proxied_ssh_command=build_proxied_job_ssh_command(job_model), |
| 538 | + ) |
| 539 | + |
| 540 | + |
493 | 541 | def _get_job_mount_point_attached_volume( |
494 | 542 | volumes: List[Volume], |
495 | 543 | job_provisioning_data: JobProvisioningData, |
|
0 commit comments