Skip to content

Commit f531048

Browse files
committed
Add a new opt-in job network mode
A new `DSTACK_SERVER_JOB_NETWORK_MODE` with three modes: - HOST_FOR_MULTINODE_ONLY (1) The new opt-in mode. "bridge" by default, unless it's a multinode run - HOST_WHEN_POSSIBLE (2) The current default. "host" by default, unless the job occupies only a part of the instance. - FORCED_BRIDGE (3) Same as legacy DSTACK_FORCE_BRIDGE_NETWORK=true Always "bridge", even for multinode runs. To opt-in: export DSTACK_SERVER_JOB_NETWORK_MODE=1 `DSTACK_FORCE_BRIDGE_NETWORK` is deprecated but supported with a migration warning
1 parent c7ccdb9 commit f531048

6 files changed

Lines changed: 357 additions & 46 deletions

File tree

src/dstack/_internal/server/app.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ async def lifespan(app: FastAPI):
160160
logger.info("Background processing is disabled")
161161
PROBES_SCHEDULER.start()
162162
dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)"
163+
logger.info(
164+
"Job network mode: %s (%d)",
165+
settings.JOB_NETWORK_MODE.name,
166+
settings.JOB_NETWORK_MODE.value,
167+
)
163168
logger.info(f"The admin token is {admin.token.get_plaintext_or_error()}", {"show_path": False})
164169
logger.info(
165170
f"The dstack server {dstack_version} is running at {SERVER_URL}",

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@
8585
)
8686
from dstack._internal.server.utils import sentry_utils
8787
from dstack._internal.utils import common as common_utils
88-
from dstack._internal.utils import env as env_utils
8988
from dstack._internal.utils.logging import get_logger
9089

9190
logger = get_logger(__name__)
@@ -186,6 +185,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
186185
run_spec = run.run_spec
187186
profile = run_spec.merged_profile
188187
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
188+
multinode = job.job_spec.jobs_per_replica > 1
189189

