Skip to content

Commit 4b4c55e

Browse files
authored
Add shell run property (#2542)
Closes: #2530
1 parent 2102b1b commit 4b4c55e

6 files changed

Lines changed: 139 additions & 11 deletions

File tree

src/dstack/_internal/core/models/configurations.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
from collections import Counter
33
from enum import Enum
4+
from pathlib import PurePosixPath
45
from typing import Any, Dict, List, Optional, Union
56

67
from pydantic import Field, ValidationError, conint, constr, root_validator, validator
@@ -210,6 +211,16 @@ class BaseRunConfiguration(CoreModel):
210211
Env,
211212
Field(description="The mapping or the list of environment variables"),
212213
] = Env()
214+
shell: Annotated[
215+
Optional[str],
216+
Field(
217+
description=(
218+
"The shell used to run commands."
219+
" Allowed values are `sh`, `bash`, or an absolute path, e.g., `/usr/bin/zsh`."
220+
" Defaults to `/bin/sh` if the `image` is specified, `/bin/bash` otherwise"
221+
)
222+
),
223+
] = None
213224
# deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init`
214225
setup: CommandsList = []
215226
resources: Annotated[
@@ -244,6 +255,17 @@ def validate_user(cls, v) -> Optional[str]:
244255
UnixUser.parse(v)
245256
return v
246257

258+
@validator("shell")
259+
def validate_shell(cls, v) -> Optional[str]:
260+
if v is None:
261+
return None
262+
if v in ["sh", "bash"]:
263+
return v
264+
path = PurePosixPath(v)
265+
if path.is_absolute():
266+
return v
267+
raise ValueError("The value must be `sh`, `bash`, or an absolute path")
268+
247269

248270
class BaseRunConfigurationWithPorts(BaseRunConfiguration):
249271
ports: Annotated[
@@ -261,7 +283,7 @@ def convert_ports(cls, v) -> PortMapping:
261283

262284

263285
class BaseRunConfigurationWithCommands(BaseRunConfiguration):
264-
commands: Annotated[CommandsList, Field(description="The bash commands to run")] = []
286+
commands: Annotated[CommandsList, Field(description="The shell commands to run")] = []
265287

266288
@root_validator
267289
def check_image_or_commands_present(cls, values):
@@ -276,7 +298,7 @@ class DevEnvironmentConfigurationParams(CoreModel):
276298
Field(description="The IDE to run. Supported values include `vscode` and `cursor`"),
277299
]
278300
version: Annotated[Optional[str], Field(description="The version of the IDE")] = None
279-
init: Annotated[CommandsList, Field(description="The bash commands to run on startup")] = []
301+
init: Annotated[CommandsList, Field(description="The shell commands to run on startup")] = []
280302
inactivity_duration: Annotated[
281303
Optional[Union[Literal["off"], int, bool, str]],
282304
Field(

src/dstack/_internal/server/services/jobs/configurators/base.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import shlex
22
import sys
33
from abc import ABC, abstractmethod
4+
from pathlib import PurePosixPath
45
from typing import Dict, List, Optional, Union
56

67
from cachetools import TTLCache, cached
@@ -131,16 +132,24 @@ async def _get_job_spec(
131132
)
132133
return job_spec
133134

135+
def _shell(self) -> str:
136+
shell = self.run_spec.configuration.shell
137+
if shell is not None:
138+
path = PurePosixPath(shell)
139+
if path.is_absolute():
140+
return shell
141+
return str("/bin" / path)
142+
if self.run_spec.configuration.image is None: # dstackai/base
143+
return "/bin/bash"
144+
return "/bin/sh"
145+
134146
async def _commands(self) -> List[str]:
135147
if self.run_spec.configuration.entrypoint is not None: # docker-like format
136148
entrypoint = shlex.split(self.run_spec.configuration.entrypoint)
137149
commands = self.run_spec.configuration.commands
138-
elif self.run_spec.configuration.image is None: # dstackai/base
139-
entrypoint = ["/bin/bash", "-i", "-c"]
140-
commands = [_join_shell_commands(self._shell_commands())]
141-
elif self._shell_commands(): # custom docker image with shell commands
142-
entrypoint = ["/bin/sh", "-i", "-c"]
143-
commands = [_join_shell_commands(self._shell_commands())]
150+
elif shell_commands := self._shell_commands():
151+
entrypoint = [self._shell(), "-i", "-c"]
152+
commands = [_join_shell_commands(shell_commands)]
144153
else: # custom docker image without commands
145154
image_config = await self._get_image_config()
146155
entrypoint = image_config.entrypoint or []

src/dstack/api/server/_runs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]:
118118
profile_excludes.add("tags")
119119
if isinstance(configuration, ServiceConfiguration) and not configuration.rate_limits:
120120
configuration_excludes["rate_limits"] = True
121+
if configuration.shell is None:
122+
configuration_excludes["shell"] = True
121123

122124
if configuration_excludes:
123125
spec_excludes["configuration"] = configuration_excludes

src/tests/_internal/core/models/test_configurations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,26 @@ def test_conf(replicas: Any, scaling: Optional[Any] = None):
5252
)
5353
)
5454

