Skip to content

Commit 2f1512e

Browse files
authored
Streamline InstanceModel.remote_connection_info handling (#3566)
1 parent 0155a28 commit 2f1512e

File tree

6 files changed

+38
-35
lines changed

6 files changed

+38
-35
lines changed

src/dstack/_internal/server/background/tasks/process_instances.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import datetime
33
import logging
44
from datetime import timedelta
5-
from typing import Any, Dict, Optional, cast
5+
from typing import Any, Dict, Optional
66

77
import gpuhunt
88
import requests
@@ -86,8 +86,10 @@
8686
get_instance_configuration,
8787
get_instance_profile,
8888
get_instance_provisioning_data,
89+
get_instance_remote_connection_info,
8990
get_instance_requirements,
9091
get_instance_ssh_private_keys,
92+
is_ssh_instance,
9193
remove_dangling_tasks_from_instance,
9294
switch_instance_status,
9395
)
@@ -244,7 +246,7 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
244246
instance = res.unique().scalar_one()
245247

246248
if instance.status == InstanceStatus.PENDING:
247-
if instance.remote_connection_info is not None:
249+
if is_ssh_instance(instance):
248250
await _add_remote(session, instance)
249251
else:
250252
await _create_instance(
@@ -323,7 +325,8 @@ async def _add_remote(session: AsyncSession, instance: InstanceModel) -> None:
323325
return
324326

325327
try:
326-
remote_details = RemoteConnectionInfo.parse_raw(cast(str, instance.remote_connection_info))
328+
remote_details = get_instance_remote_connection_info(instance)
329+
assert remote_details is not None
327330
# Prepare connection key
328331
try:
329332
pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys)

