Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
ConfigurationWithPortsParams,
DevEnvironmentConfiguration,
PortMapping,
RunAttachConfiguration,
RunAttachParams,
RunConfigurationType,
ServiceConfiguration,
TaskConfiguration,
Expand All @@ -57,6 +59,12 @@
get_repo_creds_and_default_branch,
load_repo,
)
from dstack._internal.core.services.ssh.attach import (
SSHProxyAwsSSMConfig,
SSHProxyCommandConfig,
SSHProxyConfig,
SSHProxyJumpConfig,
)
from dstack._internal.utils.common import local_time
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
from dstack._internal.utils.logging import get_logger
Expand Down Expand Up @@ -241,8 +249,28 @@
bind_address: Optional[str] = getattr(
configurator_args, _BIND_ADDRESS_ARG, None
)
# Map the attach.proxy settings to the original configuration
attach_proxy_config = SSHProxyConfig()
if isinstance(conf, RunAttachConfiguration) and conf.attach is not None:
attach: RunAttachParams = conf.attach
if attach.proxy.type == "jump":
attach_proxy_config = SSHProxyJumpConfig(attach.proxy.proxy_jump)
elif attach.proxy.type == "command":
attach_proxy_config = SSHProxyCommandConfig(attach.proxy.proxy_command)
elif attach.proxy.type == "aws-ssm":
attach_proxy_config = SSHProxyAwsSSMConfig(
profile=attach.proxy.profile,

Check failure on line 262 in src/dstack/_internal/cli/services/configurators/run.py

View workflow job for this annotation

GitHub Actions / build-artifacts / python-test (ubuntu-latest, 3.9)

No parameter named "profile" (reportCallIssue)
region=attach.proxy.region,

Check failure on line 263 in src/dstack/_internal/cli/services/configurators/run.py

View workflow job for this annotation

GitHub Actions / build-artifacts / python-test (ubuntu-latest, 3.9)

No parameter named "region" (reportCallIssue)
document_name=attach.proxy.document_name,

Check failure on line 264 in src/dstack/_internal/cli/services/configurators/run.py

View workflow job for this annotation

GitHub Actions / build-artifacts / python-test (ubuntu-latest, 3.9)

No parameter named "document_name" (reportCallIssue)
)
if attach.proxy.type != "none":
console.print(
f"Using client-side attach proxy: [code]{attach.proxy.type}[/]"
)
try:
if run.attach(bind_address=bind_address):
if run.attach(
bind_address=bind_address, attach_proxy_config=attach_proxy_config
):
for entry in run.logs():
sys.stdout.buffer.write(entry)
sys.stdout.buffer.flush()
Expand Down
74 changes: 74 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,11 +644,83 @@ def schema_extra(schema: Dict[str, Any]):
BaseRunConfigurationConfig.schema_extra(schema)


AttachProxyType = Literal["none", "jump", "command", "aws-ssm"]


class AttachProxyNone(CoreModel):
type: Annotated[
Literal["none"], Field(description="No client-side proxy for the first SSH hop")
] = "none"


class AttachProxyJump(CoreModel):
type: Annotated[
Literal["jump"],
Field(description="Use ProxyJump on the client-side for the first SSH hop"),
] = "jump"
proxy_jump: Annotated[
str,
Field(description="Host alias from ~/.ssh/config for using in ProxyJump"),
]


class AttachProxyCommand(CoreModel):
type: Annotated[
Literal["command"],
Field(description="Use ProxyCommand on the client-side for the first SSH hop"),
] = "command"
proxy_command: Annotated[
str,
Field(
description=(
"ProxyCommand string to execute for the first hop."
" The value is passed as-is to ssh_config."
" If you need stream forwarding through SSH, include '-W %h:%p' yourself."
)
),
]


class AttachProxyAwsSSM(CoreModel):
type: Annotated[
Literal["aws-ssm"], Field(description="Use AWS SSM as a proxy for the first SSH hop")
] = "aws-ssm"
profile: Annotated[Optional[str], Field(description="AWS profile name to use")] = None
region: Annotated[Optional[str], Field(description="AWS region for SSM")] = None
document_name: Annotated[
str,
Field(description="SSM document name for SSH session"),
] = "AWS-StartSSHSession"


class RunAttachParams(CoreModel):
proxy: Annotated[
Union[
AttachProxyNone,
AttachProxyJump,
AttachProxyCommand,
AttachProxyAwsSSM,
],
Field(
discriminator="type",
description="Client-side SSH transport overrides for attach",
),
] = AttachProxyNone()


class RunAttachConfiguration(CoreModel):
attach: Annotated[
RunAttachParams,
Field(description="Attach transport settings (client-side only)", exclude=True),
] = RunAttachParams()


class DevEnvironmentConfiguration(
ProfileParams,
BaseRunConfiguration,
ConfigurationWithPortsParams,
DevEnvironmentConfigurationParams,
RunAttachConfiguration,
generate_dual_core_model(DevEnvironmentConfigurationConfig),
):
type: Literal["dev-environment"] = "dev-environment"
Expand Down Expand Up @@ -680,6 +752,7 @@ class TaskConfiguration(
ConfigurationWithCommandsParams,
ConfigurationWithPortsParams,
TaskConfigurationParams,
RunAttachConfiguration,
generate_dual_core_model(TaskConfigurationConfig),
):
type: Literal["task"] = "task"
Expand Down Expand Up @@ -838,6 +911,7 @@ class ServiceConfiguration(
BaseRunConfiguration,
ConfigurationWithCommandsParams,
ServiceConfigurationParams,
RunAttachConfiguration,
generate_dual_core_model(ServiceConfigurationConfig),
):
type: Literal["service"] = "service"
Expand Down
94 changes: 93 additions & 1 deletion src/dstack/_internal/core/services/ssh/attach.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import atexit
import dataclasses
import re
import time
from pathlib import Path
from typing import Optional, Union