190190
# Master job chooses fleet for the run.
191191
# Due to two-step processing, it's saved to job_model.fleet.
@@ -308,6 +308,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
308308
session=session,
309309
instances_with_offers=fleet_instances_with_offers,
310310
job_model=job_model,
311+
multinode=multinode,
311312
)
312313
job_model.fleet = fleet_model
313314
job_model.instance_assigned = True
@@ -383,7 +384,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
383384
offer=offer,
384385
instance_num=instance_num,
385386
)
386-
job_model.job_runtime_data = _prepare_job_runtime_data(offer).json()
387+
job_model.job_runtime_data = _prepare_job_runtime_data(offer, multinode).json()
387388
instance.fleet_id = fleet_model.id
388389
logger.info(
389390
"The job %s created the new instance %s",
@@ -610,6 +611,7 @@ async def _assign_job_to_fleet_instance(
610611
session: AsyncSession,
611612
instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]],
612613
job_model: JobModel,
614+
multinode: bool,
613615
) -> Optional[InstanceModel]:
614616
if len(instances_with_offers) == 0:
615617
return None
@@ -639,7 +641,7 @@ async def _assign_job_to_fleet_instance(
639641
job_model.instance = instance
640642
job_model.used_instance_id = instance.id
641643
job_model.job_provisioning_data = instance.job_provisioning_data
642-
job_model.job_runtime_data = _prepare_job_runtime_data(offer).json()
644+
job_model.job_runtime_data = _prepare_job_runtime_data(offer, multinode).json()
643645
return instance
644646

645647

@@ -827,12 +829,17 @@ def _create_instance_model_for_job(
827829
return instance
828830

829831

830-
def _prepare_job_runtime_data(offer: InstanceOfferWithAvailability) -> JobRuntimeData:
832+
def _prepare_job_runtime_data(
833+
offer: InstanceOfferWithAvailability, multinode: bool
834+
) -> JobRuntimeData:
831835
if offer.blocks == offer.total_blocks:
832-
if env_utils.get_bool("DSTACK_FORCE_BRIDGE_NETWORK"):
836+
if settings.JOB_NETWORK_MODE == settings.JobNetworkMode.FORCED_BRIDGE:
833837
network_mode = NetworkMode.BRIDGE
834-
else:
838+
elif settings.JOB_NETWORK_MODE == settings.JobNetworkMode.HOST_WHEN_POSSIBLE:
835839
network_mode = NetworkMode.HOST
840+
else:
841+
assert settings.JOB_NETWORK_MODE == settings.JobNetworkMode.HOST_FOR_MULTINODE_ONLY
842+
network_mode = NetworkMode.HOST if multinode else NetworkMode.BRIDGE
836843
return JobRuntimeData(
837844
network_mode=network_mode,
838845
offer=offer,

src/dstack/_internal/server/settings.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@
44

55
import os
66
import warnings
7+
from enum import Enum
78
from pathlib import Path
89

10+
from dstack._internal.utils.env import environ
11+
from dstack._internal.utils.logging import get_logger
12+
13+
logger = get_logger(__name__)
14+
915
DSTACK_DIR_PATH = Path("~/.dstack/").expanduser()
1016

1117
SERVER_DIR_PATH = Path(os.getenv("DSTACK_SERVER_DIR", DSTACK_DIR_PATH / "server"))
@@ -136,3 +142,43 @@
136142
DO_NOT_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None
137143
SKIP_GATEWAY_UPDATE = os.getenv("DSTACK_SKIP_GATEWAY_UPDATE") is not None
138144
ENABLE_PROMETHEUS_METRICS = os.getenv("DSTACK_ENABLE_PROMETHEUS_METRICS") is not None
145+
146+
147+
class JobNetworkMode(Enum):
148+
# "host" for multinode runs only, "bridge" otherwise. Opt-in new defaut
149+
HOST_FOR_MULTINODE_ONLY = 1
150+
# "bridge" if the job occupies only a part of the instance, "host" otherswise. Current default
151+
HOST_WHEN_POSSIBLE = 2
152+
# Always "bridge", even for multinode runs. Same as legacy DSTACK_FORCE_BRIDGE_NETWORK=true
153+
FORCED_BRIDGE = 3
154+
155+
156+
def _get_job_network_mode() -> JobNetworkMode:
157+
# Current default
158+
mode = JobNetworkMode.HOST_WHEN_POSSIBLE
159+
bridge_var = "DSTACK_FORCE_BRIDGE_NETWORK"
160+
force_bridge = environ.get_bool(bridge_var)
161+
mode_var = "DSTACK_SERVER_JOB_NETWORK_MODE"
162+
mode_from_env = environ.get_enum(mode_var, JobNetworkMode, value_type=int)
163+
if mode_from_env is not None:
164+
if force_bridge is not None:
165+
logger.warning(
166+
f"{bridge_var} is deprecated since 0.19.27 and ignored when {mode_var} is set"
167+
)
168+
return mode_from_env
169+
if force_bridge is not None:
170+
if force_bridge:
171+
mode = JobNetworkMode.FORCED_BRIDGE
172+
logger.warning(
173+
(
174+
f"{bridge_var} is deprecated since 0.19.27."
175+
f" Set {mode_var} to {mode.value} and remove {bridge_var}"
176+
)
177+
)
178+
else:
179+
logger.warning(f"{bridge_var} is deprecated since 0.19.27. Remove {bridge_var}")
180+
return mode
181+
182+
183+
JOB_NETWORK_MODE = _get_job_network_mode()
184+
del _get_job_network_mode

src/dstack/_internal/utils/env.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,88 @@
11
import os
2+
from collections.abc import Mapping
3+
from enum import Enum
4+
from typing import Optional, TypeVar, Union, overload
25

6+
_Value = Union[str, int]
7+
_T = TypeVar("_T", bound=Enum)
38

4-
def get_bool(name: str, default: bool = False) -> bool:
5-
try:
6-
value = os.environ[name]
7-
except KeyError:
8-
return default
9-
value = value.lower()
10-
if value in ["0", "false", "off"]:
11-
return False
12-
if value in ["1", "true", "on"]:
13-
return True
14-
raise ValueError(f"Invalid bool value: {name}={value}")
9+
10+
class Environ:
11+
def __init__(self, environ: Mapping[str, str]):
12+
self._environ = environ
13+
14+
@overload
15+
def get_bool(self, name: str, *, default: None = None) -> Optional[bool]: ...
16+
17+
@overload
18+
def get_bool(self, name: str, *, default: bool) -> bool: ...
19+
20+
def get_bool(self, name: str, *, default: Optional[bool] = None) -> Optional[bool]:
21+
try:
22+
raw_value = self._environ[name]
23+
except KeyError:
24+
return default
25+
value = raw_value.lower()
26+
if value in ["0", "false", "off"]:
27+
return False
28+
if value in ["1", "true", "on"]:
29+
return True
30+
raise ValueError(f"Invalid bool value: {name}={raw_value}")
31+
32+
@overload
33+
def get_int(self, name: str, *, default: None = None) -> Optional[int]: ...
34+
35+
@overload
36+
def get_int(self, name: str, *, default: int) -> int: ...
37+
38+
def get_int(self, name: str, *, default: Optional[int] = None) -> Optional[int]:
39+
try:
40+
raw_value = self._environ[name]
41+
except KeyError:
42+
return default
43+
try:
44+
return int(raw_value)
45+
except ValueError as e:
46+
raise ValueError(f"Invalid int value: {e}: {name}={raw_value}") from e
47+
48+
@overload
49+
def get_enum(
50+
self,
51+
name: str,
52+
enum_cls: type[_T],
53+
*,
54+
value_type: Optional[type[_Value]] = None,
55+
default: None = None,
56+
) -> Optional[_T]: ...
57+
58+
@overload
59+
def get_enum(
60+
self,
61+
name: str,
62+
enum_cls: type[_T],
63+
*,
64+
value_type: Optional[type[_Value]] = None,
65+
default: _T,
66+
) -> _T: ...
67+
68+
def get_enum(
69+
self,
70+
name: str,
71+
enum_cls: type[_T],
72+
*,
73+
value_type: Optional[type[_Value]] = None,
74+
default: Optional[_T] = None,
75+
) -> Optional[_T]:
76+
try:
77+
raw_value = self._environ[name]
78+
except KeyError:
79+
return default
80+
try:
81+
if value_type is not None:
82+
raw_value = value_type(raw_value)
83+
return enum_cls(raw_value)
84+
except (ValueError, TypeError) as e:
85+
raise ValueError(f"Invalid {enum_cls.__name__} value: {e}: {name}={raw_value}") from e
86+
87+
88+
environ = Environ(os.environ)

src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy.orm import joinedload
88

99
from dstack._internal.core.models.backends.base import BackendType
10+
from dstack._internal.core.models.common import NetworkMode
1011
from dstack._internal.core.models.configurations import TaskConfiguration
1112
from dstack._internal.core.models.health import HealthStatus
1213
from dstack._internal.core.models.instances import (
@@ -25,8 +26,12 @@
2526
VolumeMountPoint,
2627
VolumeStatus,
2728
)
28-
from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs
29+
from dstack._internal.server.background.tasks.process_submitted_jobs import (
30+
_prepare_job_runtime_data,
31+
process_submitted_jobs,
32+
)
2933
from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel
34+
from dstack._internal.server.settings import JobNetworkMode
3035
from dstack._internal.server.testing.common import (
3136
ComputeMockSpec,
3237
create_fleet,
@@ -1004,3 +1009,102 @@ async def test_picks_high_priority_jobs_first(self, test_db, session: AsyncSessi
10041009
await process_submitted_jobs()
10051010
await session.refresh(job2)
10061011
assert job2.status == JobStatus.PROVISIONING
1012+
1013+
1014+
@pytest.mark.parametrize(
1015+
["job_network_mode", "blocks", "multinode", "network_mode", "constraints_are_set"],
1016+
[
1017+
pytest.param(
1018+
JobNetworkMode.HOST_FOR_MULTINODE_ONLY,
1019+
2,
1020+
False,
1021+
NetworkMode.BRIDGE,
1022+
True,
1023+
id="host-for-multinode-only--half-of-instance",
1024+
),
1025+
pytest.param(
1026+
JobNetworkMode.HOST_FOR_MULTINODE_ONLY,
1027+
4,
1028+
False,
1029+
NetworkMode.BRIDGE,
1030+
False,
1031+
id="host-for-multinode-only--entire-instance",
1032+
),
1033+
pytest.param(
1034+
JobNetworkMode.HOST_FOR_MULTINODE_ONLY,
1035+
4,
1036+
True,
1037+
NetworkMode.HOST,
1038+
False,
1039+
id="host-for-multinode-only--entire-instance--multinode",
1040+
),
1041+
pytest.param(
1042+
JobNetworkMode.HOST_WHEN_POSSIBLE,
1043+
2,
1044+
False,
1045+
NetworkMode.BRIDGE,
1046+
True,
1047+
id="host-when-possible--half-of-instance",
1048+
),
1049+
pytest.param(
1050+
JobNetworkMode.HOST_WHEN_POSSIBLE,
1051+
4,
1052+
False,
1053+
NetworkMode.HOST,
1054+
False,
1055+
id="host-when-possible--entire-instance",
1056+
),
1057+
pytest.param(
1058+
JobNetworkMode.HOST_WHEN_POSSIBLE,
1059+
4,
1060+
True,
1061+
NetworkMode.HOST,
1062+
False,
1063+
id="host-when-possible--entire-instance--multinode",
1064+
),
1065+
pytest.param(
1066+
JobNetworkMode.FORCED_BRIDGE,
1067+
2,
1068+
False,
1069+
NetworkMode.BRIDGE,
1070+
True,
1071+
id="forced-bridge--half-of-instance",
1072+
),
1073+
pytest.param(
1074+
JobNetworkMode.FORCED_BRIDGE,
1075+
4,
1076+
False,
1077+
NetworkMode.BRIDGE,
1078+
False,
1079+
id="forced-bridge--entire-instance",
1080+
),
1081+
pytest.param(
1082+
JobNetworkMode.FORCED_BRIDGE,
1083+
4,
1084+
True,
1085+
NetworkMode.BRIDGE,
1086+
False,
1087+
id="forced-bridge--entire-instance--multinode",
1088+
),
1089+
],
1090+
)
1091+
def test_prepare_job_runtime_data(
1092+
monkeypatch: pytest.MonkeyPatch,
1093+
job_network_mode: JobNetworkMode,
1094+
blocks: int,
1095+
multinode: bool,
1096+
network_mode: NetworkMode,
1097+
constraints_are_set: bool,
1098+
):
1099+
monkeypatch.setattr("dstack._internal.server.settings.JOB_NETWORK_MODE", job_network_mode)
1100+
offer = get_instance_offer_with_availability(blocks=blocks, total_blocks=4)
1101+
jrd = _prepare_job_runtime_data(offer=offer, multinode=multinode)
1102+
assert jrd.network_mode == network_mode
1103+
if constraints_are_set:
1104+
assert jrd.gpu is not None
1105+
assert jrd.cpu is not None
1106+
assert jrd.memory is not None
1107+
else:
1108+
assert jrd.gpu is None
1109+
assert jrd.cpu is None
1110+
assert jrd.memory is None

0 commit comments

Comments
 (0)