55+
@pytest.mark.parametrize("shell", [None, "sh", "bash", "/usr/bin/zsh"])
56+
def test_shell_valid(self, shell: Optional[str]):
57+
conf = {
58+
"type": "task",
59+
"shell": shell,
60+
"commands": ["sleep inf"],
61+
}
62+
assert parse_run_configuration(conf).shell == shell
63+
64+
def test_shell_invalid(self):
65+
conf = {
66+
"type": "task",
67+
"shell": "zsh",
68+
"commands": ["sleep inf"],
69+
}
70+
with pytest.raises(
71+
ConfigurationError, match="The value must be `sh`, `bash`, or an absolute path"
72+
):
73+
parse_run_configuration(conf)
74+
5575

5676
def test_registry_auth_hashable():
5777
"""

src/tests/_internal/server/routers/test_runs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def get_dev_env_run_plan_dict(
9393
"version": None,
9494
"image": None,
9595
"user": None,
96+
"shell": None,
9697
"privileged": privileged,
9798
"init": [],
9899
"ports": [],
@@ -247,6 +248,7 @@ def get_dev_env_run_dict(
247248
"version": None,
248249
"image": None,
249250
"user": None,
251+
"shell": None,
250252
"privileged": privileged,
251253
"init": [],
252254
"ports": [],

src/tests/_internal/server/services/jobs/configurators/test_task.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
from typing import Optional
12
from unittest.mock import patch
23

34
import pytest
45

56
from dstack._internal.core.models.configurations import TaskConfiguration
67
from dstack._internal.core.models.runs import JobSSHKey
8+
from dstack._internal.server.services.docker import ImageConfig
79
from dstack._internal.server.services.jobs.configurators.task import TaskJobConfigurator
810
from dstack._internal.server.testing.common import get_run_spec
911

1012

1113
@pytest.mark.asyncio
1214
@pytest.mark.usefixtures("image_config_mock")
13-
class TestTaskJobConfigurator:
14-
async def test_ssh_key_single_node(self):
15+
class TestSSHKey:
16+
async def test_single_node(self):
1517
configuration = TaskConfiguration(nodes=1, image="debian")
1618
run_spec = get_run_spec(run_name="run", repo_id="id", configuration=configuration)
1719
configurator = TaskJobConfigurator(run_spec)
@@ -21,7 +23,7 @@ async def test_ssh_key_single_node(self):
2123
assert len(job_specs) == 1
2224
assert job_specs[0].ssh_key is None
2325

24-
async def test_ssh_key_multi_node(self):
26+
async def test_multi_node(self):
2527
configuration = TaskConfiguration(nodes=2, image="debian")
2628
run_spec = get_run_spec(run_name="run", repo_id="id", configuration=configuration)
2729
configurator = TaskJobConfigurator(run_spec)
@@ -33,3 +35,74 @@ async def test_ssh_key_multi_node(self):
3335
assert len(job_specs) == 2
3436
assert job_specs[0].ssh_key == JobSSHKey(private="private1", public="public1")
3537
assert job_specs[1].ssh_key == JobSSHKey(private="private1", public="public1")
38+
39+
40+
@pytest.mark.asyncio
41+
@pytest.mark.usefixtures("image_config_mock")
42+
class TestCommands:
43+
@pytest.mark.parametrize(
44+
["commands", "expected_commands"],
45+
[
46+
pytest.param([], ["/entrypoint.sh", "-v"], id="no-commands"),
47+
pytest.param(["-x", "-u"], ["/entrypoint.sh", "-v", "-x", "-u"], id="with-commands"),
48+
],
49+
)
50+
async def test_with_entrypoint(self, commands: list[str], expected_commands: list[str]):
51+
configuration = TaskConfiguration(
52+
image="debian",
53+
entrypoint="/entrypoint.sh -v",
54+
commands=commands,
55+
)
56+
run_spec = get_run_spec(run_name="run", repo_id="id", configuration=configuration)
57+
configurator = TaskJobConfigurator(run_spec)
58+
59+
job_specs = await configurator.get_job_specs(replica_num=0)
60+
61+
assert job_specs[0].commands == expected_commands
62+
63+
@pytest.mark.parametrize(
64+
["shell", "expected_shell"],
65+
[
66+
pytest.param(None, "/bin/sh", id="default-shell"),
67+
pytest.param("sh", "/bin/sh", id="sh"),
68+
pytest.param("bash", "/bin/bash", id="bash"),
69+
pytest.param("/usr/bin/zsh", "/usr/bin/zsh", id="custom-shell"),
70+
],
71+
)
72+
async def test_with_commands_and_image(self, shell: Optional[str], expected_shell: str):
73+
configuration = TaskConfiguration(image="debian", commands=["sleep inf"], shell=shell)
74+
run_spec = get_run_spec(run_name="run", repo_id="id", configuration=configuration)
75+
configurator = TaskJobConfigurator(run_spec)
76+
77+
job_specs = await configurator.get_job_specs(replica_num=0)
78+
79+
assert job_specs[0].commands == [expected_shell, "-i", "-c", "sleep inf"]
80+
81+
@pytest.mark.parametrize(
82+
["shell", "expected_shell"],
83+
[
84+
pytest.param(None, "/bin/bash", id="default-shell"),
85+
pytest.param("sh", "/bin/sh", id="sh"),
86+
pytest.param("bash", "/bin/bash", id="bash"),
87+
pytest.param("/usr/bin/zsh", "/usr/bin/zsh", id="custom-shell"),
88+
],
89+
)
90+
async def test_with_commands_no_image(self, shell: Optional[str], expected_shell: str):
91+
configuration = TaskConfiguration(commands=["sleep inf"], shell=shell)
92+
run_spec = get_run_spec(run_name="run", repo_id="id", configuration=configuration)
93+
configurator = TaskJobConfigurator(run_spec)
94+
95+
job_specs = await configurator.get_job_specs(replica_num=0)
96+
97+
assert job_specs[0].commands == [expected_shell, "-i", "-c", "sleep inf"]
98+
99+
async def test_no_commands(self, image_config_mock: ImageConfig):
100+
image_config_mock.entrypoint = ["/entrypoint.sh"]
101+
image_config_mock.cmd = ["-f", "-x"]
102+
configuration = TaskConfiguration(image="debian")
103+
run_spec = get_run_spec(run_name="run", repo_id="id", configuration=configuration)
104+
configurator = TaskJobConfigurator(run_spec)
105+
106+
job_specs = await configurator.get_job_specs(replica_num=0)
107+
108+
assert job_specs[0].commands == ["/entrypoint.sh", "-f", "-x"]

0 commit comments

Comments
 (0)