Skip to content

Commit 8cd53b1

Browse files
authored
Use sshproxy for CLI attach if enabled (#3711)
`Run.attach()` (used internally by `dstack apply` and `dstack attach`) now prefers sshproxy over direct backend-specific SSH access if the server reports that sshproxy is enabled via new `JobConnectionInfo.sshproxy_*` fields. In addition, `DSTACK_SERVER_SSHPROXY_ENFORCED` disables direct SSH access to the instance/container by not including user's public key to the instance/container `authorized_keys`. Note, this setting renders dstack server incompatible with older clients. Part-of: #3644
1 parent 725dfe2 commit 8cd53b1

File tree

12 files changed

+454
-121
lines changed

12 files changed

+454
-121
lines changed

src/dstack/_internal/core/models/runs.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,29 @@ class JobConnectionInfo(CoreModel):
472472
)
473473
),
474474
]
475+
sshproxy_hostname: Annotated[
476+
Optional[str],
477+
Field(description="sshproxy hostname. Not set if sshproxy is not configured."),
478+
] = None
479+
sshproxy_port: Annotated[
480+
Optional[int],
481+
Field(
482+
description=(
483+
"ssproxy port. Not set if sshproxy is not configured."
484+
" May be not set if it is equal to the default SSH port 22."
485+
)
486+
),
487+
] = None
488+
sshproxy_upstream_id: Annotated[
489+
Optional[str],
490+
Field(
491+
description=(
492+
"sshproxy identifier for this job. SSH clients send this identifier as a username"
493+
" to indicate which job they wish to connect."
494+
" Not set if sshproxy is not configured."
495+
)
496+
),
497+
] = None
475498

476499

477500
class Job(CoreModel):

src/dstack/_internal/core/services/ssh/attach.py

Lines changed: 131 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,17 @@
2828
_SSH_TUNNEL_REGEX = re.compile(r"(?:[\w.-]+:)?(?P<local_port>\d+):localhost:(?P<remote_port>\d+)")
2929

3030

31-
class SSHAttach:
31+
class BaseSSHAttach:
32+
"""
33+
A base class for SSH attach implementations.
34+
35+
Child classes must populate `self.hosts` inside overridden `__init__()` with at least one host
36+
named as a `run_name` argument value.
37+
"""
38+
3239
@classmethod
3340
def get_control_sock_path(cls, run_name: str) -> Path:
34-
return ConfigManager().dstack_ssh_dir / f"%r@{run_name}.control.sock"
41+
return ConfigManager().dstack_ssh_dir / f"{run_name}.control.sock"
3542

3643
@classmethod
3744
def reuse_ports_lock(cls, run_name: str) -> Optional[PortsLock]:
@@ -57,21 +64,16 @@ def reuse_ports_lock(cls, run_name: str) -> Optional[PortsLock]:
5764

