Skip to content

Commit 1203e3e

Browse files
authored
Implement SSH connection pool for runner instances (#3936)
* Prototype InstanceConnectionPool * Make runner and shim client work over uds * Update runner and shim client call sites * Pool fixes * Revert args rename * Skip pool for container backends * Use dstack tmp dir * Drop ports from runner_ssh_tunnel * Refactor methods * Implement InstanceConnectionPool.close_all * Add DSTACK_SERVER_SSH_POOL_DISABLED * Check SSH connection health * Tweak ssh options * Drop retries from runner_ssh_tunnel * Update env docs * Clean up locks with WeakValueDictionary * Clean up control socket * Surive tmp cleanup * Make ssh pool opt-in * Fix tests * Minor fixes * Minor fixes * Fix typo * Run startup/teardown cleanup only if pool enabled * Drop connection on instance termination
1 parent 2cda42e commit 1203e3e

22 files changed

Lines changed: 625 additions & 203 deletions

File tree

mkdocs/docs/reference/env.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ For more details on the options below, refer to the [server deployment](../guide
141141
- `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY } – A default Docker registry to use for job images that do not specify an explicit registry. E.g., if set to `registry.example`, then `image: ubuntu` becomes equivalent to `image: registry.example/ubuntu`. **Note**: This setting should only be used for configuring registries that act as a pull-through cache for Docker Hub. The default `dstack` images are also pulled from the configured registry.
142142
- `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME } – Username for authenticating with the default Docker registry. See `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`.
143143
- `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD } – Password for authenticating with the default Docker registry. Applied only when the image has no explicit registry and the run configuration does not specify `registry_auth`. **Note**: The value may be visible to anyone who can SSH into instances managed by `dstack`, which usually includes all users of that `dstack` server.
144+
- `DSTACK_SERVER_SSH_CONNECT_TIMEOUT`{ #DSTACK_SERVER_SSH_CONNECT_TIMEOUT } – The SSH `ConnectTimeout` for server-instance connections, in seconds. Defaults to `3`. Increase if there are high-latency links between the server and instances.
144145

145146
??? info "Internal environment variables"
146147
The following environment variables are intended for development purposes:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ server = [
186186
"aiorwlock",
187187
"aiocache",
188188
"httpx>=0.28.0",
189+
"requests-unixsocket>=0.4.1",
189190
"jinja2",
190191
"watchfiles",
191192
"sqlalchemy[asyncio]>=2.0.0",

src/dstack/_internal/core/services/ssh/tunnel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ async def aclose(self) -> None:
252252
proc.stdout,
253253
)
254254

255+
def check(self) -> bool:
256+
proc = subprocess.run(
257+
self.check_command(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
258+
)
259+
return proc.returncode == 0
260+
255261
async def acheck(self) -> bool:
256262
proc = await asyncio.create_subprocess_exec(
257263
*self.check_command(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL

src/dstack/_internal/server/app.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from dstack._internal.server.services.projects import get_or_create_default_project
5858
from dstack._internal.server.services.proxy.deps import ServerProxyDependencyInjector
5959
from dstack._internal.server.services.proxy.routers import service_proxy
60+
from dstack._internal.server.services.runner.pool import instance_connection_pool
6061
from dstack._internal.server.services.storage import init_default_storage
6162
from dstack._internal.server.services.users import get_or_create_admin_user
6263
from dstack._internal.server.settings import (
@@ -75,6 +76,7 @@
7576
get_client_version,
7677
get_server_client_error_details,
7778
)
79+
from dstack._internal.utils.common import run_async
7880
from dstack._internal.utils.logging import get_logger
7981
from dstack._internal.utils.ssh import check_required_ssh_version
8082

@@ -167,6 +169,8 @@ async def lifespan(app: FastAPI):
167169
)
168170
if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None:
169171
init_default_storage()
172+
if settings.SERVER_SSH_POOL_ENABLED:
173+
await run_async(instance_connection_pool.startup_cleanup)
170174
scheduler = None
171175
pipeline_manager = None
172176
if settings.SERVER_BACKGROUND_PROCESSING_ENABLED:
@@ -209,6 +213,8 @@ async def lifespan(app: FastAPI):
209213
await gateway_connections_pool.remove_all()
210214
service_conn_pool = await get_injector_from_app(app).get_service_connection_pool()
211215
await service_conn_pool.remove_all()
216+
if settings.SERVER_SSH_POOL_ENABLED:
217+
await run_async(instance_connection_pool.close_all)
212218
await get_db().engine.dispose()
213219
# Let checked-out DB connections close as dispose() only closes checked-in connections
214220
await asyncio.sleep(3)

src/dstack/_internal/server/background/pipeline_tasks/instances/check.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import uuid
3+
from collections.abc import Mapping
34
from datetime import timedelta
45
from typing import Optional
56

@@ -373,15 +374,15 @@ async def _get_backend_for_provisioning_wait(
373374
)
374375

375376

376-
@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
377+
@runner_ssh_tunnel
377378
def _check_instance_inner(
378-
ports: dict[int, int],
379+
addresses: Mapping[int, runner_client.LocalAddress],
379380
*,
380381
instance: InstanceModel,
381382
check_instance_health: bool = False,
382383
) -> InstanceCheck:
383384
instance_health_response: Optional[InstanceHealthResponse] = None
384-
shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
385+
shim_client = runner_client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT])
385386
method = shim_client.healthcheck
386387
try:
387388
healthcheck_response = method(unmask_exceptions=True)

src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from dstack._internal.server.models import InstanceModel
1212
from dstack._internal.server.services import backends as backends_services
1313
from dstack._internal.server.services.instances import get_instance_provisioning_data
14+
from dstack._internal.server.services.runner.pool import (
15+
InstanceConnectionKey,
16+
instance_connection_pool,
17+
)
1418
from dstack._internal.utils.common import get_current_datetime, run_async
1519
from dstack._internal.utils.logging import get_logger
1620

@@ -77,6 +81,9 @@ async def terminate_instance(instance_model: InstanceModel) -> ProcessResult:
7781
exc_info=not isinstance(exc, BackendError),
7882
)
7983

84+
if job_provisioning_data is not None:
85+
instance_connection_pool.drop(InstanceConnectionKey.from_jpd(job_provisioning_data))
86+
8087
result.instance_update_map["deleted"] = True
8188
result.instance_update_map["deleted_at"] = NOW_PLACEHOLDER
8289
result.instance_update_map["finished_at"] = NOW_PLACEHOLDER

src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import enum
33
import uuid
4+
from collections.abc import Mapping
45
from dataclasses import dataclass, field
56
from datetime import datetime, timedelta
67
from typing import Dict, Iterable, Literal, Optional, Sequence, Union
@@ -1308,9 +1309,9 @@ def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> boo
13081309
return False
13091310

13101311

1311-
@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
1312+
@runner_ssh_tunnel
13121313
def _process_provisioning_with_shim(
1313-
ports: Dict[int, int],
1314+
addresses: Mapping[int, client.LocalAddress],
13141315
run: Run,
13151316
job_model: JobModel,
13161317
jrd: Optional[JobRuntimeData],
@@ -1322,7 +1323,7 @@ def _process_provisioning_with_shim(
13221323
ssh_key: Optional[str],
13231324
) -> bool:
13241325
job_spec = get_job_spec(job_model)
1325-
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
1326+
shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT])
13261327

13271328
resp = shim_client.healthcheck()
13281329
if resp is None:
@@ -1435,21 +1436,21 @@ class _SyncShimPullingStateResult:
14351436
image_pull_progress: Optional[ImagePullProgress] = None
14361437

14371438

1438-
@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1)
1439-
def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability:
1440-
runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT])
1439+
@runner_ssh_tunnel
1440+
def _get_runner_availability(addresses: Mapping[int, client.LocalAddress]) -> _RunnerAvailability:
1441+
runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT])
14411442
if runner_client.healthcheck() is None:
14421443
return _RunnerAvailability.UNAVAILABLE
14431444
return _RunnerAvailability.AVAILABLE
14441445

14451446

1446-
@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT])
1447+
@runner_ssh_tunnel
14471448
def _sync_shim_pulling_state(
1448-
ports: Dict[int, int],
1449+
addresses: Mapping[int, client.LocalAddress],
14491450
job_model: JobModel,
14501451
jrd: Optional[JobRuntimeData] = None,
14511452
) -> Union[_SyncShimPullingStateResult, Literal[False]]:
1452-
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
1453+
shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT])
14531454
image_pull_progress: Optional[ImagePullProgress] = None
14541455
if shim_client.is_api_v2_supported():
14551456
task = shim_client.get_task(job_model.id)
@@ -1525,9 +1526,9 @@ class _SubmitJobToRunnerResult:
15251526
job_runtime_data: Optional[JobRuntimeData] = None
15261527

15271528

1528-
@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1)
1529+
@runner_ssh_tunnel
15291530
def _submit_job_to_runner(
1530-
ports: Dict[int, int],
1531+
addresses: Mapping[int, client.LocalAddress],
15311532
run: Run,
15321533
job_model: JobModel,
15331534
job: Job,
@@ -1552,7 +1553,7 @@ def _submit_job_to_runner(
15521553
else:
15531554
instance_env = None
15541555

1555-
runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT])
1556+
runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT])
15561557
if runner_client.healthcheck() is None:
15571558
return _SubmitJobToRunnerResult(success=success_if_not_available)
15581559

@@ -1595,13 +1596,13 @@ class _ProcessRunningResult:
15951596
job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap)
15961597

15971598

1598-
@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT])
1599+
@runner_ssh_tunnel
15991600
def _process_running(
1600-
ports: Dict[int, int],
1601+
addresses: Mapping[int, client.LocalAddress],
16011602
run_model: RunModel,
16021603
job_model: JobModel,
16031604
) -> Union[_ProcessRunningResult, Literal[False]]:
1604-
runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT])
1605+
runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT])
16051606
timestamp = job_model.runner_timestamp or 0
16061607
resp = runner_client.pull(timestamp)
16071608
logs_services.write_logs(

src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import uuid
3+
from collections.abc import Mapping
34
from dataclasses import dataclass, field
45
from datetime import datetime, timedelta
56
from typing import Optional, Sequence, TypedDict
@@ -852,9 +853,9 @@ async def _stop_container(
852853
return True
853854

854855

855-
@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT])
856-
def _shim_submit_stop(ports: dict[int, int], job_model: JobModel) -> bool:
857-
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
856+
@runner_ssh_tunnel
857+
def _shim_submit_stop(addresses: Mapping[int, client.LocalAddress], job_model: JobModel) -> bool:
858+
shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT])
858859

859860
resp = shim_client.healthcheck()
860861
if resp is None:

src/dstack/_internal/server/background/scheduled_tasks/metrics.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
3-
from typing import Dict, List, Optional
3+
from collections.abc import Mapping
4+
from typing import List, Optional
45

56
from sqlalchemy import Delete, delete, select
67
from sqlalchemy.orm import joinedload
@@ -164,9 +165,9 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint]
164165
)
165166

166167

167-
@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1)
168+
@runner_ssh_tunnel
168169
def _pull_runner_metrics(
169-
ports: Dict[int, int],
170+
addresses: Mapping[int, client.LocalAddress],
170171
) -> Optional[MetricsResponse]:
171-
runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT])
172+
runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT])
172173
return runner_client.get_metrics()

src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import uuid
2+
from collections.abc import Mapping
23
from datetime import datetime, timedelta
34
from typing import Optional
45

@@ -144,7 +145,9 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[str]:
144145
return res
145146

146147

147-
@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
148-
def _pull_job_metrics(ports: dict[int, int], task_id: uuid.UUID) -> Optional[str]:
149-
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
148+
@runner_ssh_tunnel
149+
def _pull_job_metrics(
150+
addresses: Mapping[int, client.LocalAddress], task_id: uuid.UUID
151+
) -> Optional[str]:
152+
shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT])
150153
return shim_client.get_task_metrics(task_id)

0 commit comments

Comments
 (0)