33import pydantic
44from sqlalchemy import select
55from sqlalchemy .ext .asyncio import AsyncSession
6- from sqlalchemy .orm import joinedload
6+ from sqlalchemy .orm import contains_eager , joinedload
77
88import dstack ._internal .server .services .jobs as jobs_services
99from dstack ._internal .core .consts import DSTACK_RUNNER_SSH_PORT
3030 TGIChatModelFormat ,
3131)
3232from dstack ._internal .proxy .lib .repo import BaseProxyRepo
33- from dstack ._internal .server .models import JobModel , ProjectModel , RunModel
33+ from dstack ._internal .server .models import InstanceModel , JobModel , ProjectModel , RunModel
3434from dstack ._internal .server .services .instances import get_instance_remote_connection_info
3535from dstack ._internal .server .settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
3636from dstack ._internal .utils .common import get_or_error
@@ -59,8 +59,9 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
5959 JobModel .job_num == 0 ,
6060 )
6161 .options (
62- joinedload (JobModel .run ),
63- joinedload (JobModel .instance ),
62+ contains_eager (JobModel .run ),
63+ contains_eager (JobModel .project ),
64+ joinedload (JobModel .instance ).joinedload (InstanceModel .project ),
6465 )
6566 )
6667 jobs = res .unique ().scalars ().all ()
@@ -77,10 +78,12 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
7778 )
7879 assert jpd .hostname is not None
7980 assert jpd .ssh_port is not None
81+ instance = get_or_error (job .instance )
8082 if not jpd .dockerized :
8183 ssh_destination = f"{ jpd .username } @{ jpd .hostname } "
8284 ssh_port = jpd .ssh_port
8385 ssh_proxy = jpd .ssh_proxy
86+ ssh_proxy_private_key = None
8487 else :
8588 ssh_destination = "root@localhost"
8689 ssh_port = DSTACK_RUNNER_SSH_PORT
@@ -93,11 +96,14 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
9396 username = jpd .username ,
9497 port = jpd .ssh_port ,
9598 )
99+ ssh_proxy_private_key = None
100+ if job .project_id != instance .project_id :
101+ ssh_proxy_private_key = instance .project .ssh_private_key
96102 if jpd .backend == BackendType .LOCAL :
97103 ssh_proxy = None
104+ ssh_proxy_private_key = None
98105 ssh_head_proxy : Optional [SSHConnectionParams ] = None
99106 ssh_head_proxy_private_key : Optional [str ] = None
100- instance = get_or_error (job .instance )
101107 rci = get_instance_remote_connection_info (instance )
102108 if rci is not None and rci .ssh_proxy is not None :
103109 ssh_head_proxy = rci .ssh_proxy
@@ -109,6 +115,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
109115 ssh_destination = ssh_destination ,
110116 ssh_port = ssh_port ,
111117 ssh_proxy = ssh_proxy ,
118+ ssh_proxy_private_key = ssh_proxy_private_key ,
112119 ssh_head_proxy = ssh_head_proxy ,
113120 ssh_head_proxy_private_key = ssh_head_proxy_private_key ,
114121 internal_ip = jpd .internal_ip ,
0 commit comments