import psutil

from dstack._internal.core.errors import SSHError
from dstack._internal.core.errors import ConfigurationError, SSHError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import JobProvisioningData
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.core.services.ssh.client import get_ssh_client_info
from dstack._internal.core.services.ssh.ports import PortsLock
Expand All @@ -25,6 +28,82 @@
_SSH_TUNNEL_REGEX = re.compile(r"(?:[\w.-]+:)?(?P<local_port>\d+):localhost:(?P<remote_port>\d+)")


HostConfigType = dict[str, Union[str, int, FilePath]]


class SSHProxyConfig:
"""Do nothing"""

def update_host(self, host: HostConfigType):
pass

def apply_provisioning_data(self, provisioning_data: JobProvisioningData):
pass


@dataclasses.dataclass
class SSHProxyJumpConfig(SSHProxyConfig):
"""Add ProxyJump to the given host configuration"""

jump_host: str

def update_host(self, host: HostConfigType):
host["ProxyJump"] = self.jump_host


@dataclasses.dataclass
class SSHProxyCommandConfig(SSHProxyConfig):
"""Add ProxyCommand to the given host configuration"""

command: str

def update_host(self, host: HostConfigType):
host["ProxyCommand"] = self.command


@dataclasses.dataclass(kw_only=True)

Check failure on line 64 in src/dstack/_internal/core/services/ssh/attach.py

View workflow job for this annotation

GitHub Actions / build-artifacts / python-test (ubuntu-latest, 3.9)

No overloads for "dataclass" match the provided arguments   Argument types: (Literal[True]) (reportCallIssue)
class SSHProxyAwsSSMConfig(SSHProxyConfig):
"""Add ProxyCommand to use AWS SSM"""

profile: Optional[str] = None
region: Optional[str] = None
document_name: Optional[str] = None
instance_id: Optional[str] = None
instance_region: Optional[str] = None

def update_host(self, host: HostConfigType):
if self.instance_id:
host["HostName"] = self.instance_id
region = self.region if self.region else self.instance_region
document = self.document_name if self.document_name else "AWS-StartSSHSession"
args = [
"aws",
"ssm",
"start-session",
"--target",
"%h",
"--document-name",
document,
"--parameters",
"portNumber=%p",
]
if region:
args.extend(["--region", region])
if self.profile:
args.extend(["--profile", self.profile])
command = f"sh -c '{' '.join(map(str, args))}'"
host["ProxyCommand"] = command

def apply_provisioning_data(self, provisioning_data: JobProvisioningData):
backend = provisioning_data.get_base_backend()
if backend != BackendType.AWS:
raise ConfigurationError(
"attach.proxy.type=aws-ssm is supported only for the AWS backend"
)
self.instance_id = provisioning_data.instance_id
self.instance_region = provisioning_data.region


class SSHAttach:
@classmethod
def get_control_sock_path(cls, run_name: str) -> Path:
Expand Down Expand Up @@ -67,6 +146,7 @@
service_port: Optional[int] = None,
local_backend: bool = False,
bind_address: Optional[str] = None,
proxy_config: Optional[SSHProxyConfig] = None,
):
self._ports_lock = ports_lock
self.ports = ports_lock.dict()
Expand Down Expand Up @@ -199,6 +279,18 @@
}
)

if proxy_config:
# Apply proxy configuration for the first hop connection
first_hop_key: Optional[str] = None
if f"{run_name}-jump-host" in hosts:
first_hop_key = f"{run_name}-jump-host"
elif f"{run_name}-host" in hosts:
first_hop_key = f"{run_name}-host"
else:
first_hop_key = run_name

proxy_config.update_host(hosts[first_hop_key])

def attach(self):
include_ssh_config(self.ssh_config_path)
for host, options in self.hosts.items():
Expand Down
7 changes: 6 additions & 1 deletion src/dstack/api/_public/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from dstack._internal.core.models.runs import Run as RunModel
from dstack._internal.core.services.logs import URLReplacer
from dstack._internal.core.services.ssh.attach import SSHAttach
from dstack._internal.core.services.ssh.attach import SSHAttach, SSHProxyConfig
from dstack._internal.core.services.ssh.ports import PortsLock
from dstack._internal.server.schemas.logs import PollLogsRequest
from dstack._internal.utils.common import get_or_error, make_proxy_url
Expand Down Expand Up @@ -259,6 +259,7 @@ def attach(
ports_overrides: Optional[List[PortMapping]] = None,
replica_num: Optional[int] = None,
job_num: int = 0,
attach_proxy_config: Optional[SSHProxyConfig] = None,
) -> bool:
"""
Establish an SSH tunnel to the instance and update SSH config
Expand Down Expand Up @@ -347,6 +348,9 @@ def attach(
if isinstance(self._run.run_spec.configuration, ServiceConfiguration):
service_port = get_service_port(job.job_spec, self._run.run_spec.configuration)

if attach_proxy_config:
attach_proxy_config.apply_provisioning_data(provisioning_data)

self._ssh_attach = SSHAttach(
hostname=provisioning_data.hostname,
ssh_port=provisioning_data.ssh_port,
Expand All @@ -361,6 +365,7 @@ def attach(
service_port=service_port,
local_backend=provisioning_data.backend == BackendType.LOCAL,
bind_address=bind_address,
proxy_config=attach_proxy_config,
)
if not ports_lock:
self._ssh_attach.attach()
Expand Down
Loading