Skip to content

Commit 2a3c77f

Browse files
authored
Fix services on imported fleets (#3654)
For all server-to-container and gateway-to-container SSH connections: - Use the instance project's key to connect to the instance - Use the job project's key to connect to the container
1 parent 130b22b commit 2a3c77f

File tree

14 files changed

+226
-21
lines changed

14 files changed

+226
-21
lines changed

src/dstack/_internal/proxy/gateway/routers/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ async def register_replica(
7878
ssh_destination=body.ssh_host,
7979
ssh_port=body.ssh_port,
8080
ssh_proxy=body.ssh_proxy,
81+
ssh_proxy_private_key=body.ssh_proxy_private_key,
8182
ssh_head_proxy=body.ssh_head_proxy,
8283
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
8384
internal_ip=body.internal_ip,

src/dstack/_internal/proxy/gateway/schemas/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class RegisterReplicaRequest(BaseModel):
5454
ssh_host: str
5555
ssh_port: int
5656
ssh_proxy: Optional[SSHConnectionParams]
57+
ssh_proxy_private_key: Optional[str]
5758
ssh_head_proxy: Optional[SSHConnectionParams]
5859
ssh_head_proxy_private_key: Optional[str]
5960
internal_ip: Optional[str] = None

src/dstack/_internal/proxy/gateway/services/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ async def register_replica(
134134
ssh_destination: str,
135135
ssh_port: int,
136136
ssh_proxy: Optional[SSHConnectionParams],
137+
ssh_proxy_private_key: Optional[str],
137138
ssh_head_proxy: Optional[SSHConnectionParams],
138139
ssh_head_proxy_private_key: Optional[str],
139140
repo: GatewayProxyRepo,
@@ -147,6 +148,7 @@ async def register_replica(
147148
ssh_destination=ssh_destination,
148149
ssh_port=ssh_port,
149150
ssh_proxy=ssh_proxy,
151+
ssh_proxy_private_key=ssh_proxy_private_key,
150152
ssh_head_proxy=ssh_head_proxy,
151153
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
152154
internal_ip=internal_ip,

src/dstack/_internal/proxy/lib/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class Replica(ImmutableModel):
2424
ssh_destination: str
2525
ssh_port: int
2626
ssh_proxy: Optional[SSHConnectionParams]
27+
ssh_proxy_private_key: Optional[str] = None
28+
"`None` means same as service project's key"
2729
# Optional outer proxy, a head node/bastion
2830
ssh_head_proxy: Optional[SSHConnectionParams] = None
2931
ssh_head_proxy_private_key: Optional[str] = None

src/dstack/_internal/proxy/lib/services/service_connection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ def __init__(self, project: Project, service: Service, replica: Replica) -> None
5353
ssh_head_proxy_private_key = get_or_error(replica.ssh_head_proxy_private_key)
5454
ssh_proxies.append((replica.ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
5555
if replica.ssh_proxy is not None:
56-
ssh_proxies.append((replica.ssh_proxy, None))
56+
if replica.ssh_proxy_private_key is not None:
57+
ssh_proxies.append((replica.ssh_proxy, FileContent(replica.ssh_proxy_private_key)))
58+
else:
59+
ssh_proxies.append((replica.ssh_proxy, None))
5760
self._tunnel = SSHTunnel(
5861
destination=replica.ssh_destination,
5962
port=replica.ssh_port,

src/dstack/_internal/server/background/scheduled_tasks/probes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ async def process_probes():
6363
.joinedload(JobModel.instance)
6464
.joinedload(InstanceModel.project)
6565
)
66-
.options(joinedload(ProbeModel.job))
66+
.options(joinedload(ProbeModel.job).joinedload(JobModel.project))
6767
.execution_options(populate_existing=True)
6868
)
6969
probes = res.unique().scalars().all()

src/dstack/_internal/server/services/gateways/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ async def register_replica(
8585
run: Run,
8686
job_spec: JobSpec,
8787
job_submission: JobSubmission,
88+
instance_project_ssh_private_key: Optional[str],
8889
ssh_head_proxy: Optional[SSHConnectionParams],
8990
ssh_head_proxy_private_key: Optional[str],
9091
):
@@ -122,6 +123,7 @@ async def register_replica(
122123
username=jpd.username,
123124
port=jpd.ssh_port,
124125
).dict(),
126+
"ssh_proxy_private_key": instance_project_ssh_private_key,
125127
}
126128
)
127129
resp = await self._client.post(

src/dstack/_internal/server/services/proxy/repo.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pydantic
44
from sqlalchemy import select
55
from sqlalchemy.ext.asyncio import AsyncSession
6-
from sqlalchemy.orm import joinedload
6+
from sqlalchemy.orm import contains_eager, joinedload
77

88
import dstack._internal.server.services.jobs as jobs_services
99
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
@@ -30,7 +30,7 @@
3030
TGIChatModelFormat,
3131
)
3232
from dstack._internal.proxy.lib.repo import BaseProxyRepo
33-
from dstack._internal.server.models import JobModel, ProjectModel, RunModel
33+
from dstack._internal.server.models import InstanceModel, JobModel, ProjectModel, RunModel
3434
from dstack._internal.server.services.instances import get_instance_remote_connection_info
3535
from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
3636
from dstack._internal.utils.common import get_or_error
@@ -59,8 +59,9 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
5959
JobModel.job_num == 0,
6060
)
6161
.options(
62-
joinedload(JobModel.run),
63-
joinedload(JobModel.instance),
62+
contains_eager(JobModel.run),
63+
contains_eager(JobModel.project),
64+
joinedload(JobModel.instance).joinedload(InstanceModel.project),
6465
)
6566
)
6667
jobs = res.unique().scalars().all()
@@ -77,10 +78,12 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
7778
)
7879
assert jpd.hostname is not None
7980
assert jpd.ssh_port is not None
81+
instance = get_or_error(job.instance)
8082
if not jpd.dockerized:
8183
ssh_destination = f"{jpd.username}@{jpd.hostname}"
8284
ssh_port = jpd.ssh_port
8385
ssh_proxy = jpd.ssh_proxy
86+
ssh_proxy_private_key = None
8487
else:
8588
ssh_destination = "root@localhost"
8689
ssh_port = DSTACK_RUNNER_SSH_PORT
@@ -93,11 +96,14 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
9396
username=jpd.username,
9497
port=jpd.ssh_port,
9598
)
99+
ssh_proxy_private_key = None
100+
if job.project_id != instance.project_id:
101+
ssh_proxy_private_key = instance.project.ssh_private_key
96102
if jpd.backend == BackendType.LOCAL:
97103
ssh_proxy = None
104+
ssh_proxy_private_key = None
98105
ssh_head_proxy: Optional[SSHConnectionParams] = None
99106
ssh_head_proxy_private_key: Optional[str] = None
100-
instance = get_or_error(job.instance)
101107
rci = get_instance_remote_connection_info(instance)
102108
if rci is not None and rci.ssh_proxy is not None:
103109
ssh_head_proxy = rci.ssh_proxy
@@ -109,6 +115,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
109115
ssh_destination=ssh_destination,
110116
ssh_port=ssh_port,
111117
ssh_proxy=ssh_proxy,
118+
ssh_proxy_private_key=ssh_proxy_private_key,
112119
ssh_head_proxy=ssh_head_proxy,
113120
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
114121
internal_ip=jpd.internal_ip,

