Skip to content

Commit 376be2a

Browse files
authored
Clean up low- and high-level API clients (#3355)
* Deprecate and ignore RunCollection.get_run_plan() repo_dir argument * Ensure UsersAPIClient.get_my_user() returns UserWithCreds * Ensure UserSSHKeyManager.get_user_key() returns user SSH key stored on the server * Drop SSH key-related support code for pre-0.20.0 servers
1 parent 26016a5 commit 376be2a

File tree

11 files changed

+105
-74
lines changed

11 files changed

+105
-74
lines changed

docs/docs/reference/api/python/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ finally:
4444
!!! info "NOTE:"
4545
1. The `configuration` argument in the `apply_configuration` method can be either `dstack.api.Task`, `dstack.api.Service`, or `dstack.api.DevEnvironment`.
4646
2. When you create `dstack.api.Task`, `dstack.api.Service`, or `dstack.api.DevEnvironment`, you can specify the `image` argument. If `image` isn't specified, the default image will be used. For a private Docker registry, ensure you also pass the `registry_auth` argument.
47-
3. The `repo` argument in the `apply_configuration` method allows the mounting of a local folder, a remote repo, or a
47+
3. The `repo` argument in the `apply_configuration` method allows the mounting of a remote repo or a
4848
programmatically created repo. In this case, the `commands` argument can refer to the files within this repo.
4949
4. The `attach` method waits for the run to start and, for `dstack.api.Task` sets up an SSH tunnel and forwards
5050
configured `ports` to `localhost`.

src/dstack/_internal/core/services/ssh/key_manager.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from dataclasses import dataclass
33
from datetime import datetime, timedelta
44
from pathlib import Path
5-
from typing import TYPE_CHECKING, Optional
5+
from typing import TYPE_CHECKING
66

7-
from dstack._internal.core.models.users import UserWithCreds
7+
from dstack._internal.core.errors import ClientError
88

99
if TYPE_CHECKING:
1010
from dstack.api.server import APIClient
@@ -24,26 +24,25 @@ def __init__(self, api_client: "APIClient", ssh_keys_dir: Path) -> None:
2424
self._key_path = ssh_keys_dir / api_client.get_token_hash()
2525
self._pub_key_path = self._key_path.with_suffix(".pub")
2626

27-
def get_user_key(self) -> Optional[UserSSHKey]:
27+
def get_user_key(self) -> UserSSHKey:
2828
"""
29-
Return the up-to-date user key, or None if the user has no key (if created before 0.19.33)
29+
Return the up-to-date user key
3030
"""
3131
if (
3232
not self._key_path.exists()
3333
or not self._pub_key_path.exists()
3434
or datetime.now() - datetime.fromtimestamp(self._key_path.stat().st_mtime)
3535
> KEY_REFRESH_RATE
3636
):
37-
if not self._download_user_key():
38-
return None
37+
self._download_user_key()
3938
return UserSSHKey(
4039
public_key=self._pub_key_path.read_text(), private_key_path=self._key_path
4140
)
4241

43-
def _download_user_key(self) -> bool:
42+
def _download_user_key(self) -> None:
4443
user = self._api_client.users.get_my_user()
45-
if not (isinstance(user, UserWithCreds) and user.ssh_public_key and user.ssh_private_key):
46-
return False
44+
if user.ssh_private_key is None or user.ssh_public_key is None:
45+
raise ClientError("Server response does not contain user SSH key")
4746

4847
def key_opener(path, flags):
4948
return os.open(path, flags, 0o600)
@@ -52,5 +51,3 @@ def key_opener(path, flags):
5251
f.write(user.ssh_private_key)
5352
with open(self._pub_key_path, "w") as f:
5453
f.write(user.ssh_public_key)
55-
56-
return True

src/dstack/_internal/server/routers/runs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def get_plan(
118118
"""
119119
user, project = user_project
120120
if not user.ssh_public_key and not body.run_spec.ssh_key_pub:
121-
await users.refresh_ssh_key(session=session, user=user, username=user.name)
121+
await users.refresh_ssh_key(session=session, user=user)
122122
run_plan = await runs.get_plan(
123123
session=session,
124124
project=project,
@@ -148,7 +148,7 @@ async def apply_plan(
148148
"""
149149
user, project = user_project
150150
if not user.ssh_public_key and not body.plan.run_spec.ssh_key_pub:
151-
await users.refresh_ssh_key(session=session, user=user, username=user.name)
151+
await users.refresh_ssh_key(session=session, user=user)
152152
return CustomORJSONResponse(
153153
await runs.apply_plan(
154154
session=session,

src/dstack/_internal/server/routers/users.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@ async def get_my_user(
4343
):
4444
if user.ssh_private_key is None or user.ssh_public_key is None:
4545
# Generate keys for pre-0.19.33 users
46-
updated_user = await users.refresh_ssh_key(session=session, user=user, username=user.name)
47-
if updated_user is None:
48-
raise ResourceNotExistsError()
49-
user = updated_user
46+
await users.refresh_ssh_key(session=session, user=user)
5047
return CustomORJSONResponse(users.user_model_to_user_with_creds(user))
5148

5249

src/dstack/_internal/server/services/users.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ async def update_user(
147147
async def refresh_ssh_key(
148148
session: AsyncSession,
149149
user: UserModel,
150-
username: str,
150+
username: Optional[str] = None,
151151
) -> Optional[UserModel]:
152+
if username is None:
153+
username = user.name
152154
logger.debug("Refreshing SSH key for user [code]%s[/code]", username)
153155
if user.global_role != GlobalRole.ADMIN and user.name != username:
154156
raise error_forbidden()

src/dstack/_internal/server/testing/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def get_run_spec(
277277
configuration_path: str = "dstack.yaml",
278278
profile: Union[Profile, Callable[[], Profile], None] = lambda: Profile(name="default"),
279279
configuration: Optional[AnyRunConfiguration] = None,
280+
ssh_key_pub: Optional[str] = "user_ssh_key",
280281
) -> RunSpec:
281282
if callable(profile):
282283
profile = profile()
@@ -288,7 +289,7 @@ def get_run_spec(
288289
configuration_path=configuration_path,
289290
configuration=configuration or DevEnvironmentConfiguration(ide="vscode"),
290291
profile=profile,
291-
ssh_key_pub="user_ssh_key",
292+
ssh_key_pub=ssh_key_pub,
292293
)
293294

294295

src/dstack/api/_public/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import enum
2+
3+
4+
class Deprecated(enum.Enum):
5+
PLACEHOLDER = "DEPRECATED"

src/dstack/api/_public/runs.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from copy import copy
1010
from datetime import datetime
1111
from pathlib import Path
12-
from typing import BinaryIO, Dict, Iterable, List, Optional
12+
from typing import BinaryIO, Dict, Iterable, List, Optional, Union
1313
from urllib.parse import urlencode, urlparse
1414

1515
from websocket import WebSocketApp
@@ -46,10 +46,10 @@
4646
from dstack._internal.core.services.ssh.ports import PortsLock
4747
from dstack._internal.server.schemas.logs import PollLogsRequest
4848
from dstack._internal.utils.common import get_or_error, make_proxy_url
49-
from dstack._internal.utils.crypto import generate_rsa_key_pair
5049
from dstack._internal.utils.files import create_file_archive
5150
from dstack._internal.utils.logging import get_logger
5251
from dstack._internal.utils.path import PathLike
52+
from dstack.api._public.common import Deprecated
5353
from dstack.api.server import APIClient
5454

5555
logger = get_logger(__name__)
@@ -278,13 +278,11 @@ def attach(
278278
if not ssh_identity_file:
279279
config_manager = ConfigManager()
280280
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
281-
if (
282-
user_key := key_manager.get_user_key()
283-
) and user_key.public_key == self._run.run_spec.ssh_key_pub:
281+
user_key = key_manager.get_user_key()
282+
if user_key.public_key == self._run.run_spec.ssh_key_pub:
284283
ssh_identity_file = user_key.private_key_path
285284
else:
286285
if config_manager.dstack_key_path.exists():
287-
# TODO: Remove since 0.19.40
288286
logger.debug(f"Using legacy [code]{config_manager.dstack_key_path}[/code].")
289287
ssh_identity_file = config_manager.dstack_key_path
290288
else:
@@ -451,7 +449,7 @@ def get_run_plan(
451449
repo: Optional[Repo] = None,
452450
profile: Optional[Profile] = None,
453451
configuration_path: Optional[str] = None,
454-
repo_dir: Optional[str] = None,
452+
repo_dir: Union[Deprecated, str, None] = Deprecated.PLACEHOLDER,
455453
ssh_identity_file: Optional[PathLike] = None,
456454
) -> RunPlan:
457455
"""
@@ -465,9 +463,10 @@ def get_run_plan(
465463
profile: The profile to use for the run.
466464
configuration_path: The path to the configuration file. Omit if the configuration
467465
is not loaded from a file.
468-
repo_dir: The path of the cloned repo inside the run container. If not set,
469-
defaults first to the `repos[0].path` property of the configuration (for remote
470-
repos only).
466+
ssh_identity_file: Path to the private SSH key file. The corresponding public key
467+
(`.pub` file) is read and included in the run plan, allowing SSH access to the instances.
468+
If the `.pub` file does not exist, it is generated automatically.
469+
If ssh_identity_file is not specified, the user key is used.
471470
472471
Returns:
473472
Run plan.
@@ -479,8 +478,15 @@ def get_run_plan(
479478
with _prepare_code_file(repo) as (_, repo_code_hash):
480479
pass
481480

482-
if repo_dir is None and configuration.repos:
481+
if repo_dir is not Deprecated.PLACEHOLDER:
482+
logger.warning(
483+
"The repo_dir argument is deprecated, ignored, and will be removed soon."
484+
" Remove it and use the repos[].path configuration property instead."
485+
)
486+
if configuration.repos:
483487
repo_dir = configuration.repos[0].path
488+
else:
489+
repo_dir = None
484490

485491
self._validate_configuration_files(configuration, configuration_path)
486492
file_archives: list[FileArchiveMapping] = []
@@ -497,20 +503,7 @@ def get_run_plan(
497503
if ssh_identity_file:
498504
ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text()
499505
else:
500-
config_manager = ConfigManager()
501-
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
502-
if key_manager.get_user_key():
503-
ssh_key_pub = None # using the server-managed user key
504-
else:
505-
if not config_manager.dstack_key_path.exists():
506-
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
507-
logger.warning(
508-
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
509-
" You will only be able to attach to the run from this client."
510-
" Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys"
511-
" automatically replicated to all clients.",
512-
)
513-
ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text()
506+
ssh_key_pub = None # using the server-managed user key
514507
run_spec = RunSpec(
515508
run_name=configuration.name,
516509
repo_id=repo.repo_id,
@@ -587,6 +580,10 @@ def apply_configuration(
587580
profile: The profile to use for the run.
588581
configuration_path: The path to the configuration file. Omit if the configuration is not loaded from a file.
589582
reserve_ports: Reserve local ports before applying. Use if you'll attach to the run.
583+
ssh_identity_file: Path to the private SSH key file. The corresponding public key
584+
(`.pub` file) is read and included in the run plan, allowing SSH access to the instances.
585+
If the `.pub` file does not exist, it is generated automatically.
586+
If ssh_identity_file is not specified, the user key is used.
590587
591588
Returns:
592589
Submitted run.

src/dstack/api/server/_users.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import List
22

3-
from pydantic import ValidationError, parse_obj_as
3+
from pydantic import parse_obj_as
44

55
from dstack._internal.core.models.users import GlobalRole, User, UserWithCreds
66
from dstack._internal.server.schemas.users import (
@@ -17,24 +17,9 @@ def list(self) -> List[User]:
1717
resp = self._request("/api/users/list")
1818
return parse_obj_as(List[User.__response__], resp.json())
1919

20-
def get_my_user(self) -> User:
21-
"""
22-
Returns `User` with pre-0.19.33 servers, or `UserWithCreds` with newer servers.
23-
"""
24-
20+
def get_my_user(self) -> UserWithCreds:
2521
resp = self._request("/api/users/get_my_user")
26-
try:
27-
return parse_obj_as(UserWithCreds.__response__, resp.json())
28-
except ValidationError as e:
29-
# Compatibility with pre-0.19.33 server
30-
if (
31-
len(e.errors()) == 1
32-
and e.errors()[0]["loc"] == ("__root__", "creds")
33-
and e.errors()[0]["type"] == "value_error.missing"
34-
):
35-
return parse_obj_as(User.__response__, resp.json())
36-
else:
37-
raise
22+
return parse_obj_as(UserWithCreds.__response__, resp.json())
3823

3924
def get_user(self, username: str) -> User:
4025
body = GetUserRequest(username=username)

src/tests/_internal/core/services/ssh/test_key_manager.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,6 @@ def set_mtime(path: Path, ts: float):
4444
os.utime(path, (ts, ts))
4545

4646

47-
def test_get_user_key_returns_none_when_no_user_creds(tmp_path: Path):
48-
api_client = make_api_client(
49-
user=User.__response__.parse_obj(SAMPLE_USER.dict()), token_hash=SAMPLE_USER_TOKEN_HASH
50-
)
51-
manager = UserSSHKeyManager(api_client, tmp_path)
52-
53-
assert manager.get_user_key() is None
54-
assert not (tmp_path / SAMPLE_USER_TOKEN_HASH).exists()
55-
assert not (tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub").exists()
56-
57-
5847
def test_get_user_key_downloads_keys(tmp_path: Path):
5948
api_client = make_api_client(user=SAMPLE_USER, token_hash=SAMPLE_USER_TOKEN_HASH)
6049
manager = UserSSHKeyManager(api_client, tmp_path)

0 commit comments

Comments
 (0)