src/dstack/_internal/server/background/tasks/process_running_jobs.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from dstack._internal.core.models.files import FileArchiveMapping
1919
from dstack._internal.core.models.instances import (
2020
InstanceStatus,
21-
RemoteConnectionInfo,
2221
SSHConnectionParams,
2322
)
2423
from dstack._internal.core.models.metrics import Metric
@@ -54,7 +53,10 @@
5453
from dstack._internal.server.services import events, services
5554
from dstack._internal.server.services import files as files_services
5655
from dstack._internal.server.services import logs as logs_services
57-
from dstack._internal.server.services.instances import get_instance_ssh_private_keys
56+
from dstack._internal.server.services.instances import (
57+
get_instance_remote_connection_info,
58+
get_instance_ssh_private_keys,
59+
)
5860
from dstack._internal.server.services.jobs import (
5961
find_job,
6062
get_job_attached_volumes,
@@ -870,14 +872,11 @@ async def _maybe_register_replica(
870872
ssh_head_proxy: Optional[SSHConnectionParams] = None
871873
ssh_head_proxy_private_key: Optional[str] = None
872874
instance = common_utils.get_or_error(job_model.instance)
873-
if instance.remote_connection_info is not None:
874-
rci: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
875-
instance.remote_connection_info
876-
)
877-
if rci.ssh_proxy is not None:
878-
ssh_head_proxy = rci.ssh_proxy
879-
ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
880-
ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
875+
rci = get_instance_remote_connection_info(instance)
876+
if rci is not None and rci.ssh_proxy is not None:
877+
ssh_head_proxy = rci.ssh_proxy
878+
ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
879+
ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
881880
try:
882881
await services.register_replica(
883882
session,
@@ -1090,9 +1089,8 @@ def _submit_job_to_runner(
10901089
None if repo_credentials is None else repo_credentials.clone_url,
10911090
)
10921091
instance = job_model.instance
1093-
if instance is not None and instance.remote_connection_info is not None:
1094-
remote_info = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
1095-
instance_env = remote_info.env
1092+
if instance is not None and (rci := get_instance_remote_connection_info(instance)) is not None:
1093+
instance_env = rci.env
10961094
else:
10971095
instance_env = None
10981096

src/dstack/_internal/server/services/fleets.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import Callable
33
from datetime import datetime
44
from functools import wraps
5-
from typing import List, Literal, Optional, Tuple, TypeVar, Union, cast
5+
from typing import List, Literal, Optional, Tuple, TypeVar, Union
66

77
from sqlalchemy import and_, func, or_, select
88
from sqlalchemy.ext.asyncio import AsyncSession
@@ -32,7 +32,6 @@
3232
InstanceOfferWithAvailability,
3333
InstanceStatus,
3434
InstanceTerminationReason,
35-
RemoteConnectionInfo,
3635
SSHConnectionParams,
3736
SSHKey,
3837
)
@@ -1106,9 +1105,8 @@ async def _check_ssh_hosts_not_yet_added(
11061105
# ignore instances belonging to the same fleet -- in-place update/recreate
11071106
if current_fleet_id is not None and instance.fleet_id == current_fleet_id:
11081107
continue
1109-
instance_conn_info = RemoteConnectionInfo.parse_raw(
1110-
cast(str, instance.remote_connection_info)
1111-
)
1108+
instance_conn_info = get_instance_remote_connection_info(instance)
1109+
assert instance_conn_info is not None
11121110
existing_hosts.add(instance_conn_info.host)
11131111

11141112
instances_already_in_fleet = []

src/dstack/_internal/server/services/instances.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ def get_instance_requirements(instance_model: InstanceModel) -> Requirements:
286286
return Requirements.__response__.parse_raw(instance_model.requirements)
287287

288288

289+
def is_ssh_instance(instance_model: InstanceModel) -> bool:
290+
return instance_model.remote_connection_info is not None
291+
292+
289293
def get_instance_remote_connection_info(
290294
instance_model: InstanceModel,
291295
) -> Optional[RemoteConnectionInfo]:
@@ -299,11 +303,11 @@ def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, O
299303
Returns a pair of SSH private keys: host key and optional proxy jump key.
300304
"""
301305
host_private_key = instance_model.project.ssh_private_key
302-
if instance_model.remote_connection_info is None:
306+
rci = get_instance_remote_connection_info(instance_model)
307+
if rci is None:
303308
# Cloud instance
304309
return host_private_key, None
305310
# SSH instance
306-
rci = RemoteConnectionInfo.__response__.parse_raw(instance_model.remote_connection_info)
307311
if rci.ssh_proxy is None:
308312
return host_private_key, None
309313
if rci.ssh_proxy_keys is None:

src/dstack/_internal/server/services/proxy/repo.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
1010
from dstack._internal.core.models.backends.base import BackendType
1111
from dstack._internal.core.models.configurations import ServiceConfiguration
12-
from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
12+
from dstack._internal.core.models.instances import SSHConnectionParams
1313
from dstack._internal.core.models.runs import (
1414
JobProvisioningData,
1515
JobSpec,
@@ -31,6 +31,7 @@
3131
)
3232
from dstack._internal.proxy.lib.repo import BaseProxyRepo
3333
from dstack._internal.server.models import JobModel, ProjectModel, RunModel
34+
from dstack._internal.server.services.instances import get_instance_remote_connection_info
3435
from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
3536
from dstack._internal.utils.common import get_or_error
3637

@@ -97,11 +98,10 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
9798
ssh_head_proxy: Optional[SSHConnectionParams] = None
9899
ssh_head_proxy_private_key: Optional[str] = None
99100
instance = get_or_error(job.instance)
100-
if instance.remote_connection_info is not None:
101-
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
102-
if rci.ssh_proxy is not None:
103-
ssh_head_proxy = rci.ssh_proxy
104-
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
101+
rci = get_instance_remote_connection_info(instance)
102+
if rci is not None and rci.ssh_proxy is not None:
103+
ssh_head_proxy = rci.ssh_proxy
104+
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
105105
job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
106106
replica = Replica(
107107
id=job.id.hex,

src/dstack/_internal/server/services/ssh.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import dstack._internal.server.services.jobs as jobs_services
55
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
66
from dstack._internal.core.models.backends.base import BackendType
7-
from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
7+
from dstack._internal.core.models.instances import SSHConnectionParams
88
from dstack._internal.core.models.runs import JobProvisioningData
99
from dstack._internal.core.services.ssh.tunnel import SSH_DEFAULT_OPTIONS, SocketPair, SSHTunnel
1010
from dstack._internal.server.models import JobModel
11+
from dstack._internal.server.services.instances import get_instance_remote_connection_info
1112
from dstack._internal.utils.common import get_or_error
1213
from dstack._internal.utils.path import FileContent
1314

@@ -46,11 +47,10 @@ def container_ssh_tunnel(
4647
ssh_head_proxy: Optional[SSHConnectionParams] = None
4748
ssh_head_proxy_private_key: Optional[str] = None
4849
instance = get_or_error(job.instance)
49-
if instance.remote_connection_info is not None:
50-
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
51-
if rci.ssh_proxy is not None:
52-
ssh_head_proxy = rci.ssh_proxy
53-
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
50+
rci = get_instance_remote_connection_info(instance)
51+
if rci is not None and rci.ssh_proxy is not None:
52+
ssh_head_proxy = rci.ssh_proxy
53+
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
5454
ssh_proxies = []
5555
if ssh_head_proxy is not None:
5656
ssh_head_proxy_private_key = get_or_error(ssh_head_proxy_private_key)

0 commit comments

Comments
 (0)