5865
def __init__(
5966
self,
60-
hostname: str,
61-
ssh_port: int,
62-
container_ssh_port: int,
63-
user: str,
64-
container_user: str,
65-
id_rsa_path: PathLike,
66-
ports_lock: PortsLock,
67+
*,
6768
run_name: str,
68-
dockerized: bool,
69-
ssh_proxy: Optional[SSHConnectionParams] = None,
69+
identity_path: PathLike,
70+
ports_lock: PortsLock,
71+
destination: str,
7072
service_port: Optional[int] = None,
71-
local_backend: bool = False,
7273
bind_address: Optional[str] = None,
7374
):
7475
self._attached = False
76+
self._hosts_added_to_ssh_config = False
7577
self._ports_lock = ports_lock
7678
self.ports = ports_lock.dict()
7779
self.run_name = run_name
@@ -80,9 +82,9 @@ def __init__(
8082
# Cast all path-like values used in configs to FilePath instances for automatic
8183
# path normalization in :func:`update_ssh_config`.
8284
self.control_sock_path = FilePath(control_sock_path)
83-
self.identity_file = FilePath(id_rsa_path)
85+
self.identity_file = FilePath(identity_path)
8486
self.tunnel = SSHTunnel(
85-
destination=f"root@{run_name}",
87+
destination=destination,
8688
identity=self.identity_file,
8789
forwarded_sockets=ports_to_forwarded_sockets(
8890
ports=self.ports,
@@ -94,12 +96,92 @@ def __init__(
9496
"ExitOnForwardFailure": "yes",
9597
},
9698
)
97-
self.ssh_proxy = ssh_proxy
9899
self.service_port = service_port
100+
self.hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {}
101+
102+
def __enter__(self):
103+
self.attach()
104+
return self
105+
106+
def __exit__(self, exc_type, exc_val, exc_tb):
107+
self.detach()
108+
109+
def attach(self):
110+
include_ssh_config(self.ssh_config_path)
111+
self._add_hosts_to_ssh_config()
99112

100-
hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {}
101-
self.hosts = hosts
113+
self._ports_lock.release()
102114

115+
max_retries = 10
116+
for i in range(max_retries):
117+
try:
118+
self.tunnel.open()
119+
self._attached = True
120+
atexit.register(self.detach)
121+
return
122+
except SSHError:
123+
if i < max_retries - 1:
124+
time.sleep(1)
125+
self._remove_hosts_from_ssh_config()
126+
raise SSHError("Can't connect to the remote host")
127+
128+
def detach(self):
129+
self._remove_hosts_from_ssh_config()
130+
if not self._attached:
131+
logger.debug("Not attached")
132+
return
133+
self.tunnel.close()
134+
self._attached = False
135+
logger.debug("Detached")
136+
137+
def _add_hosts_to_ssh_config(self):
138+
if self._hosts_added_to_ssh_config:
139+
return
140+
for host, options in self.hosts.items():
141+
update_ssh_config(self.ssh_config_path, host, options)
142+
self._hosts_added_to_ssh_config = True
143+
144+
def _remove_hosts_from_ssh_config(self):
145+
if not self._hosts_added_to_ssh_config:
146+
return
147+
for host in self.hosts:
148+
update_ssh_config(self.ssh_config_path, host, {})
149+
self._hosts_added_to_ssh_config = False
150+
151+
152+
class SSHAttach(BaseSSHAttach):
153+
"""
154+
`SSHAttach` attaches to a job directly, via a backend-specific chain of hosts.
155+
156+
Used when `dstack-sshproxy` is not configured on the server.
157+
"""
158+
159+
def __init__(
160+
self,
161+
*,
162+
run_name: str,
163+
identity_path: PathLike,
164+
ports_lock: PortsLock,
165+
hostname: str,
166+
ssh_port: int,
167+
container_ssh_port: int,
168+
user: str,
169+
container_user: str,
170+
dockerized: bool,
171+
ssh_proxy: Optional[SSHConnectionParams] = None,
172+
local_backend: bool = False,
173+
service_port: Optional[int] = None,
174+
bind_address: Optional[str] = None,
175+
):
176+
super().__init__(
177+
run_name=run_name,
178+
identity_path=identity_path,
179+
ports_lock=ports_lock,
180+
destination=f"root@{run_name}",
181+
service_port=service_port,
182+
bind_address=bind_address,
183+
)
184+
hosts = self.hosts
103185
if local_backend:
104186
hosts[run_name] = {
105187
"HostName": hostname,
@@ -195,47 +277,39 @@ def __init__(
195277
"StrictHostKeyChecking": "no",
196278
"UserKnownHostsFile": "/dev/null",
197279
}
198-
if get_ssh_client_info().supports_multiplexing:
199-
hosts[run_name].update(
200-
{
201-
"ControlMaster": "auto",
202-
"ControlPath": self.control_sock_path,
203-
}
204-
)
205280

206-
def attach(self):
207-
include_ssh_config(self.ssh_config_path)
208-
for host, options in self.hosts.items():
209-
update_ssh_config(self.ssh_config_path, host, options)
210281

211-
max_retries = 10
212-
self._ports_lock.release()
213-
for i in range(max_retries):
214-
try:
215-
self.tunnel.open()
216-
self._attached = True
217-
atexit.register(self.detach)
218-
break
219-
except SSHError:
220-
if i < max_retries - 1:
221-
time.sleep(1)
222-
else:
223-
self.detach()
224-
raise SSHError("Can't connect to the remote host")
282+
class SSHProxyAttach(BaseSSHAttach):
283+
"""
284+
`SSHProxyAttach` attaches to a job via `dstack-sshproxy`.
225285
226-
def detach(self):
227-
if not self._attached:
228-
logger.debug("Not attached")
229-
return
230-
self.tunnel.close()
231-
for host in self.hosts:
232-
update_ssh_config(self.ssh_config_path, host, {})
233-
self._attached = False
234-
logger.debug("Detached")
286+
Used when `dstack-sshproxy` is configured on the server.
287+
"""
235288

236-
def __enter__(self):
237-
self.attach()
238-
return self
239-
240-
def __exit__(self, exc_type, exc_val, exc_tb):
241-
self.detach()
289+
def __init__(
290+
self,
291+
*,
292+
run_name: str,
293+
identity_path: PathLike,
294+
ports_lock: PortsLock,
295+
hostname: str,
296+
upstream_id: str,
297+
port: Optional[int] = None,
298+
service_port: Optional[int] = None,
299+
bind_address: Optional[str] = None,
300+
):
301+
super().__init__(
302+
run_name=run_name,
303+
identity_path=identity_path,
304+
ports_lock=ports_lock,
305+
destination=f"{upstream_id}_root@{run_name}",
306+
service_port=service_port,
307+
bind_address=bind_address,
308+
)
309+
self.hosts[run_name] = {
310+
"HostName": hostname,
311+
"Port": port or 22,
312+
"User": upstream_id,
313+
"IdentityFile": self.identity_file,
314+
"IdentitiesOnly": "yes",
315+
}

src/dstack/_internal/core/services/ssh/tunnel.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import subprocess
66
import tempfile
77
from dataclasses import dataclass
8-
from typing import Dict, Iterable, List, Literal, Optional, Union
8+
from typing import Dict, Iterable, List, Literal, NoReturn, Optional, Union
99

1010
from dstack._internal.core.errors import SSHError
1111
from dstack._internal.core.models.instances import SSHConnectionParams
@@ -199,9 +199,8 @@ def open(self) -> None:
199199
raise SSHError(msg) from e
200200
if r.returncode == 0:
201201
return
202-
stderr = self._read_log_file()
203-
logger.debug("SSH tunnel failed: %s", stderr)
204-
raise get_ssh_error(stderr)
202+
log_output = self._read_log_file()
203+
self._raise_ssh_error_from_log_output(log_output)
205204

206205
async def aopen(self) -> None:
207206
await run_async(self._remove_log_file)
@@ -217,9 +216,8 @@ async def aopen(self) -> None:
217216
raise SSHError(msg) from e
218217
if proc.returncode == 0:
219218
return
220-
stderr = await run_async(self._read_log_file)
221-
logger.debug("SSH tunnel failed: %s", stderr)
222-
raise get_ssh_error(stderr)
219+
log_output = await run_async(self._read_log_file)
220+
self._raise_ssh_error_from_log_output(log_output)
223221

224222
def close(self) -> None:
225223
if not os.path.exists(self.control_sock_path):
@@ -325,9 +323,13 @@ def _build_proxy_command(
325323
]
326324
return "ProxyCommand=" + shlex.join(command)
327325

328-
def _read_log_file(self) -> bytes:
329-
with open(self.log_path, "rb") as f:
330-
return f.read()
326+
def _read_log_file(self) -> Optional[bytes]:
327+
try:
328+
with open(self.log_path, "rb") as f:
329+
return f.read()
330+
except OSError as e:
331+
logger.debug("Failed to read SSH tunnel log file %s: %s", self.log_path, e)
332+
return None
331333

332334
def _remove_log_file(self) -> None:
333335
try:
@@ -337,6 +339,16 @@ def _remove_log_file(self) -> None:
337339
except OSError as e:
338340
logger.debug("Failed to remove SSH tunnel log file %s: %s", self.log_path, e)
339341

342+
def _raise_ssh_error_from_log_output(self, output: Optional[bytes]) -> NoReturn:
343+
if output is None:
344+
msg = "(no log file)"
345+
ssh_error = SSHError()
346+
else:
347+
msg = output
348+
ssh_error = get_ssh_error(output)
349+
logger.debug("SSH tunnel failed: %s", msg)
350+
raise ssh_error
351+
340352
def _get_identity_path(self, identity: FilePathOrContent, tmp_filename: str) -> PathLike:
341353
if isinstance(identity, FilePath):
342354
return identity.path

src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
RunStatus,
3535
)
3636
from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint
37+
from dstack._internal.server import settings as server_settings
3738
from dstack._internal.server.background.pipeline_tasks.base import (
3839
Fetcher,
3940
Heartbeater,
@@ -578,12 +579,16 @@ async def _process_provisioning_status(
578579
fmt(context.job_model),
579580
context.job_submission.age,
580581
)
581-
ssh_user = job_provisioning_data.username
582-
assert context.run.run_spec.ssh_key_pub is not None
583-
user_ssh_key = context.run.run_spec.ssh_key_pub.strip()
584-
public_keys = [context.project.ssh_public_key.strip(), user_ssh_key]
582+
public_keys = [context.project.ssh_public_key.strip()]
583+
ssh_user: Optional[str] = None
584+
user_ssh_key: Optional[str] = None
585+
if not server_settings.SSHPROXY_ENFORCED:
586+
ssh_user = job_provisioning_data.username
587+
assert context.run.run_spec.ssh_key_pub is not None
588+
user_ssh_key = context.run.run_spec.ssh_key_pub.strip()
589+
public_keys.append(user_ssh_key)
585590
if job_provisioning_data.backend == BackendType.LOCAL:
586-
user_ssh_key = ""
591+
user_ssh_key = None
587592
success = await run_async(
588593
_process_provisioning_with_shim,
589594
server_ssh_private_keys,
@@ -1118,8 +1123,8 @@ def _process_provisioning_with_shim(
11181123
volumes: list[Volume],
11191124
registry_auth: Optional[RegistryAuth],
11201125
public_keys: list[str],
1121-
ssh_user: str,
1122-
ssh_key: str,
1126+
ssh_user: Optional[str],
1127+
ssh_key: Optional[str],
11231128
) -> bool:
11241129
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
11251130
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
@@ -1181,7 +1186,7 @@ def _process_provisioning_with_shim(
11811186
volume_mounts=volume_mounts,
11821187
instance_mounts=instance_mounts,
11831188
gpu_devices=gpu_devices,
1184-
host_ssh_user=ssh_user,
1189+
host_ssh_user=ssh_user or "",
11851190
host_ssh_keys=[ssh_key] if ssh_key else [],
11861191
container_ssh_keys=public_keys,
11871192
instance_id=jpd.instance_id,
@@ -1196,8 +1201,8 @@ def _process_provisioning_with_shim(
11961201
container_user=container_user,
11971202
shm_size=job_spec.requirements.resources.shm_size,
11981203
public_keys=public_keys,
1199-
ssh_user=ssh_user,
1200-
ssh_key=ssh_key,
1204+
ssh_user=ssh_user or "",
1205+
ssh_key=ssh_key or "",
12011206
mounts=volume_mounts,
12021207
volumes=volumes,
12031208
instance_mounts=instance_mounts,

0 commit comments

Comments
 (0)