diff --git a/mkdocs/docs/reference/env.md b/mkdocs/docs/reference/env.md index 1b81629109..086ea80ad8 100644 --- a/mkdocs/docs/reference/env.md +++ b/mkdocs/docs/reference/env.md @@ -141,6 +141,7 @@ For more details on the options below, refer to the [server deployment](../guide - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY } – A default Docker registry to use for job images that do not specify an explicit registry. E.g., if set to `registry.example`, then `image: ubuntu` becomes equivalent to `image: registry.example/ubuntu`. **Note**: This setting should only be used for configuring registries that act as a pull-through cache for Docker Hub. The default `dstack` images are also pulled from the configured registry. - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME } – Username for authenticating with the default Docker registry. See `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`. - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD } – Password for authenticating with the default Docker registry. Applied only when the image has no explicit registry and the run configuration does not specify `registry_auth`. **Note**: The value may be visible to anyone who can SSH into instances managed by `dstack`, which usually includes all users of that `dstack` server. +- `DSTACK_SERVER_SSH_CONNECT_TIMEOUT`{ #DSTACK_SERVER_SSH_CONNECT_TIMEOUT } – The SSH `ConnectTimeout` for server-instance connections, in seconds. Defaults to `3`. Increase if there are high-latency links between the server and instances. ??? info "Internal environment variables" The following environment variables are intended for development purposes: diff --git a/pyproject.toml b/pyproject.toml index 4f09349ced..2a4e8620d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -186,6 +186,7 @@ server = [ "aiorwlock", "aiocache", "httpx>=0.28.0", + "requests-unixsocket>=0.4.1", "jinja2", "watchfiles", "sqlalchemy[asyncio]>=2.0.0", diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py index 9fede91111..f4d6a17f70 100644 --- a/src/dstack/_internal/core/services/ssh/tunnel.py +++ b/src/dstack/_internal/core/services/ssh/tunnel.py @@ -252,6 +252,12 @@ async def aclose(self) -> None: proc.stdout, ) + def check(self) -> bool: + proc = subprocess.run( + self.check_command(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + return proc.returncode == 0 + async def acheck(self) -> bool: proc = await asyncio.create_subprocess_exec( *self.check_command(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 7de6e74059..8b9e044776 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -57,6 +57,7 @@ from dstack._internal.server.services.projects import get_or_create_default_project from dstack._internal.server.services.proxy.deps import ServerProxyDependencyInjector from dstack._internal.server.services.proxy.routers import service_proxy +from dstack._internal.server.services.runner.pool import instance_connection_pool from dstack._internal.server.services.storage import init_default_storage from dstack._internal.server.services.users import get_or_create_admin_user from dstack._internal.server.settings import ( @@ -75,6 +76,7 @@ get_client_version, get_server_client_error_details, ) +from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger from dstack._internal.utils.ssh import check_required_ssh_version @@ -167,6 +169,8 @@ async def lifespan(app: FastAPI): ) if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None: init_default_storage() + if settings.SERVER_SSH_POOL_ENABLED: + await run_async(instance_connection_pool.startup_cleanup) scheduler = None pipeline_manager = None if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: @@ -209,6 +213,8 @@ async def lifespan(app: FastAPI): await gateway_connections_pool.remove_all() service_conn_pool = await get_injector_from_app(app).get_service_connection_pool() await service_conn_pool.remove_all() + if settings.SERVER_SSH_POOL_ENABLED: + await run_async(instance_connection_pool.close_all) await get_db().engine.dispose() # Let checked-out DB connections close as dispose() only closes checked-in connections await asyncio.sleep(3) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py index d23d536cd1..486c83dbf6 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -1,5 +1,6 @@ import logging import uuid +from collections.abc import Mapping from datetime import timedelta from typing import Optional @@ -373,15 +374,15 @@ async def _get_backend_for_provisioning_wait( ) -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +@runner_ssh_tunnel def _check_instance_inner( - ports: dict[int, int], + addresses: Mapping[int, runner_client.LocalAddress], *, instance: InstanceModel, check_instance_health: bool = False, ) -> InstanceCheck: instance_health_response: Optional[InstanceHealthResponse] = None - shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + shim_client = runner_client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) method = shim_client.healthcheck try: healthcheck_response = method(unmask_exceptions=True) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py index eb1f3c8a39..a4bf6d3294 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py @@ -11,6 +11,10 @@ from dstack._internal.server.models import InstanceModel from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.instances import get_instance_provisioning_data +from dstack._internal.server.services.runner.pool import ( + InstanceConnectionKey, + instance_connection_pool, +) from dstack._internal.utils.common import get_current_datetime, run_async from dstack._internal.utils.logging import get_logger @@ -77,6 +81,9 @@ async def terminate_instance(instance_model: InstanceModel) -> ProcessResult: exc_info=not isinstance(exc, BackendError), ) + if job_provisioning_data is not None: + instance_connection_pool.drop(InstanceConnectionKey.from_jpd(job_provisioning_data)) + result.instance_update_map["deleted"] = True result.instance_update_map["deleted_at"] = NOW_PLACEHOLDER result.instance_update_map["finished_at"] = NOW_PLACEHOLDER diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 068add9a63..98e5967cb8 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -1,6 +1,7 @@ import asyncio import enum import uuid +from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Dict, Iterable, Literal, Optional, Sequence, Union @@ -1308,9 +1309,9 @@ def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> boo return False -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +@runner_ssh_tunnel def _process_provisioning_with_shim( - ports: Dict[int, int], + addresses: Mapping[int, client.LocalAddress], run: Run, job_model: JobModel, jrd: Optional[JobRuntimeData], @@ -1322,7 +1323,7 @@ def _process_provisioning_with_shim( ssh_key: Optional[str], ) -> bool: job_spec = get_job_spec(job_model) - shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) resp = shim_client.healthcheck() if resp is None: @@ -1435,21 +1436,21 @@ class _SyncShimPullingStateResult: image_pull_progress: Optional[ImagePullProgress] = None -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) -def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability: - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) +@runner_ssh_tunnel +def _get_runner_availability(addresses: Mapping[int, client.LocalAddress]) -> _RunnerAvailability: + runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) if runner_client.healthcheck() is None: return _RunnerAvailability.UNAVAILABLE return _RunnerAvailability.AVAILABLE -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) +@runner_ssh_tunnel def _sync_shim_pulling_state( - ports: Dict[int, int], + addresses: Mapping[int, client.LocalAddress], job_model: JobModel, jrd: Optional[JobRuntimeData] = None, ) -> Union[_SyncShimPullingStateResult, Literal[False]]: - shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) image_pull_progress: Optional[ImagePullProgress] = None if shim_client.is_api_v2_supported(): task = shim_client.get_task(job_model.id) @@ -1525,9 +1526,9 @@ class _SubmitJobToRunnerResult: job_runtime_data: Optional[JobRuntimeData] = None -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +@runner_ssh_tunnel def _submit_job_to_runner( - ports: Dict[int, int], + addresses: Mapping[int, client.LocalAddress], run: Run, job_model: JobModel, job: Job, @@ -1552,7 +1553,7 @@ def _submit_job_to_runner( else: instance_env = None - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) if runner_client.healthcheck() is None: return _SubmitJobToRunnerResult(success=success_if_not_available) @@ -1595,13 +1596,13 @@ class _ProcessRunningResult: job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT]) +@runner_ssh_tunnel def _process_running( - ports: Dict[int, int], + addresses: Mapping[int, client.LocalAddress], run_model: RunModel, job_model: JobModel, ) -> Union[_ProcessRunningResult, Literal[False]]: - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) timestamp = job_model.runner_timestamp or 0 resp = runner_client.pull(timestamp) logs_services.write_logs( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index e15c24db57..fe2e64ca4e 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -1,5 +1,6 @@ import asyncio import uuid +from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Optional, Sequence, TypedDict @@ -852,9 +853,9 @@ async def _stop_container( return True -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) -def _shim_submit_stop(ports: dict[int, int], job_model: JobModel) -> bool: - shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) +@runner_ssh_tunnel +def _shim_submit_stop(addresses: Mapping[int, client.LocalAddress], job_model: JobModel) -> bool: + shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) resp = shim_client.healthcheck() if resp is None: diff --git a/src/dstack/_internal/server/background/scheduled_tasks/metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py index f75c5f3eae..1febe7fa52 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/metrics.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py @@ -1,6 +1,7 @@ import asyncio import json -from typing import Dict, List, Optional +from collections.abc import Mapping +from typing import List, Optional from sqlalchemy import Delete, delete, select from sqlalchemy.orm import joinedload @@ -164,9 +165,9 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint] ) -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +@runner_ssh_tunnel def _pull_runner_metrics( - ports: Dict[int, int], + addresses: Mapping[int, client.LocalAddress], ) -> Optional[MetricsResponse]: - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) return runner_client.get_metrics() diff --git a/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py index 5b039fe2ec..96b8cb7742 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py @@ -1,4 +1,5 @@ import uuid +from collections.abc import Mapping from datetime import datetime, timedelta from typing import Optional @@ -144,7 +145,9 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[str]: return res -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) -def _pull_job_metrics(ports: dict[int, int], task_id: uuid.UUID) -> Optional[str]: - shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) +@runner_ssh_tunnel +def _pull_job_metrics( + addresses: Mapping[int, client.LocalAddress], task_id: uuid.UUID +) -> Optional[str]: + shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) return shim_client.get_task_metrics(task_id) diff --git a/src/dstack/_internal/server/services/gateways/connection.py b/src/dstack/_internal/server/services/gateways/connection.py index b8df322a1d..dada5bea64 100644 --- a/src/dstack/_internal/server/services/gateways/connection.py +++ b/src/dstack/_internal/server/services/gateways/connection.py @@ -1,9 +1,7 @@ import contextlib import shutil import uuid -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import AsyncIterator, Optional, Tuple +from typing import AsyncIterator, Optional import aiorwlock @@ -22,7 +20,7 @@ from dstack._internal.server.services.gateways.client import GatewayClient from dstack._internal.server.settings import SERVER_DIR_PATH from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.path import FileContent +from dstack._internal.utils.path import FileContent, make_tmp_symlink_to_dir logger = get_logger(__name__) CONNECTIONS_DIR = SERVER_DIR_PATH / "gateway-connections" @@ -47,7 +45,9 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int): self.connection_dir = CONNECTIONS_DIR / ip_address # connection_dir can have a long path that won't be accepted by the ssh command, # so we create a short temporary symlink - self.temp_dir, self.connection_symlink_dir = self._init_symlink_dir(self.connection_dir) + self.temp_dir, self.connection_symlink_dir = make_tmp_symlink_to_dir( + self.connection_dir, "connection" + ) self.gateway_socket_path = self.connection_symlink_dir / "gateway.sock" self.tunnel = SSHTunnel( destination=f"ubuntu@{ip_address}", @@ -69,13 +69,6 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int): self.tunnel_id = uuid.uuid4() self._client = GatewayClient(uds=str(self.gateway_socket_path)) - @staticmethod - def _init_symlink_dir(connection_dir: Path) -> Tuple[TemporaryDirectory, Path]: - temp_dir = TemporaryDirectory() - symlink_dir = Path(temp_dir.name) / "connection" - symlink_dir.symlink_to(connection_dir, target_is_directory=True) - return temp_dir, symlink_dir - async def check_or_restart(self) -> bool: async with self._lock.writer_lock: if not await self.tunnel.acheck(): diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 5dc0699113..2d149ab77b 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -347,13 +347,13 @@ async def stop_runner(job_model: JobModel, instance_model: InstanceModel): logger.debug("%s: failed to stop runner", fmt(job_model)) -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT]) +@runner_ssh_tunnel def _stop_runner( - ports: dict[int, int], + addresses: Mapping[int, client.LocalAddress], job_model: JobModel, ): logger.debug("%s: stopping runner", fmt(job_model)) - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) try: runner_client.stop() except requests.RequestException: diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 6a1c541856..7ccc2b1af7 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -1,11 +1,14 @@ +import urllib.parse import uuid from collections.abc import Generator from http import HTTPStatus +from pathlib import Path from typing import BinaryIO, Dict, List, Literal, Optional, TypeVar, Union, overload import packaging.version import requests import requests.exceptions +import requests_unixsocket from typing_extensions import Self from dstack._internal.core.errors import DstackError @@ -42,12 +45,16 @@ ) from dstack._internal.utils.common import get_or_error from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.path import PathLike REQUEST_TIMEOUT = 9 UPLOAD_CODE_REQUEST_TIMEOUT = 60 logger = get_logger(__name__) +LocalAddress = Union[int, Path] +"""A local TCP port or a Unix domain socket path the client connects to.""" + class RunnerClient: # `/api/upload_code` call is not required if there is no code @@ -59,12 +66,20 @@ class RunnerClient: def __init__( self, - port: int, + port: Optional[int] = None, hostname: str = "localhost", + uds: Optional[PathLike] = None, ): - self.secure = False - self.hostname = hostname - self.port = port + self._session, self._base_url = _make_session_and_base_url(port, hostname, uds) + + @classmethod + def from_address(cls, address: LocalAddress) -> Self: + """ + Builds a client from a TCP port (`int`) or a Unix domain socket path (`Path`). + """ + if isinstance(address, int): + return cls(port=address) + return cls(uds=address) def get_version_string(self) -> str: if not self._negotiated: @@ -90,7 +105,7 @@ def healthcheck(self) -> Optional[HealthcheckResponse]: return healthcheck_response def get_metrics(self) -> Optional[MetricsResponse]: - resp = requests.get(self._url("/api/metrics"), timeout=REQUEST_TIMEOUT) + resp = self._session.get(self._url("/api/metrics"), timeout=REQUEST_TIMEOUT) if resp.status_code == 404: return None resp.raise_for_status() @@ -134,7 +149,7 @@ def submit_job( log_quota_hour=quota if quota > 0 else None, run_spec=run.run_spec, ) - resp = requests.post( + resp = self._session.post( # use .json() to encode enums self._url("/api/submit"), data=body.json(), @@ -144,7 +159,7 @@ def submit_job( resp.raise_for_status() def upload_archive(self, id: uuid.UUID, file: Union[BinaryIO, bytes]): - resp = requests.post( + resp = self._session.post( self._url("/api/upload_archive"), files={"archive": (str(id), file)}, timeout=UPLOAD_CODE_REQUEST_TIMEOUT, @@ -152,13 +167,13 @@ def upload_archive(self, id: uuid.UUID, file: Union[BinaryIO, bytes]): resp.raise_for_status() def upload_code(self, file: Union[BinaryIO, bytes]): - resp = requests.post( + resp = self._session.post( self._url("/api/upload_code"), data=file, timeout=UPLOAD_CODE_REQUEST_TIMEOUT ) resp.raise_for_status() def run_job(self) -> Optional[JobInfoResponse]: - resp = requests.post(self._url("/api/run"), timeout=REQUEST_TIMEOUT) + resp = self._session.post(self._url("/api/run"), timeout=REQUEST_TIMEOUT) resp.raise_for_status() if not _is_json_response(resp): # Old runner or runner failed to get job info @@ -166,21 +181,21 @@ def run_job(self) -> Optional[JobInfoResponse]: return JobInfoResponse.__response__.parse_obj(resp.json()) def pull(self, timestamp: int) -> PullResponse: - resp = requests.get( + resp = self._session.get( self._url("/api/pull"), params={"timestamp": timestamp}, timeout=REQUEST_TIMEOUT ) resp.raise_for_status() return PullResponse.__response__.parse_obj(resp.json()) def stop(self): - resp = requests.post(self._url("/api/stop"), timeout=REQUEST_TIMEOUT) + resp = self._session.post(self._url("/api/stop"), timeout=REQUEST_TIMEOUT) resp.raise_for_status() def _url(self, path: str) -> str: - return f"{'https' if self.secure else 'http'}://{self.hostname}:{self.port}/{path.lstrip('/')}" + return f"{self._base_url}/{path.lstrip('/')}" def _healthcheck(self) -> HealthcheckResponse: - resp = requests.get(self._url("/api/healthcheck"), timeout=REQUEST_TIMEOUT) + resp = self._session.get(self._url("/api/healthcheck"), timeout=REQUEST_TIMEOUT) resp.raise_for_status() return HealthcheckResponse.__response__.parse_obj(resp.json()) @@ -302,11 +317,20 @@ class ShimClient: def __init__( self, - port: int, + port: Optional[int] = None, hostname: str = "localhost", + uds: Optional[PathLike] = None, ): - self._session = requests.Session() - self._base_url = f"http://{hostname}:{port}" + self._session, self._base_url = _make_session_and_base_url(port, hostname, uds) + + @classmethod + def from_address(cls, address: LocalAddress) -> Self: + """ + Builds a client from a TCP port (`int`) or a Unix domain socket path (`Path`). + """ + if isinstance(address, int): + return cls(port=address) + return cls(uds=address) # Methods shared by all API versions @@ -626,6 +650,24 @@ def _get_restart_safe_task_statuses(self) -> list[TaskStatus]: return [TaskStatus.TERMINATED] +def _make_session_and_base_url( + port: Optional[int], hostname: str, uds: Optional[PathLike] +) -> tuple[requests.Session, str]: + """ + Builds a session and base URL for HTTP over TCP (`port`) or over + a Unix domain socket (`uds`). Exactly one of the two must be specified. + """ + if (port is None) == (uds is None): + raise ValueError("Either port or uds must be specified, not both") + session = requests.Session() + if uds is not None: + base_url = f"http+unix://{urllib.parse.quote(str(uds), safe='')}" + session.mount("http+unix://", requests_unixsocket.UnixAdapter()) + else: + base_url = f"http://{hostname}:{port}" + return session, base_url + + def healthcheck_response_to_instance_check( response: HealthcheckResponse, instance_health_response: Optional[InstanceHealthResponse] = None, diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py new file mode 100644 index 0000000000..b91d1d3125 --- /dev/null +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -0,0 +1,361 @@ +import os +import shutil +import threading +import time +from dataclasses import dataclass +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Collection, Optional, Union +from weakref import WeakValueDictionary + +from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import SSHError +from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.runs import JobProvisioningData, JobRuntimeData +from dstack._internal.core.services.ssh.tunnel import ( + SSH_DEFAULT_OPTIONS, + IPSocket, + SocketPair, + SSHTunnel, + UnixSocket, +) +from dstack._internal.server.settings import ( + SERVER_DIR_PATH, + SERVER_SSH_CONNECT_TIMEOUT, +) +from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.path import FileContent, make_tmp_symlink_to_dir + +logger = get_logger(__name__) + +PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]] +"""A host private key or pair of (host private key, optional proxy jump private key)""" + +CONNECTIONS_DIR = SERVER_DIR_PATH / "instance-connections" + +MIN_ALIVE_CHECK_INTERVAL = 30 +"""How often (at most) `InstanceConnection.is_alive()` runs `ssh -O check`, in seconds.""" + + +@dataclass(frozen=True) +class InstanceConnectionKey: + hostname: str + port: int + ports_to_forward: tuple[int, ...] + + @staticmethod + def from_jpd( + jpd: JobProvisioningData, jrd: Optional[JobRuntimeData] = None + ) -> "InstanceConnectionKey": + assert jpd.hostname is not None and jpd.ssh_port is not None + container_to_host_port_map = InstanceConnection.get_container_to_host_port_map(jpd, jrd) + return InstanceConnectionKey( + hostname=jpd.hostname, + port=jpd.ssh_port, + ports_to_forward=tuple(container_to_host_port_map.values()), + ) + + +# InstanceConnectionPool has sync interface because runner/shim clients and all the callers are sync. +# TODO: Consider moving all of them to async for consistency with other pools/clients. +class InstanceConnectionPool: + """ + A pool of SSH connections to instances' host sshd (VM-based) + or runner sshd (container-based) for forwarding shim and runner ports. + + NOTE: The pool is not currently intended for arbitrary ports forwarding, only for shim and runner ports. + E.g. it cannot be used to forward services ports for probes or router-worker communication. + This simplified model allows forwarding the same ports for the given host:port and reusing the connection across all calls. + TODO: Generalize to support arbitrary ports forwarding incl. job's ports. + + Incompatible with multiple server processes sharing the same server dir: + connection dirs and control sockets are assumed to be owned by a single process. + """ + + def __init__(self): + self._connections: dict[InstanceConnectionKey, InstanceConnection] = {} + # Use `WeakValueDictionary` for automatic GC of unused locks and avoid manual refcounting. + # A lock is expected to exist only while a thread holds a strong reference to it. + self._access_locks: WeakValueDictionary[InstanceConnectionKey, threading.Lock] = ( + WeakValueDictionary() + ) + self._access_locks_lock = threading.Lock() + self._closed = False + + def get_or_open( + self, + ssh_private_key: PrivateKeyOrPair, + jpd: JobProvisioningData, + jrd: Optional[JobRuntimeData], + ) -> Optional["InstanceConnection"]: + """ + Starts a new SSH connection or returns an existing one. + Existing connections are checked for health periodically + so that subsequent calls to `get_or_open()` eventually return a healthy connection. + """ + key = InstanceConnectionKey.from_jpd(jpd, jrd) + lock = self._get_access_lock(key) + with lock: + if self._closed: + return None + conn = self._connections.get(key) + if conn is not None: + if conn.is_alive(): + return conn + # The master process is gone — evict and reopen. + logger.debug("Instance connection %s is dead, reopening", key) + self._connections.pop(key) + try: + conn.close() + except Exception: + logger.exception("Failed to close instance connection %s", key) + try: + conn = InstanceConnection(ssh_private_key, jpd, jrd) + conn.open() + except SSHError: + # error logged in tunnel + return None + self._connections[key] = conn + return conn + + def drop(self, key: InstanceConnectionKey) -> None: + lock = self._get_access_lock(key) + with lock: + try: + conn = self._connections.pop(key) + except KeyError: + return + try: + conn.close() + except Exception: + logger.exception("Failed to close instance connection %s", key) + + def startup_cleanup(self) -> None: + """ + Removes connection dirs left by a previous server process (e.g. after SIGKILL). + Must be called on server startup before the pool is used. + Leftover live masters are reaped by `ControlPersist`. + """ + shutil.rmtree(CONNECTIONS_DIR, ignore_errors=True) + + def close_all(self) -> None: + """ + Closes all connections and prevents new ones from being opened. + Safe to call concurrently with in-flight `get_or_open()` calls. + `get_or_open()` will return `None` after `close_all()`. + """ + with self._access_locks_lock: + self._closed = True + # self._connections holds cached connections, and + # self._access_locks may hold mid-open connections not yet cached. + keys = set(self._connections) | set(self._access_locks.keys()) + logger.debug("Closing %d instance connection(s)", len(keys)) + for key in keys: + self.drop(key) + + def _get_access_lock(self, key: InstanceConnectionKey) -> threading.Lock: + with self._access_locks_lock: + lock = self._access_locks.get(key) + if lock is not None: + return lock + lock = threading.Lock() + self._access_locks[key] = lock + return lock + + +instance_connection_pool = InstanceConnectionPool() + + +class InstanceConnection: + """ + An SSH connection to instance's host sshd (VM-based) + or runner sshd (container-based) for forwarding shim and runner ports. + + The same control socket is used for all connections to the same hostname:port, + unless jrd overrides the runner port mapped on host (blocks case). + In case of blocks, each job establishes a separate connection with a different runner port forwarded. + TODO: Re-use the same SSH connection for all blocks via `-O forward`/`-O cancel`. + """ + + def __init__( + self, + ssh_private_key: PrivateKeyOrPair, + jpd: JobProvisioningData, + jrd: Optional[JobRuntimeData], + ephemeral: bool = False, + ) -> None: + """ + Args: + ephemeral: Creates a unique tmp dir for the UDS. Use when connection re-use is not needed. + """ + self._key = InstanceConnectionKey.from_jpd(jpd, jrd) + self._ephemeral = ephemeral + self._last_verified_at: float = 0.0 + self._temp_dir, self._effective_conn_dir, self._real_conn_dir = ( + InstanceConnection._resolve_conn_dir(self._key, ephemeral) + ) + self._control_socket_path = self._effective_conn_dir / "control.sock" + self._real_control_socket_path = self._real_conn_dir / "control.sock" + self._container_to_host_port_map = InstanceConnection.get_container_to_host_port_map( + jpd, jrd + ) + self._host_port_to_uds_map = InstanceConnection._get_host_port_to_uds_map( + conn_dir=self._effective_conn_dir, + ports_to_forward=self._key.ports_to_forward, + ) + self._tunnel = SSHTunnel( + destination=f"{jpd.username}@{jpd.hostname}", + port=jpd.ssh_port, + identity=InstanceConnection._get_identity(ssh_private_key), + control_sock_path=self._control_socket_path, + forwarded_sockets=self._get_forwarded_sockets(self._host_port_to_uds_map), + ssh_proxies=InstanceConnection._get_proxies(ssh_private_key, jpd), + options={ + **SSH_DEFAULT_OPTIONS, + "ConnectTimeout": str(SERVER_SSH_CONNECT_TIMEOUT), + # Auto-close half-opened connections (the instance not responding). + "ServerAliveInterval": "10", + "ServerAliveCountMax": "3", + # Set ControlPersist to auto-close orphaned background ssh process + # in case dstack server shutdown is not graceful. + "ControlPersist": "2m", + }, + batch_mode=True, + ) + + def open(self) -> None: + # A control socket left by a killed master or by a master that exited after + # its tmp symlink was deleted prevents ssh from becoming a mux master + # ("ControlSocket ... already exists, disabling multiplexing"). + # Remove it unless it's served by a live master (then open() attaches to it). + if self._real_control_socket_path.exists() and not self._tunnel.check(): + self._real_control_socket_path.unlink(missing_ok=True) + self._tunnel.open() + self._last_verified_at = time.monotonic() + + def is_alive(self) -> bool: + """ + Verifies that the connection's SSH master process is alive: + + 1. The control socket exists (a stat). Catches cleanly exited masters (incl. ControlPersist). + 2. `ssh -O check`. Catches killed masters that left a stale socket file behind. + Rate-limited to once per `MIN_ALIVE_CHECK_INTERVAL`. + + Does not detect half-open TCP (ServerAliveInterval converts it into a clean exit) + or mid-request deaths (handled by the callers' drop-on-error pattern). + """ + if not self._control_socket_path.exists(): + return False + now = time.monotonic() + if now - self._last_verified_at < MIN_ALIVE_CHECK_INTERVAL: + return True + if not self._tunnel.check(): + return False + # Keep the symlink fresh so that age-based /tmp cleanup is less likely to remove it. + try: + os.utime(self._effective_conn_dir, follow_symlinks=False) + except OSError: + pass + self._last_verified_at = now + return True + + def forwarded_paths(self) -> dict[int, Path]: + """Returns a mapping from container port to the local UDS path.""" + return { + container_port: self._host_port_to_uds_map[host_port] + for container_port, host_port in self._container_to_host_port_map.items() + } + + def close(self) -> None: + self._tunnel.close() + # Remove a stale control.sock left by a killed master, forwarded UDS files + # (ssh does not unlink them on exit), and the dir itself, so that + # CONNECTIONS_DIR does not accumulate dirs of gone instances. + # A master that survives close() because it is unreachable via a deleted + # symlink is reaped by ControlPersist. + shutil.rmtree(self._real_conn_dir, ignore_errors=True) + + @property + def key(self) -> InstanceConnectionKey: + return self._key + + @staticmethod + def get_container_to_host_port_map( + jpd: JobProvisioningData, + jrd: Optional[JobRuntimeData], + ) -> dict[int, int]: + runner_host_port = DSTACK_RUNNER_HTTP_PORT + if jrd is not None and jrd.ports is not None: + runner_host_port = jrd.ports.get(DSTACK_RUNNER_HTTP_PORT, runner_host_port) + port_map = {DSTACK_RUNNER_HTTP_PORT: runner_host_port} + if jpd.dockerized: + port_map[DSTACK_SHIM_HTTP_PORT] = DSTACK_SHIM_HTTP_PORT + return port_map + + @staticmethod + def _resolve_conn_dir( + key: InstanceConnectionKey, ephemeral: bool + ) -> tuple[TemporaryDirectory, Path, Path]: + """ + Returns (temp dir to retain, dir to be used by ssh, real conn dir). + """ + if ephemeral: + temp_dir = TemporaryDirectory() + path = Path(temp_dir.name) + return temp_dir, path, path + + conn_dir = ( + CONNECTIONS_DIR + / f"{key.hostname}:{key.port},{','.join(map(str, key.ports_to_forward))}" + ) + conn_dir.mkdir(parents=True, exist_ok=True) + # Connection_dir can have a long path that won't be accepted by the ssh command, + # so we create a short temporary symlink. + # The symlink may be removed by age-based /tmp cleanup while the connection is still alive. + # The connection dir will be removed and the connection is re-opened. + temp_dir, conn_symlink_dir = make_tmp_symlink_to_dir( + dirpath=conn_dir, + symlink_dirname="connection", + ) + return temp_dir, conn_symlink_dir, conn_dir + + @staticmethod + def _get_host_port_to_uds_map( + conn_dir: Path, + ports_to_forward: Collection[int], + ) -> dict[int, Path]: + return {port: conn_dir / f"{port}.sock" for port in ports_to_forward} + + @staticmethod + def _get_forwarded_sockets(host_port_to_uds_map: dict[int, Path]) -> list[SocketPair]: + return [ + SocketPair( + local=UnixSocket(path=path), + remote=IPSocket(host="localhost", port=port), + ) + for port, path in host_port_to_uds_map.items() + ] + + @staticmethod + def _get_identity(ssh_private_key: PrivateKeyOrPair) -> FileContent: + if isinstance(ssh_private_key, tuple): + ssh_private_key, _ = ssh_private_key + return FileContent(ssh_private_key) + + @staticmethod + def _get_proxies( + ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData + ) -> list[tuple[SSHConnectionParams, FileContent]]: + if jpd.ssh_proxy is None: + return [] + + if isinstance(ssh_private_key, str): + ssh_proxy_private_key = ssh_private_key + else: + ssh_proxy_private_key = ssh_private_key[1] + if ssh_proxy_private_key is None: + # In case proxy key is None, fallback to main key (k8s case). + ssh_proxy_private_key = ssh_private_key[0] + + proxy_identity = FileContent(ssh_proxy_private_key) + return [(jpd.ssh_proxy, proxy_identity)] diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index a4ef986862..b1430fba6f 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -1,8 +1,6 @@ import functools -import socket -import time -from collections.abc import Iterable -from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union +from collections.abc import Mapping +from typing import Callable, Literal, Optional, TypeVar, Union import requests from typing_extensions import Concatenate, ParamSpec @@ -10,120 +8,100 @@ from dstack._internal.core.errors import DstackError, SSHError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.runs import JobProvisioningData, JobRuntimeData -from dstack._internal.core.services.ssh.tunnel import SSHTunnel, ports_to_forwarded_sockets -from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.path import FileContent +from dstack._internal.server import settings +from dstack._internal.server.services.runner.client import LocalAddress +from dstack._internal.server.services.runner.pool import ( + InstanceConnection, + PrivateKeyOrPair, + instance_connection_pool, +) -logger = get_logger(__name__) P = ParamSpec("P") R = TypeVar("R") -# A host private key or pair of (host private key, optional proxy jump private key) -PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]] def runner_ssh_tunnel( - ports: List[int], retries: int = 3, retry_interval: float = 1 + func: Callable[Concatenate[Mapping[int, LocalAddress], P], R], ) -> Callable[ - [Callable[Concatenate[Dict[int, int], P], R]], - Callable[ - Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], - Union[Literal[False], R], - ], + Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], + Union[Literal[False], R], ]: """ - A decorator that opens an SSH tunnel to the runner. + A decorator that opens an SSH tunnel to the runner instance for port forwarding. - NOTE: connections from dstack-server to running jobs are expected to be short. - The runner uses a heuristic to differentiate dstack-server connections from - client connections based on their duration. See `ConnectionTracker` for details. - """ - - def decorator( - func: Callable[Concatenate[Dict[int, int], P], R], - ) -> Callable[ - Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], - Union[Literal[False], R], - ]: - @functools.wraps(func) - def wrapper( - ssh_private_key: PrivateKeyOrPair, - job_provisioning_data: JobProvisioningData, - job_runtime_data: Optional[JobRuntimeData], - *args: P.args, - **kwargs: P.kwargs, - ) -> Union[Literal[False], R]: - """ - Returns: - is successful - """ - # container:host mapping - container_ports_map = {port: port for port in ports} - if job_runtime_data is not None and job_runtime_data.ports is not None: - container_ports_map.update(job_runtime_data.ports) - - if job_provisioning_data.backend == BackendType.LOCAL: - # without SSH - return func(container_ports_map, *args, **kwargs) + Forwarded ports: + * VM-based backends: forward the shim and runner ports. + * Container-based backends: forward only the runner port. + * `jrd.ports` may remap the runner port (blocks case). - if isinstance(ssh_private_key, str): - ssh_proxy_private_key = None - else: - ssh_private_key, ssh_proxy_private_key = ssh_private_key - identity = FileContent(ssh_private_key) - if ssh_proxy_private_key is not None: - proxy_identity = FileContent(ssh_proxy_private_key) - else: - proxy_identity = None + Always forwards the same ports for the given instance/job so that connection is reused across all calls. + In case of blocks, each job uses a separate connection as the runner host port differs. - ssh_proxies = [] - if job_provisioning_data.ssh_proxy is not None: - ssh_proxies.append((job_provisioning_data.ssh_proxy, proxy_identity)) - - for attempt in range(retries): - last = attempt == retries - 1 - # remote_host:local mapping - tunnel_ports_map = _reserve_ports(container_ports_map.values()) - runner_ports_map = { - container_port: tunnel_ports_map[host_port] - for container_port, host_port in container_ports_map.items() - } - try: - with SSHTunnel( - destination=( - f"{job_provisioning_data.username}@{job_provisioning_data.hostname}" - ), - port=job_provisioning_data.ssh_port, - forwarded_sockets=ports_to_forwarded_sockets(tunnel_ports_map), - identity=identity, - ssh_proxies=ssh_proxies, - batch_mode=True, - ): - return func(runner_ports_map, *args, **kwargs) - except SSHError: - pass # error is logged in the tunnel - except (DstackError, requests.RequestException) as e: - if last: - logger.debug( - "Cannot connect to %s's API: %s", job_provisioning_data.hostname, e - ) - if not last: - time.sleep(retry_interval) - return False + There are no retries: a transient transport failure fails the call, + and the callers must retry. In high-latency setups, tune `DSTACK_SERVER_SSH_CONNECT_TIMEOUT`. + """ - return wrapper + @functools.wraps(func) + def wrapper( + ssh_private_key: PrivateKeyOrPair, + job_provisioning_data: JobProvisioningData, + job_runtime_data: Optional[JobRuntimeData], + *args: P.args, + **kwargs: P.kwargs, + ) -> Union[Literal[False], R]: + """ + Returns: + is successful + """ + if job_provisioning_data.backend == BackendType.LOCAL: + # without SSH + port_map = InstanceConnection.get_container_to_host_port_map( + job_provisioning_data, job_runtime_data + ) + return func(port_map, *args, **kwargs) - return decorator + if not settings.SERVER_SSH_POOL_ENABLED or not job_provisioning_data.dockerized: + # Connections from dstack-server to runner's sshd are expected to be short + # as the `inactivity_duration` feature distinguishes user and server connections based on duration. + # Do not re-use SSH connections for container-based backends. + # TODO: Drop `inactivity_duration` dependence on connection duration and re-use connections. + try: + conn = InstanceConnection( + ssh_private_key=ssh_private_key, + jpd=job_provisioning_data, + jrd=job_runtime_data, + ephemeral=True, + ) + conn.open() + except SSHError: + return False + try: + return func(conn.forwarded_paths(), *args, **kwargs) + except (DstackError, requests.RequestException): + return False + finally: + conn.close() + # First try a cached connection and, if it's dead, a new connection. + # Connections already cover against + # a) cleanly-exited master (ControlPersist reap); and + # b) stale control socket file left by killed master. + # (Because we cannot rely solely on connection errors from `func` – it may swallow the errors.) + # but we still want a fast retry in case master dies mid-request. + for _ in range(2): + conn = instance_connection_pool.get_or_open( + ssh_private_key=ssh_private_key, + jpd=job_provisioning_data, + jrd=job_runtime_data, + ) + if conn is None: + return False # couldn't establish at all + try: + return func(conn.forwarded_paths(), *args, **kwargs) + except (SSHError, requests.ConnectionError): + instance_connection_pool.drop(conn.key) # dead ssh connection, re-open + except (DstackError, requests.RequestException): + return False # reached runner, app-level fail; don't re-open ssh connection + return False -def _reserve_ports(ports: Iterable[int]) -> dict[int, int]: - sockets = [] - try: - for port in ports: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("localhost", 0)) # Bind to a free port provided by the host - sockets.append((port, s)) - return {port: s.getsockname()[1] for port, s in sockets} - finally: - for _, s in sockets: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.close() + return wrapper diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index f90aee339d..2845687e23 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -20,6 +20,7 @@ SERVER_DATA_DIR_PATH = SERVER_DIR_PATH / "data" SERVER_DATA_DIR_PATH.mkdir(parents=True, exist_ok=True) + DATABASE_URL = os.getenv( "DSTACK_DATABASE_URL", f"sqlite+aiosqlite:///{str(SERVER_DATA_DIR_PATH.absolute())}/sqlite.db" ) @@ -148,6 +149,11 @@ os.getenv("DSTACK_SERVER_LOG_QUOTA_PER_JOB_HOUR", 50 * 1024 * 1024) # 50 MB ) +# TODO: Replace DSTACK_SERVER_SSH_POOL_ENABLED with DSTACK_SERVER_SSH_POOL_DISABLED +# as pool becomes opt-out and document the env var. +SERVER_SSH_POOL_ENABLED = os.getenv("DSTACK_SERVER_SSH_POOL_ENABLED") is not None +SERVER_SSH_CONNECT_TIMEOUT = int(os.getenv("DSTACK_SERVER_SSH_CONNECT_TIMEOUT", 3)) + # Development settings SQL_ECHO_ENABLED = os.getenv("DSTACK_SQL_ECHO_ENABLED") is not None diff --git a/src/dstack/_internal/utils/path.py b/src/dstack/_internal/utils/path.py index 18e0b7c812..07b8fdd664 100644 --- a/src/dstack/_internal/utils/path.py +++ b/src/dstack/_internal/utils/path.py @@ -1,6 +1,7 @@ import os from dataclasses import dataclass from pathlib import Path, PurePath, PurePosixPath +from tempfile import TemporaryDirectory from typing import Union PathLike = Union[str, os.PathLike] @@ -55,3 +56,12 @@ def is_absolute_posix_path(path: PathLike) -> bool: if str(path).startswith("~"): return True return PurePosixPath(path).is_absolute() + + +def make_tmp_symlink_to_dir( + dirpath: PathLike, symlink_dirname: str +) -> tuple[TemporaryDirectory, Path]: + temp_dir = TemporaryDirectory() + symlink_dir = Path(temp_dir.name) / symlink_dirname + symlink_dir.symlink_to(dirpath, target_is_directory=True) + return temp_dir, symlink_dir diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py index b555556881..33e57df016 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py @@ -543,7 +543,7 @@ def shim_client_mock( mock.list_tasks.return_value = TaskListResponse(tasks=[]) mock.is_safe_to_restart.return_value = False monkeypatch.setattr( - "dstack._internal.server.services.runner.client.ShimClient", + "dstack._internal.server.services.runner.client.ShimClient.from_address", Mock(return_value=mock), ) return mock diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index 35a129e0f6..e308b89ce8 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -36,6 +36,7 @@ RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, VolumeMountPoint, VolumeStatus +from dstack._internal.core.services.ssh.tunnel import SSHTunnel from dstack._internal.server import settings as server_settings from dstack._internal.server.background.pipeline_tasks.jobs_running import ( ROUTER_PROVISIONING_WAIT_TIMEOUT_SECONDS, @@ -61,7 +62,6 @@ TaskStatus, ) from dstack._internal.server.services.runner.client import RunnerClient, ShimClient -from dstack._internal.server.services.runner.ssh import SSHTunnel from dstack._internal.server.services.runs.replicas import RouterEnvStatus from dstack._internal.server.services.volumes import volume_model_to_volume from dstack._internal.server.testing.common import ( @@ -116,7 +116,7 @@ def worker() -> JobRunningWorker: @pytest.fixture def ssh_tunnel_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: mock = MagicMock(spec_set=SSHTunnel) - monkeypatch.setattr("dstack._internal.server.services.runner.ssh.SSHTunnel", mock) + monkeypatch.setattr("dstack._internal.server.services.runner.pool.SSHTunnel", mock) return mock @@ -126,7 +126,8 @@ def shim_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: mock.healthcheck.return_value = HealthcheckResponse(service="dstack-shim", version="latest") mock.get_task.return_value.image_pull_progress = None monkeypatch.setattr( - "dstack._internal.server.services.runner.client.ShimClient", Mock(return_value=mock) + "dstack._internal.server.services.runner.client.ShimClient.from_address", + Mock(return_value=mock), ) return mock @@ -138,7 +139,8 @@ def runner_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: service="dstack-runner", version="0.0.1.dev2" ) monkeypatch.setattr( - "dstack._internal.server.services.runner.client.RunnerClient", Mock(return_value=mock) + "dstack._internal.server.services.runner.client.RunnerClient.from_address", + Mock(return_value=mock), ) return mock @@ -481,9 +483,9 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, patch( "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", @@ -561,7 +563,7 @@ async def test_runs_provisioning_job( before_processed_at = job.last_processed_at with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch.object(RunnerClient, "_healthcheck") as healthcheck_mock, patch.object(RunnerClient, "submit_job") as submit_job_mock, patch.object(RunnerClient, "upload_code") as upload_code_mock, @@ -1067,14 +1069,13 @@ async def test_pulling_shim_failed( ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, - patch("dstack._internal.server.services.runner.ssh.time.sleep"), + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, ): from dstack._internal.core.errors import SSHError ssh_tunnel_cls.side_effect = SSHError await _process_job(session, worker, job) - assert ssh_tunnel_cls.call_count == 3 + assert ssh_tunnel_cls.call_count == 1 await session.refresh(job) events = await list_events(session) @@ -1084,15 +1085,14 @@ async def test_pulling_shim_failed( assert events[0].message == "Job became unreachable" with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, - patch("dstack._internal.server.services.runner.ssh.time.sleep"), + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, freeze_time(job.disconnected_at + timedelta(minutes=5)), ): from dstack._internal.core.errors import SSHError ssh_tunnel_cls.side_effect = SSHError await _process_job(session, worker, job) - assert ssh_tunnel_cls.call_count == 3 + assert ssh_tunnel_cls.call_count == 1 await session.refresh(job) assert job.status == JobStatus.TERMINATING @@ -1168,11 +1168,12 @@ async def test_provisioning_shim_force_stop_if_already_running_api_v1( instance_assigned=True, ) monkeypatch.setattr( - "dstack._internal.server.services.runner.ssh.SSHTunnel", Mock(return_value=MagicMock()) + "dstack._internal.server.services.runner.pool.SSHTunnel", + Mock(return_value=MagicMock()), ) shim_client_mock = Mock() monkeypatch.setattr( - "dstack._internal.server.services.runner.client.ShimClient", + "dstack._internal.server.services.runner.client.ShimClient.from_address", Mock(return_value=shim_client_mock), ) shim_client_mock.healthcheck.return_value = HealthcheckResponse( @@ -1243,9 +1244,9 @@ async def test_master_job_waits_for_workers( await session.commit() with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel"), + patch("dstack._internal.server.services.runner.pool.SSHTunnel"), patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, ): runner_client_mock = runner_client_cls.return_value @@ -1342,9 +1343,9 @@ async def test_updates_running_job( with ( patch.object(server_settings, "SERVER_DIR_PATH", tmp_path), - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, ): runner_client_mock = runner_client_cls.return_value @@ -1365,9 +1366,9 @@ async def test_updates_running_job( await session.commit() with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, ): runner_client_mock = runner_client_cls.return_value @@ -1411,12 +1412,11 @@ async def test_running_job_disconnect_retries_then_terminates( ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, - patch("dstack._internal.server.services.runner.ssh.time.sleep"), + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, ): ssh_tunnel_cls.side_effect = SSHError await _process_job(session, worker, job) - assert ssh_tunnel_cls.call_count == 3 + assert ssh_tunnel_cls.call_count == 1 await session.refresh(job) events = await list_events(session) @@ -1426,13 +1426,12 @@ async def test_running_job_disconnect_retries_then_terminates( assert events[0].message == "Job became unreachable" with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, - patch("dstack._internal.server.services.runner.ssh.time.sleep"), + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, freeze_time(job.disconnected_at + timedelta(minutes=5)), ): ssh_tunnel_cls.side_effect = SSHError await _process_job(session, worker, job) - assert ssh_tunnel_cls.call_count == 3 + assert ssh_tunnel_cls.call_count == 1 await session.refresh(job) assert job.status == JobStatus.TERMINATING @@ -1537,9 +1536,9 @@ async def test_inactivity_duration( instance_assigned=True, ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, ): runner_client_mock = runner_client_cls.return_value @@ -1649,9 +1648,9 @@ async def test_gpu_utilization( ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, ): runner_client_mock = runner_client_cls.return_value @@ -2127,7 +2126,8 @@ async def test_does_not_terminate_job_when_instance_access_is_valid( session=session, run=run, status=job_status, - job_provisioning_data=get_job_provisioning_data(dockerized=False), + # dockerized=True so that the shim port is forwarded for the PULLING case + job_provisioning_data=get_job_provisioning_data(dockerized=True), instance=instance, instance_assigned=True, ) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py index 6bc3a433ba..5ec769519b 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py @@ -384,8 +384,10 @@ async def test_terminates_job( await session.commit() with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch("dstack._internal.server.services.runner.client.ShimClient") as ShimClientMock, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as SSHTunnelMock, + patch( + "dstack._internal.server.services.runner.client.ShimClient.from_address" + ) as ShimClientMock, ): shim_client_mock = ShimClientMock.return_value await worker.process(_job_to_pipeline_item(job)) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py b/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py index df52dd88e2..1e3900a449 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py @@ -64,9 +64,9 @@ async def test_collects_metrics(self, test_db, session: AsyncSession): instance=instance, ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as SSHTunnelMock, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as RunnerClientMock, ): runner_client_mock = RunnerClientMock.return_value diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py b/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py index 80961d5c10..0775723b4d 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py @@ -73,12 +73,14 @@ async def job(self, request: pytest.FixtureRequest, session: AsyncSession) -> Jo @pytest.fixture def ssh_tunnel_mock(self) -> Generator[Mock, None, None]: - with patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock: + with patch("dstack._internal.server.services.runner.pool.SSHTunnel") as SSHTunnelMock: yield SSHTunnelMock @pytest.fixture def shim_client_mock(self) -> Generator[Mock, None, None]: - with patch("dstack._internal.server.services.runner.client.ShimClient") as ShimClientMock: + with patch( + "dstack._internal.server.services.runner.client.ShimClient.from_address" + ) as ShimClientMock: yield ShimClientMock.return_value @freeze_time(datetime(2023, 1, 2, 3, 5, 20, tzinfo=timezone.utc))