Skip to content

Commit 8d7a870

Browse files
committed
Add API for SSH proxy
Part-of: #3644
1 parent 684ade3 commit 8d7a870

13 files changed

Lines changed: 766 additions & 50 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ markers = [
126126
]
127127
env = [
128128
"DSTACK_CLI_RICH_FORCE_TERMINAL=0",
129+
"DSTACK_SSHPROXY_API_TOKEN=test-token",
129130
]
130131
filterwarnings = [
131132
# testcontainers modules use deprecated decorators – nothing we can do:
@@ -142,6 +143,7 @@ dev = [
142143
"pytest-httpbin>=2.1.0",
143144
"pytest-socket>=0.7.0",
144145
"pytest-env>=1.1.0",
146+
"pytest-unordered>=0.7.0",
145147
"httpbin>=0.10.2", # indirect to make compatible with Werkzeug 3
146148
"requests-mock>=1.12.1",
147149
"openai>=1.68.2",

src/dstack/_internal/server/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
runs,
4545
secrets,
4646
server,
47+
sshproxy,
4748
templates,
4849
users,
4950
volumes,
@@ -253,6 +254,7 @@ def register_routes(app: FastAPI, ui: bool = True):
253254
app.include_router(files.router)
254255
app.include_router(events.root_router)
255256
app.include_router(templates.router)
257+
app.include_router(sshproxy.router)
256258

257259
@app.exception_handler(ForbiddenError)
258260
async def forbidden_error_handler(request: Request, exc: ForbiddenError):
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
from typing import Annotated
3+
4+
from fastapi import APIRouter, Depends
5+
from sqlalchemy.ext.asyncio import AsyncSession
6+
7+
from dstack._internal.core.errors import ResourceNotExistsError
8+
from dstack._internal.server.db import get_session
9+
from dstack._internal.server.schemas.sshproxy import GetUpstreamRequest, GetUpstreamResponse
10+
from dstack._internal.server.security.permissions import AlwaysForbidden, ServiceAccount
11+
from dstack._internal.server.services.sshproxy import get_upstream_response
12+
from dstack._internal.server.utils.routers import (
13+
CustomORJSONResponse,
14+
get_base_api_additional_responses,
15+
)
16+
17+
if _token := os.getenv("DSTACK_SSHPROXY_API_TOKEN"):
18+
_auth = ServiceAccount(_token)
19+
else:
20+
_auth = AlwaysForbidden()
21+
22+
23+
router = APIRouter(
24+
prefix="/api/sshproxy",
25+
tags=["sshproxy"],
26+
responses=get_base_api_additional_responses(),
27+
dependencies=[Depends(_auth)],
28+
)
29+
30+
31+
@router.post("/get_upstream", response_model=GetUpstreamResponse)
32+
async def get_upstream(
33+
body: GetUpstreamRequest,
34+
session: Annotated[AsyncSession, Depends(get_session)],
35+
):
36+
response = await get_upstream_response(session=session, upstream_id=body.id)
37+
if response is None:
38+
raise ResourceNotExistsError()
39+
return CustomORJSONResponse(response)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Annotated
2+
3+
from pydantic import Field
4+
5+
from dstack._internal.core.models.common import CoreModel
6+
7+
8+
class GetUpstreamRequest(CoreModel):
9+
# The format of id is intentionally not limited to UUID to allow further extensions
10+
id: str
11+
12+
13+
class UpstreamHost(CoreModel):
14+
host: Annotated[str, Field(description="The hostname or IP address")]
15+
port: Annotated[int, Field(description="The SSH port")]
16+
user: Annotated[str, Field(description="The user to log in")]
17+
private_key: Annotated[str, Field(description="The private key in OpenSSH file format")]
18+
19+
20+
class GetUpstreamResponse(CoreModel):
21+
hosts: Annotated[
22+
list[UpstreamHost],
23+
Field(description="The chain of SSH hosts, the jump host(s) first, the target host last"),
24+
]
25+
authorized_keys: Annotated[
26+
list[str], Field(description="The list of authorized public keys in OpenSSH file format")
27+
]

src/dstack/_internal/server/security/permissions.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from secrets import compare_digest
12
from typing import Annotated, Optional, Tuple
23
from uuid import UUID
34

@@ -219,9 +220,23 @@ async def __call__(
219220
raise error_forbidden()
220221

221222

222-
class OptionalServiceAccount:
223+
class ServiceAccount:
224+
def __init__(self, token: str) -> None:
225+
self._token = token.encode()
226+
227+
async def __call__(
228+
self, token: Annotated[HTTPAuthorizationCredentials, Security(HTTPBearer())]
229+
) -> None:
230+
if not compare_digest(token.credentials.encode(), self._token):
231+
raise error_invalid_token()
232+
233+
234+
class OptionalServiceAccount(ServiceAccount):
235+
_token: Optional[bytes] = None
236+
223237
def __init__(self, token: Optional[str]) -> None:
224-
self._token = token
238+
if token is not None:
239+
super().__init__(token)
225240

226241
async def __call__(
227242
self,
@@ -233,8 +248,12 @@ async def __call__(
233248
return
234249
if token is None:
235250
raise error_forbidden()
236-
if token.credentials != self._token:
237-
raise error_invalid_token()
251+
await super().__call__(token)
252+
253+
254+
class AlwaysForbidden:
255+
async def __call__(self) -> None:
256+
raise error_forbidden()
238257

239258

240259
async def get_project_member(

src/dstack/_internal/server/services/jobs/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,10 @@ def get_job_runtime_data(job_model: JobModel) -> Optional[JobRuntimeData]:
267267
return JobRuntimeData.__response__.parse_raw(job_model.job_runtime_data)
268268

269269

270+
def get_job_spec(job_model: JobModel) -> JobSpec:
271+
return JobSpec.__response__.parse_raw(job_model.job_spec_data)
272+
273+
270274
def delay_job_instance_termination(job_model: JobModel):
271275
job_model.remove_at = common.get_current_datetime() + timedelta(seconds=15)
272276

src/dstack/_internal/server/services/runs/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def switch_run_status(
112112
events.emit(session, msg, actor=actor, targets=[events.Target.from_model(run_model)])
113113

114114

115+
def get_run_spec(run_model: RunModel) -> RunSpec:
116+
return RunSpec.__response__.parse_raw(run_model.run_spec)
117+
118+
115119
async def list_user_runs(
116120
session: AsyncSession,
117121
user: UserModel,
@@ -743,7 +747,7 @@ def run_model_to_run(
743747
include_sensitive=include_sensitive,
744748
)
745749

746-
run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
750+
run_spec = get_run_spec(run_model)
747751

748752
latest_job_submission = None
749753
if len(jobs) > 0 and len(jobs[0].job_submissions) > 0:
Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,93 @@
11
from collections.abc import Iterable
2-
from typing import Optional
32

4-
import dstack._internal.server.services.jobs as jobs_services
53
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
64
from dstack._internal.core.models.backends.base import BackendType
75
from dstack._internal.core.models.instances import SSHConnectionParams
8-
from dstack._internal.core.models.runs import JobProvisioningData
96
from dstack._internal.core.services.ssh.tunnel import SSH_DEFAULT_OPTIONS, SocketPair, SSHTunnel
107
from dstack._internal.server.models import JobModel
118
from dstack._internal.server.services.instances import get_instance_remote_connection_info
9+
from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
1210
from dstack._internal.utils.common import get_or_error
1311
from dstack._internal.utils.path import FileContent
1412

1513

16-
def container_ssh_tunnel(
17-
job: JobModel,
18-
forwarded_sockets: Iterable[SocketPair] = (),
19-
options: dict[str, str] = SSH_DEFAULT_OPTIONS,
20-
) -> SSHTunnel:
14+
def get_container_ssh_credentials(job: JobModel) -> list[tuple[SSHConnectionParams, FileContent]]:
2115
"""
22-
Build SSHTunnel for connecting to the container running the specified job.
16+
Returns the information needed to connect to the SSH server inside the job container.
17+
18+
The user of the target host (container) is set to:
19+
* VM-based backends and SSH instances: "root"
20+
* container-based backends: `JobProvisioningData.username`, which is, as of 2026-03-10,
21+
is always "root" on all supported backends (Runpod, Vast.ai, Kubernetes)
22+
23+
Args:
24+
job: `JobModel` with `instance` and `instance.project` fields loaded.
25+
26+
Returns:
27+
A list of hosts credentials as (host's `SSHConnectionParams`, private key's `FileContent`)
28+
pairs ordered from the first proxy jump (if any) to the target host (container).
2329
"""
24-
jpd: JobProvisioningData = JobProvisioningData.__response__.parse_raw(
25-
job.job_provisioning_data
26-
)
30+
hosts: list[tuple[SSHConnectionParams, FileContent]] = []
31+
32+
instance = get_or_error(job.instance)
33+
project_key = FileContent(instance.project.ssh_private_key)
34+
35+
rci = get_instance_remote_connection_info(instance)
36+
if rci is not None and (head_proxy := rci.ssh_proxy) is not None:
37+
head_key = FileContent(get_or_error(get_or_error(rci.ssh_proxy_keys)[0].private))
38+
hosts.append((head_proxy, head_key))
39+
40+
jpd = get_job_provisioning_data(job)
41+
assert jpd is not None
2742
assert jpd.hostname is not None
2843
assert jpd.ssh_port is not None
29-
if not jpd.dockerized:
30-
ssh_destination = f"{jpd.username}@{jpd.hostname}"
31-
ssh_port = jpd.ssh_port
32-
ssh_proxy = jpd.ssh_proxy
33-
else:
34-
ssh_destination = "root@localhost"
44+
45+
if jpd.dockerized:
46+
if jpd.backend != BackendType.LOCAL:
47+
instance_proxy = SSHConnectionParams(
48+
hostname=jpd.hostname,
49+
username=jpd.username,
50+
port=jpd.ssh_port,
51+
)
52+
hosts.append((instance_proxy, project_key))
3553
ssh_port = DSTACK_RUNNER_SSH_PORT
36-
job_submission = jobs_services.job_model_to_job_submission(job)
37-
jrd = job_submission.job_runtime_data
54+
jrd = get_job_runtime_data(job)
3855
if jrd is not None and jrd.ports is not None:
3956
ssh_port = jrd.ports.get(ssh_port, ssh_port)
40-
ssh_proxy = SSHConnectionParams(
57+
target_host = SSHConnectionParams(
58+
hostname="localhost",
59+
username="root",
60+
port=ssh_port,
61+
)
62+
hosts.append((target_host, project_key))
63+
else:
64+
if jpd.ssh_proxy is not None:
65+
hosts.append((jpd.ssh_proxy, project_key))
66+
target_host = SSHConnectionParams(
4167
hostname=jpd.hostname,
4268
username=jpd.username,
4369
port=jpd.ssh_port,
4470
)
45-
if jpd.backend == BackendType.LOCAL:
46-
ssh_proxy = None
47-
ssh_head_proxy: Optional[SSHConnectionParams] = None
48-
ssh_head_proxy_private_key: Optional[str] = None
49-
instance = get_or_error(job.instance)
50-
rci = get_instance_remote_connection_info(instance)
51-
if rci is not None and rci.ssh_proxy is not None:
52-
ssh_head_proxy = rci.ssh_proxy
53-
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
54-
ssh_proxies = []
55-
if ssh_head_proxy is not None:
56-
ssh_head_proxy_private_key = get_or_error(ssh_head_proxy_private_key)
57-
ssh_proxies.append((ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
58-
if ssh_proxy is not None:
59-
ssh_proxies.append((ssh_proxy, None))
71+
hosts.append((target_host, project_key))
72+
73+
return hosts
74+
75+
76+
def container_ssh_tunnel(
77+
job: JobModel,
78+
forwarded_sockets: Iterable[SocketPair] = (),
79+
options: dict[str, str] = SSH_DEFAULT_OPTIONS,
80+
) -> SSHTunnel:
81+
"""
82+
Build SSHTunnel for connecting to the container running the specified job.
83+
"""
84+
hosts = get_container_ssh_credentials(job)
85+
target, identity = hosts[-1]
6086
return SSHTunnel(
61-
destination=ssh_destination,
62-
port=ssh_port,
63-
ssh_proxies=ssh_proxies,
64-
identity=FileContent(instance.project.ssh_private_key),
87+
destination=f"{target.username}@{target.hostname}",
88+
port=target.port,
89+
ssh_proxies=hosts[:-1],
90+
identity=identity,
6591
forwarded_sockets=forwarded_sockets,
6692
options=options,
6793
)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from typing import Optional
2+
from uuid import UUID
3+
4+
from sqlalchemy import select
5+
from sqlalchemy.ext.asyncio import AsyncSession
6+
from sqlalchemy.orm import joinedload
7+
8+
from dstack._internal.core.models.runs import JobStatus
9+
from dstack._internal.server.models import (
10+
InstanceModel,
11+
JobModel,
12+
ProjectModel,
13+
RunModel,
14+
UserModel,
15+
)
16+
from dstack._internal.server.schemas.sshproxy import GetUpstreamResponse, UpstreamHost
17+
from dstack._internal.server.services.jobs import get_job_runtime_data, get_job_spec
18+
from dstack._internal.server.services.runs import get_run_spec
19+
from dstack._internal.server.services.ssh import get_container_ssh_credentials
20+
21+
22+
async def get_upstream_response(
23+
session: AsyncSession,
24+
upstream_id: str,
25+
) -> Optional[GetUpstreamResponse]:
26+
# The format of upstream_id is intentionally not limited to UUID in the API schema to allow
27+
# further extensions. Currently, it's just a JobModel.id
28+
try:
29+
job_id = UUID(upstream_id)
30+
except ValueError:
31+
return None
32+
33+
res = await session.execute(
34+
select(JobModel)
35+
.where(
36+
JobModel.id == job_id,
37+
JobModel.status == JobStatus.RUNNING,
38+
)
39+
.options(
40+
(
41+
joinedload(JobModel.instance, innerjoin=True)
42+
.load_only(InstanceModel.remote_connection_info)
43+
.joinedload(InstanceModel.project, innerjoin=True)
44+
.load_only(ProjectModel.ssh_private_key)
45+
),
46+
(
47+
joinedload(JobModel.run, innerjoin=True)
48+
.load_only(RunModel.run_spec)
49+
.joinedload(RunModel.user, innerjoin=True)
50+
.load_only(UserModel.ssh_public_key)
51+
),
52+
)
53+
)
54+
job = res.scalar_one_or_none()
55+
if job is None:
56+
return None
57+
58+
hosts: list[UpstreamHost] = []
59+
for ssh_params, private_key in get_container_ssh_credentials(job):
60+
hosts.append(
61+
UpstreamHost(
62+
host=ssh_params.hostname,
63+
port=ssh_params.port,
64+
user=ssh_params.username,
65+
private_key=private_key.content,
66+
)
67+
)
68+
69+
username: Optional[str] = None
70+
if (jrd := get_job_runtime_data(job)) is not None:
71+
username = jrd.username
72+
if username is None and (job_spec_user := get_job_spec(job).user) is not None:
73+
username = job_spec_user.username
74+
if username is not None:
75+
hosts[-1].user = username
76+
77+
authorized_keys: set[str] = set()
78+
if (run_spec_key := get_run_spec(job.run).ssh_key_pub) is not None:
79+
authorized_keys.add(run_spec_key)
80+
if (user_key := job.run.user.ssh_public_key) is not None:
81+
authorized_keys.add(user_key)
82+
83+
return GetUpstreamResponse(
84+
hosts=hosts,
85+
authorized_keys=list(authorized_keys),
86+
)

0 commit comments

Comments
 (0)