Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mkdocs/docs/reference/env.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ For more details on the options below, refer to the [server deployment](../guide
- `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY } – A default Docker registry to use for job images that do not specify an explicit registry. E.g., if set to `registry.example`, then `image: ubuntu` becomes equivalent to `image: registry.example/ubuntu`. **Note**: This setting should only be used for configuring registries that act as a pull-through cache for Docker Hub. The default `dstack` images are also pulled from the configured registry.
- `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME } – Username for authenticating with the default Docker registry. See `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`.
- `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD } – Password for authenticating with the default Docker registry. Applied only when the image has no explicit registry and the run configuration does not specify `registry_auth`. **Note**: The value may be visible to anyone who can SSH into instances managed by `dstack`, which usually includes all users of that `dstack` server.
- `DSTACK_SERVER_SSH_CONNECT_TIMEOUT`{ #DSTACK_SERVER_SSH_CONNECT_TIMEOUT } – The SSH `ConnectTimeout` for server-instance connections, in seconds. Defaults to `3`. Increase if there are high-latency links between the server and instances.

??? info "Internal environment variables"
The following environment variables are intended for development purposes:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ server = [
"aiorwlock",
"aiocache",
"httpx>=0.28.0",
"requests-unixsocket>=0.4.1",
"jinja2",
"watchfiles",
"sqlalchemy[asyncio]>=2.0.0",
Expand Down
6 changes: 6 additions & 0 deletions src/dstack/_internal/core/services/ssh/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ async def aclose(self) -> None:
proc.stdout,
)

def check(self) -> bool:
proc = subprocess.run(
self.check_command(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
return proc.returncode == 0

async def acheck(self) -> bool:
proc = await asyncio.create_subprocess_exec(
*self.check_command(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
Expand Down
6 changes: 6 additions & 0 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from dstack._internal.server.services.projects import get_or_create_default_project
from dstack._internal.server.services.proxy.deps import ServerProxyDependencyInjector
from dstack._internal.server.services.proxy.routers import service_proxy
from dstack._internal.server.services.runner.pool import instance_connection_pool
from dstack._internal.server.services.storage import init_default_storage
from dstack._internal.server.services.users import get_or_create_admin_user
from dstack._internal.server.settings import (
Expand All @@ -75,6 +76,7 @@
get_client_version,
get_server_client_error_details,
)
from dstack._internal.utils.common import run_async
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.ssh import check_required_ssh_version

Expand Down Expand Up @@ -167,6 +169,8 @@ async def lifespan(app: FastAPI):
)
if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None:
init_default_storage()
if settings.SERVER_SSH_POOL_ENABLED:
await run_async(instance_connection_pool.startup_cleanup)
scheduler = None
pipeline_manager = None
if settings.SERVER_BACKGROUND_PROCESSING_ENABLED:
Expand Down Expand Up @@ -209,6 +213,8 @@ async def lifespan(app: FastAPI):
await gateway_connections_pool.remove_all()
service_conn_pool = await get_injector_from_app(app).get_service_connection_pool()
await service_conn_pool.remove_all()
if settings.SERVER_SSH_POOL_ENABLED:
await run_async(instance_connection_pool.close_all)
await get_db().engine.dispose()
# Let checked-out DB connections close as dispose() only closes checked-in connections
await asyncio.sleep(3)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import uuid
from collections.abc import Mapping
from datetime import timedelta
from typing import Optional

Expand Down Expand Up @@ -373,15 +374,15 @@ async def _get_backend_for_provisioning_wait(
)


@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
@runner_ssh_tunnel
def _check_instance_inner(
ports: dict[int, int],
addresses: Mapping[int, runner_client.LocalAddress],
*,
instance: InstanceModel,
check_instance_health: bool = False,
) -> InstanceCheck:
instance_health_response: Optional[InstanceHealthResponse] = None
shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
shim_client = runner_client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT])
method = shim_client.healthcheck
try:
healthcheck_response = method(unmask_exceptions=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from dstack._internal.server.models import InstanceModel
from dstack._internal.server.services import backends as backends_services
from dstack._internal.server.services.instances import get_instance_provisioning_data
from dstack._internal.server.services.runner.pool import (
InstanceConnectionKey,
instance_connection_pool,
)
from dstack._internal.utils.common import get_current_datetime, run_async
from dstack._internal.utils.logging import get_logger

Expand Down Expand Up @@ -77,6 +81,9 @@ async def terminate_instance(instance_model: InstanceModel) -> ProcessResult:
exc_info=not isinstance(exc, BackendError),
)

if job_provisioning_data is not None:
instance_connection_pool.drop(InstanceConnectionKey.from_jpd(job_provisioning_data))

result.instance_update_map["deleted"] = True
result.instance_update_map["deleted_at"] = NOW_PLACEHOLDER
result.instance_update_map["finished_at"] = NOW_PLACEHOLDER
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import enum
import uuid
from collections.abc import Mapping
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, Iterable, Literal, Optional, Sequence, Union
Expand Down Expand Up @@ -1308,9 +1309,9 @@ def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> boo
return False


@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
@runner_ssh_tunnel
def _process_provisioning_with_shim(
ports: Dict[int, int],
addresses: Mapping[int, client.LocalAddress],
run: Run,
job_model: JobModel,
jrd: Optional[JobRuntimeData],
Expand All @@ -1322,7 +1323,7 @@ def _process_provisioning_with_shim(
ssh_key: Optional[str],
) -> bool:
job_spec = get_job_spec(job_model)
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT])

resp = shim_client.healthcheck()
if resp is None:
Expand Down Expand Up @@ -1435,21 +1436,21 @@ class _SyncShimPullingStateResult:
image_pull_progress: Optional[ImagePullProgress] = None


@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1)
def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability:
runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT])
@runner_ssh_tunnel
def _get_runner_availability(addresses: Mapping[int, client.LocalAddress]) -> _RunnerAvailability:
runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT])
if runner_client.healthcheck() is None:
return _RunnerAvailability.UNAVAILABLE
return _RunnerAvailability.AVAILABLE


@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT])
@runner_ssh_tunnel
def _sync_shim_pulling_state(
ports: Dict[int, int],
addresses: Mapping[int, client.LocalAddress],
job_model: JobModel,
jrd: Optional[JobRuntimeData] = None,
) -> Union[_SyncShimPullingStateResult, Literal[False]]:
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT])
image_pull_progress: Optional[ImagePullProgress] = None
if shim_client.is_api_v2_supported():
task = shim_client.get_task(job_model.id)
Expand Down Expand Up @@ -1525,9 +1526,9 @@ class _SubmitJobToRunnerResult:
job_runtime_data: Optional[JobRuntimeData] = None


@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1)
@runner_ssh_tunnel
def _submit_job_to_runner(
ports: Dict[int, int],
addresses: Mapping[int, client.LocalAddress],
run: Run,
job_model: JobModel,
job: Job,
Expand All @@ -1552,7 +1553,7 @@ def _submit_job_to_runner(
else:
instance_env = None

runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT])
runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT])
if runner_client.healthcheck() is None:
return _SubmitJobToRunnerResult(success=success_if_not_available)

Expand Down Expand Up @@ -1595,13 +1596,13 @@ class _ProcessRunningResult:
job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap)


@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT])
@runner_ssh_tunnel
def _process_running(
ports: Dict[int, int],
addresses: Mapping[int, client.LocalAddress],
run_model: RunModel,
job_model: JobModel,
) -> Union[_ProcessRunningResult, Literal[False]]:
runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT])
runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT])
timestamp = job_model.runner_timestamp or 0
resp = runner_client.pull(timestamp)
logs_services.write_logs(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import uuid
from collections.abc import Mapping
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Optional, Sequence, TypedDict
Expand Down Expand Up @@ -852,9 +853,9 @@ async def _stop_container(
return True


@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT])
def _shim_submit_stop(ports: dict[int, int], job_model: JobModel) -> bool:
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
@runner_ssh_tunnel
def _shim_submit_stop(addresses: Mapping[int, client.LocalAddress], job_model: JobModel) -> bool:
shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT])

resp = shim_client.healthcheck()
if resp is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
from typing import Dict, List, Optional
from collections.abc import Mapping
from typing import List, Optional

from sqlalchemy import Delete, delete, select
from sqlalchemy.orm import joinedload
Expand Down Expand Up @@ -164,9 +165,9 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint]
)


@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1)
@runner_ssh_tunnel
def _pull_runner_metrics(
ports: Dict[int, int],
addresses: Mapping[int, client.LocalAddress],
) -> Optional[MetricsResponse]:
runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT])
runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT])
return runner_client.get_metrics()
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from collections.abc import Mapping
from datetime import datetime, timedelta
from typing import Optional

Expand Down Expand Up @@ -144,7 +145,9 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[str]:
return res


@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
def _pull_job_metrics(ports: dict[int, int], task_id: uuid.UUID) -> Optional[str]:
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
@runner_ssh_tunnel
def _pull_job_metrics(
addresses: Mapping[int, client.LocalAddress], task_id: uuid.UUID
) -> Optional[str]:
shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT])
return shim_client.get_task_metrics(task_id)
17 changes: 5 additions & 12 deletions src/dstack/_internal/server/services/gateways/connection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import contextlib
import shutil
import uuid
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import AsyncIterator, Optional, Tuple
from typing import AsyncIterator, Optional

import aiorwlock

Expand All @@ -22,7 +20,7 @@
from dstack._internal.server.services.gateways.client import GatewayClient
from dstack._internal.server.settings import SERVER_DIR_PATH
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.path import FileContent
from dstack._internal.utils.path import FileContent, make_tmp_symlink_to_dir

logger = get_logger(__name__)
CONNECTIONS_DIR = SERVER_DIR_PATH / "gateway-connections"
Expand All @@ -47,7 +45,9 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int):
self.connection_dir = CONNECTIONS_DIR / ip_address
# connection_dir can have a long path that won't be accepted by the ssh command,
# so we create a short temporary symlink
self.temp_dir, self.connection_symlink_dir = self._init_symlink_dir(self.connection_dir)
self.temp_dir, self.connection_symlink_dir = make_tmp_symlink_to_dir(
self.connection_dir, "connection"
)
self.gateway_socket_path = self.connection_symlink_dir / "gateway.sock"
self.tunnel = SSHTunnel(
destination=f"ubuntu@{ip_address}",
Expand All @@ -69,13 +69,6 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int):
self.tunnel_id = uuid.uuid4()
self._client = GatewayClient(uds=str(self.gateway_socket_path))

@staticmethod
def _init_symlink_dir(connection_dir: Path) -> Tuple[TemporaryDirectory, Path]:
temp_dir = TemporaryDirectory()
symlink_dir = Path(temp_dir.name) / "connection"
symlink_dir.symlink_to(connection_dir, target_is_directory=True)
return temp_dir, symlink_dir

async def check_or_restart(self) -> bool:
async with self._lock.writer_lock:
if not await self.tunnel.acheck():
Expand Down
6 changes: 3 additions & 3 deletions src/dstack/_internal/server/services/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,13 @@ async def stop_runner(job_model: JobModel, instance_model: InstanceModel):
logger.debug("%s: failed to stop runner", fmt(job_model))


@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT])
@runner_ssh_tunnel
def _stop_runner(
ports: dict[int, int],
addresses: Mapping[int, client.LocalAddress],
job_model: JobModel,
):
logger.debug("%s: stopping runner", fmt(job_model))
runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT])
runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT])
try:
runner_client.stop()
except requests.RequestException:
Expand Down
Loading
Loading