Skip to content

Commit 44c46ea

Browse files
authored
Use server-managed user SSH keys for new runs (#3216)
This commit updates the CLI and the server to use server-managed user SSH keys when starting new runs. This allows users to attach to the run from different machines, since the SSH key is automatically replicated to all clients. Implementation details: - Server: - If the user key is missing, generate it when the user first calls `/get_my_user`. - Client: - Before applying or getting a run plan, call `/get_my_user` to check if the user key is available. If it is, use it. - Cache the downloaded keys in `~/.dstack/ssh` to avoid repeated `/get_my_user` calls. - Switch from `warn` to logger messages, since this code is part of the Python API, so its output should be configurable.
1 parent 95e5536 commit 44c46ea

File tree

10 files changed

+242
-46
lines changed

10 files changed

+242
-46
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
from dataclasses import dataclass
3+
from datetime import datetime, timedelta
4+
from pathlib import Path
5+
from typing import TYPE_CHECKING, Optional
6+
7+
from dstack._internal.core.models.users import UserWithCreds
8+
9+
if TYPE_CHECKING:
10+
from dstack.api.server import APIClient
11+
12+
KEY_REFRESH_RATE = timedelta(minutes=10) # redownload the key periodically in case it was rotated
13+
14+
15+
@dataclass
16+
class UserSSHKey:
17+
public_key: str
18+
private_key_path: Path
19+
20+
21+
class UserSSHKeyManager:
22+
def __init__(self, api_client: "APIClient", ssh_keys_dir: Path) -> None:
23+
self._api_client = api_client
24+
self._key_path = ssh_keys_dir / api_client.get_token_hash()
25+
self._pub_key_path = self._key_path.with_suffix(".pub")
26+
27+
def get_user_key(self) -> Optional[UserSSHKey]:
28+
"""
29+
Return the up-to-date user key, or None if the user has no key (if created before 0.19.33)
30+
"""
31+
if (
32+
not self._key_path.exists()
33+
or not self._pub_key_path.exists()
34+
or datetime.now() - datetime.fromtimestamp(self._key_path.stat().st_mtime)
35+
> KEY_REFRESH_RATE
36+
):
37+
if not self._download_user_key():
38+
return None
39+
return UserSSHKey(
40+
public_key=self._pub_key_path.read_text(), private_key_path=self._key_path
41+
)
42+
43+
def _download_user_key(self) -> bool:
44+
user = self._api_client.users.get_my_user()
45+
if not (isinstance(user, UserWithCreds) and user.ssh_public_key and user.ssh_private_key):
46+
return False
47+
48+
def key_opener(path, flags):
49+
return os.open(path, flags, 0o600)
50+
51+
with open(self._key_path, "w", opener=key_opener) as f:
52+
f.write(user.ssh_private_key)
53+
with open(self._pub_key_path, "w") as f:
54+
f.write(user.ssh_public_key)
55+
56+
return True

src/dstack/_internal/server/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ class UserModel(BaseModel):
190190
# deactivated users cannot access API
191191
active: Mapped[bool] = mapped_column(Boolean, default=True)
192192

193+
# SSH keys can be null for users created before 0.19.33.
194+
# Keys for those users are being gradually generated on /get_my_user calls.
195+
# TODO: make keys required in a future version.
193196
ssh_private_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
194197
ssh_public_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
195198

src/dstack/_internal/server/routers/users.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,15 @@ async def list_users(
3838

3939
@router.post("/get_my_user", response_model=UserWithCreds)
4040
async def get_my_user(
41+
session: AsyncSession = Depends(get_session),
4142
user: UserModel = Depends(Authenticated()),
4243
):
44+
if user.ssh_private_key is None or user.ssh_public_key is None:
45+
# Generate keys for pre-0.19.33 users
46+
updated_user = await users.refresh_ssh_key(session=session, user=user, username=user.name)
47+
if updated_user is None:
48+
raise ResourceNotExistsError()
49+
user = updated_user
4350
return CustomORJSONResponse(users.user_model_to_user_with_creds(user))
4451

4552

src/dstack/_internal/server/services/users.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from dstack._internal.server.models import DecryptedString, UserModel
2121
from dstack._internal.server.services.permissions import get_default_permissions
2222
from dstack._internal.server.utils.routers import error_forbidden
23+
from dstack._internal.utils import crypto
2324
from dstack._internal.utils.common import run_async
24-
from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes
2525
from dstack._internal.utils.logging import get_logger
2626

2727
logger = get_logger(__name__)
@@ -88,7 +88,7 @@ async def create_user(
8888
raise ResourceExistsError()
8989
if token is None:
9090
token = str(uuid.uuid4())
91-
private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username)
91+
private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username)
9292
user = UserModel(
9393
id=uuid.uuid4(),
9494
name=username,
@@ -135,7 +135,7 @@ async def refresh_ssh_key(
135135
logger.debug("Refreshing SSH key for user [code]%s[/code]", username)
136136
if user.global_role != GlobalRole.ADMIN and user.name != username:
137137
raise error_forbidden()
138-
private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username)
138+
private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username)
139139
await session.execute(
140140
update(UserModel)
141141
.where(UserModel.name == username)

src/dstack/_internal/server/testing/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ async def create_user(
126126
global_role: GlobalRole = GlobalRole.ADMIN,
127127
token: Optional[str] = None,
128128
email: Optional[str] = None,
129+
ssh_public_key: Optional[str] = None,
130+
ssh_private_key: Optional[str] = None,
129131
active: bool = True,
130132
) -> UserModel:
131133
if token is None:
@@ -137,6 +139,8 @@ async def create_user(
137139
token=DecryptedString(plaintext=token),
138140
token_hash=get_token_hash(token),
139141
email=email,
142+
ssh_public_key=ssh_public_key,
143+
ssh_private_key=ssh_private_key,
140144
active=active,
141145
)
142146
session.add(user)

src/dstack/api/_public/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dstack._internal.utils.path import PathLike as PathLike
77
from dstack.api._public.backends import BackendCollection
88
from dstack.api._public.repos import RepoCollection
9-
from dstack.api._public.runs import RunCollection, warn
9+
from dstack.api._public.runs import RunCollection
1010
from dstack.api.server import APIClient
1111

1212
logger = get_logger(__name__)
@@ -42,7 +42,7 @@ def __init__(
4242
self._backends = BackendCollection(api_client, project_name)
4343
self._runs = RunCollection(api_client, project_name, self)
4444
if ssh_identity_file is not None:
45-
warn(
45+
logger.warning(
4646
"[code]ssh_identity_file[/code] in [code]Client[/code] is deprecated and ignored; will be removed"
4747
" since 0.19.40"
4848
)

src/dstack/api/_public/runs.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import base64
2-
import hashlib
3-
import os
42
import queue
53
import tempfile
64
import threading
@@ -17,7 +15,6 @@
1715
from websocket import WebSocketApp
1816

1917
import dstack.api as api
20-
from dstack._internal.cli.utils.common import warn
2118
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_RUNNER_SSH_PORT
2219
from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError
2320
from dstack._internal.core.models.backends.base import BackendType
@@ -48,10 +45,10 @@
4845
get_service_port,
4946
)
5047
from dstack._internal.core.models.runs import Run as RunModel
51-
from dstack._internal.core.models.users import UserWithCreds
5248
from dstack._internal.core.services.configs import ConfigManager
5349
from dstack._internal.core.services.logs import URLReplacer
5450
from dstack._internal.core.services.ssh.attach import SSHAttach
51+
from dstack._internal.core.services.ssh.key_manager import UserSSHKeyManager
5552
from dstack._internal.core.services.ssh.ports import PortsLock
5653
from dstack._internal.server.schemas.logs import PollLogsRequest
5754
from dstack._internal.utils.common import get_or_error, make_proxy_url
@@ -88,7 +85,7 @@ def __init__(
8885
self._ports_lock: Optional[PortsLock] = ports_lock
8986
self._ssh_attach: Optional[SSHAttach] = None
9087
if ssh_identity_file is not None:
91-
warn(
88+
logger.warning(
9289
"[code]ssh_identity_file[/code] in [code]Run[/code] is deprecated and ignored; will be removed"
9390
" since 0.19.40"
9491
)
@@ -281,31 +278,20 @@ def attach(
281278
dstack.api.PortUsedError: If ports are in use or the run is attached by another process.
282279
"""
283280
if not ssh_identity_file:
284-
user = self._api_client.users.get_my_user()
285-
run_ssh_key_pub = self._run.run_spec.ssh_key_pub
286281
config_manager = ConfigManager()
287-
if isinstance(user, UserWithCreds) and user.ssh_public_key == run_ssh_key_pub:
288-
token_hash = hashlib.sha1(user.creds.token.encode()).hexdigest()[:8]
289-
config_manager.dstack_ssh_dir.mkdir(parents=True, exist_ok=True)
290-
ssh_identity_file = config_manager.dstack_ssh_dir / token_hash
291-
292-
def key_opener(path, flags):
293-
return os.open(path, flags, 0o600)
294-
295-
with open(ssh_identity_file, "wb", opener=key_opener) as f:
296-
assert user.ssh_private_key
297-
f.write(user.ssh_private_key.encode())
282+
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
283+
if (
284+
user_key := key_manager.get_user_key()
285+
) and user_key.public_key == self._run.run_spec.ssh_key_pub:
286+
ssh_identity_file = user_key.private_key_path
298287
else:
299288
if config_manager.dstack_key_path.exists():
300289
# TODO: Remove since 0.19.40
301-
warn(
302-
f"Using legacy [code]{config_manager.dstack_key_path}[/code]."
303-
" Future versions will use the user SSH key from the server.",
304-
)
290+
logger.debug(f"Using legacy [code]{config_manager.dstack_key_path}[/code].")
305291
ssh_identity_file = config_manager.dstack_key_path
306292
else:
307293
raise ConfigurationError(
308-
f"User SSH key doen't match; default SSH key ({config_manager.dstack_key_path}) doesn't exist"
294+
f"User SSH key doesn't match; default SSH key ({config_manager.dstack_key_path}) doesn't exist"
309295
)
310296
ssh_identity_file = str(ssh_identity_file)
311297

@@ -504,15 +490,19 @@ def get_run_plan(
504490
ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text()
505491
else:
506492
config_manager = ConfigManager()
507-
if not config_manager.dstack_key_path.exists():
508-
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
509-
warn(
510-
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
511-
" Future versions will use the user SSH key from the server.",
512-
)
513-
ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text()
514-
# TODO: Uncomment after 0.19.40
515-
# ssh_key_pub = None
493+
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
494+
if key_manager.get_user_key():
495+
ssh_key_pub = None # using the server-managed user key
496+
else:
497+
if not config_manager.dstack_key_path.exists():
498+
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
499+
logger.warning(
500+
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
501+
" You will only be able to attach to the run from this client."
502+
" Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys"
503+
" automatically replicated to all clients.",
504+
)
505+
ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text()
516506
run_spec = RunSpec(
517507
run_name=configuration.name,
518508
repo_id=repo.repo_id,
@@ -760,12 +750,19 @@ def get_plan(
760750
idle_duration=idle_duration, # type: ignore[assignment]
761751
)
762752
config_manager = ConfigManager()
763-
if not config_manager.dstack_key_path.exists():
764-
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
765-
warn(
766-
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
767-
" Future versions will use the user SSH key from the server.",
768-
)
753+
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
754+
if key_manager.get_user_key():
755+
ssh_key_pub = None # using the server-managed user key
756+
else:
757+
if not config_manager.dstack_key_path.exists():
758+
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
759+
logger.warning(
760+
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
761+
" You will only be able to attach to the run from this client."
762+
" Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys"
763+
" automatically replicated to all clients.",
764+
)
765+
ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text()
769766
run_spec = RunSpec(
770767
run_name=run_name,
771768
repo_id=repo.repo_id,
@@ -775,7 +772,7 @@ def get_plan(
775772
configuration_path=configuration_path,
776773
configuration=configuration,
777774
profile=profile,
778-
ssh_key_pub=config_manager.dstack_key_path.with_suffix(".pub").read_text(),
775+
ssh_key_pub=ssh_key_pub,
779776
)
780777
logger.debug("Getting run plan")
781778
run_plan = self._api_client.runs.get_plan(self._project, run_spec)

src/dstack/api/server/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import os
23
import pprint
34
import time
@@ -121,6 +122,9 @@ def volumes(self) -> VolumesAPIClient:
121122
def files(self) -> FilesAPIClient:
122123
return FilesAPIClient(self._request, self._logger)
123124

125+
def get_token_hash(self) -> str:
126+
return hashlib.sha1(self._token.encode()).hexdigest()[:8]
127+
124128
def _request(
125129
self,
126130
path: str,

0 commit comments

Comments
 (0)