src/dstack/_internal/server/services/services/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,13 +313,18 @@ async def register_replica(
313313
if gateway_id is not None:
314314
gateway, conn = await get_or_add_gateway_connection(session, gateway_id)
315315
job_submission = jobs_services.job_model_to_job_submission(job_model)
316+
assert job_model.instance is not None
317+
instance_project_ssh_private_key = None
318+
if job_model.project_id != job_model.instance.project_id:
319+
instance_project_ssh_private_key = job_model.instance.project.ssh_private_key
316320
try:
317321
logger.debug("%s: registering replica for service %s", fmt(job_model), run.id.hex)
318322
async with conn.client() as client:
319323
await client.register_replica(
320324
run=run,
321325
job_spec=JobSpec.__response__.parse_raw(job_model.job_spec_data),
322326
job_submission=job_submission,
327+
instance_project_ssh_private_key=instance_project_ssh_private_key,
323328
ssh_head_proxy=ssh_head_proxy,
324329
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
325330
)

src/dstack/_internal/server/services/ssh.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ def container_ssh_tunnel(
2626
)
2727
assert jpd.hostname is not None
2828
assert jpd.ssh_port is not None
29+
instance = get_or_error(job.instance)
2930
if not jpd.dockerized:
3031
ssh_destination = f"{jpd.username}@{jpd.hostname}"
3132
ssh_port = jpd.ssh_port
3233
ssh_proxy = jpd.ssh_proxy
34+
ssh_proxy_private_key = None
3335
else:
3436
ssh_destination = "root@localhost"
3537
ssh_port = DSTACK_RUNNER_SSH_PORT
@@ -42,11 +44,14 @@ def container_ssh_tunnel(
4244
username=jpd.username,
4345
port=jpd.ssh_port,
4446
)
47+
ssh_proxy_private_key = None
48+
if job.project_id != instance.project_id:
49+
ssh_proxy_private_key = FileContent(instance.project.ssh_private_key)
4550
if jpd.backend == BackendType.LOCAL:
4651
ssh_proxy = None
52+
ssh_proxy_private_key = None
4753
ssh_head_proxy: Optional[SSHConnectionParams] = None
4854
ssh_head_proxy_private_key: Optional[str] = None
49-
instance = get_or_error(job.instance)
5055
rci = get_instance_remote_connection_info(instance)
5156
if rci is not None and rci.ssh_proxy is not None:
5257
ssh_head_proxy = rci.ssh_proxy
@@ -56,12 +61,12 @@ def container_ssh_tunnel(
5661
ssh_head_proxy_private_key = get_or_error(ssh_head_proxy_private_key)
5762
ssh_proxies.append((ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
5863
if ssh_proxy is not None:
59-
ssh_proxies.append((ssh_proxy, None))
64+
ssh_proxies.append((ssh_proxy, ssh_proxy_private_key))
6065
return SSHTunnel(
6166
destination=ssh_destination,
6267
port=ssh_port,
6368
ssh_proxies=ssh_proxies,
64-
identity=FileContent(instance.project.ssh_private_key),
69+
identity=FileContent(job.project.ssh_private_key),
6570
forwarded_sockets=forwarded_sockets,
6671
options=options,
6772
)

0 commit comments

Comments
 (0)