Skip to content

Commit 8a05d0a

Browse files
authored
Add a new opt-in job network mode (#3043)
A new `DSTACK_SERVER_JOB_NETWORK_MODE` with three modes: - HOST_FOR_MULTINODE_ONLY (1) The new opt-in mode. "bridge" by default, unless it's a multinode run - HOST_WHEN_POSSIBLE (2) The current default. "host" by default, unless the job occupies only a part of the instance. - FORCED_BRIDGE (3) Same as legacy DSTACK_FORCE_BRIDGE_NETWORK=true Always "bridge", even for multinode runs. To opt-in: export DSTACK_SERVER_JOB_NETWORK_MODE=1 `DSTACK_FORCE_BRIDGE_NETWORK` is deprecated but supported with a migration warning
1 parent a6b25bb commit 8a05d0a

File tree

6 files changed

+357
-46
lines changed

6 files changed

+357
-46
lines changed

src/dstack/_internal/server/app.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ async def lifespan(app: FastAPI):
160160
logger.info("Background processing is disabled")
161161
PROBES_SCHEDULER.start()
162162
dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)"
163+
logger.info(
164+
"Job network mode: %s (%d)",
165+
settings.JOB_NETWORK_MODE.name,
166+
settings.JOB_NETWORK_MODE.value,
167+
)
163168
logger.info(f"The admin token is {admin.token.get_plaintext_or_error()}", {"show_path": False})
164169
logger.info(
165170
f"The dstack server {dstack_version} is running at {SERVER_URL}",

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@
8888
)
8989
from dstack._internal.server.utils import sentry_utils
9090
from dstack._internal.utils import common as common_utils
91-
from dstack._internal.utils import env as env_utils
9291
from dstack._internal.utils.logging import get_logger
9392

9493
logger = get_logger(__name__)
@@ -189,6 +188,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
189188
run_spec = run.run_spec
190189
profile = run_spec.merged_profile
191190
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
191+
multinode = job.job_spec.jobs_per_replica > 1
192192

193193
# Master job chooses fleet for the run.
194194
# Due to two-step processing, it's saved to job_model.fleet.
@@ -311,6 +311,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
311311
session=session,
312312
instances_with_offers=fleet_instances_with_offers,
313313
job_model=job_model,
314+
multinode=multinode,
314315
)
315316
job_model.fleet = fleet_model
316317
job_model.instance_assigned = True
@@ -387,7 +388,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
387388
offer=offer,
388389
instance_num=instance_num,
389390
)
390-
job_model.job_runtime_data = _prepare_job_runtime_data(offer).json()
391+
job_model.job_runtime_data = _prepare_job_runtime_data(offer, multinode).json()
391392
# Both this task and process_fleets can add instances to fleets.
392393
# TODO: Ensure this does not violate nodes.max when it's enforced.
393394
instance.fleet_id = fleet_model.id
@@ -616,6 +617,7 @@ async def _assign_job_to_fleet_instance(
616617
session: AsyncSession,
617618
instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]],
618619
job_model: JobModel,
620+
multinode: bool,
619621
) -> Optional[InstanceModel]:
620622
if len(instances_with_offers) == 0:
621623
return None
@@ -645,7 +647,7 @@ async def _assign_job_to_fleet_instance(
645647
job_model.instance = instance
646648
job_model.used_instance_id = instance.id
647649
job_model.job_provisioning_data = instance.job_provisioning_data
648-
job_model.job_runtime_data = _prepare_job_runtime_data(offer).json()
650+
job_model.job_runtime_data = _prepare_job_runtime_data(offer, multinode).json()
649651
return instance
650652

651653

@@ -852,12 +854,17 @@ def _create_instance_model_for_job(
852854
return instance
853855

854856

855-
def _prepare_job_runtime_data(offer: InstanceOfferWithAvailability) -> JobRuntimeData:
857+
def _prepare_job_runtime_data(
858+
offer: InstanceOfferWithAvailability, multinode: bool
859+
) -> JobRuntimeData:
856860
if offer.blocks == offer.total_blocks:
857-
if env_utils.get_bool("DSTACK_FORCE_BRIDGE_NETWORK"):
861+
if settings.JOB_NETWORK_MODE == settings.JobNetworkMode.FORCED_BRIDGE:
858862
network_mode = NetworkMode.BRIDGE
859-
else:
863+
elif settings.JOB_NETWORK_MODE == settings.JobNetworkMode.HOST_WHEN_POSSIBLE:
860864
network_mode = NetworkMode.HOST
865+
else:
866+
assert settings.JOB_NETWORK_MODE == settings.JobNetworkMode.HOST_FOR_MULTINODE_ONLY
867+
network_mode = NetworkMode.HOST if multinode else NetworkMode.BRIDGE
861868
return JobRuntimeData(
862869
network_mode=network_mode,
863870
offer=offer,

src/dstack/_internal/server/settings.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@
44

55
import os
66
import warnings
7+
from enum import Enum
78
from pathlib import Path
89

10+
from dstack._internal.utils.env import environ
11+
from dstack._internal.utils.logging import get_logger
12+
13+
logger = get_logger(__name__)
14+
915
DSTACK_DIR_PATH = Path("~/.dstack/").expanduser()
1016

1117
SERVER_DIR_PATH = Path(os.getenv("DSTACK_SERVER_DIR", DSTACK_DIR_PATH / "server"))
@@ -136,3 +142,43 @@
136142
DO_NOT_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None
137143
SKIP_GATEWAY_UPDATE = os.getenv("DSTACK_SKIP_GATEWAY_UPDATE") is not None
138144
ENABLE_PROMETHEUS_METRICS = os.getenv("DSTACK_ENABLE_PROMETHEUS_METRICS") is not None
145+
146+
147+
class JobNetworkMode(Enum):
148+
# "host" for multinode runs only, "bridge" otherwise. Opt-in new defaut
149+
HOST_FOR_MULTINODE_ONLY = 1
150+
# "bridge" if the job occupies only a part of the instance, "host" otherswise. Current default
151+
HOST_WHEN_POSSIBLE = 2
152+
# Always "bridge", even for multinode runs. Same as legacy DSTACK_FORCE_BRIDGE_NETWORK=true
153+
FORCED_BRIDGE = 3
154+
155+
156+
def _get_job_network_mode() -> JobNetworkMode:
157+
# Current default
158+
mode = JobNetworkMode.HOST_WHEN_POSSIBLE
159+
bridge_var = "DSTACK_FORCE_BRIDGE_NETWORK"
160+
force_bridge = environ.get_bool(bridge_var)
161+
mode_var = "DSTACK_SERVER_JOB_NETWORK_MODE"
162+
mode_from_env = environ.get_enum(mode_var, JobNetworkMode, value_type=int)
163+
if mode_from_env is not None:
164+
if force_bridge is not None:
165+
logger.warning(
166+
f"{bridge_var} is deprecated since 0.19.27 and ignored when {mode_var} is set"
167+
)
168+
return mode_from_env
169+
if force_bridge is not None:
170+
if force_bridge:
171+
mode = JobNetworkMode.FORCED_BRIDGE
172+
logger.warning(
173+
(
174+
f"{bridge_var} is deprecated since 0.19.27."
175+
f" Set {mode_var} to {mode.value} and remove {bridge_var}"
176+
)
177+
)
178+
else:
179+
logger.warning(f"{bridge_var} is deprecated since 0.19.27. Remove {bridge_var}")
180+
return mode
181+
182+
183+
JOB_NETWORK_MODE = _get_job_network_mode()
184+
del _get_job_network_mode

src/dstack/_internal/utils/env.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,88 @@
11
import os
2+
from collections.abc import Mapping
3+
from enum import Enum
4+
from typing import Optional, TypeVar, Union, overload
25

6+
_Value = Union[str, int]
7+
_T = TypeVar("_T", bound=Enum)
38

4-
def get_bool(name: str, default: bool = False) -> bool:
5-
try:
6-
value = os.environ[name]
7-
except KeyError:
8-
return default
9-
value = value.lower()
10-
if value in ["0", "false", "off"]:
11-
return False
12-
if value in ["1", "true", "on"]:
13-
return True
14-
raise ValueError(f"Invalid bool value: {name}={value}")
9+
10+
class Environ:
11+
def __init__(self, environ: Mapping[str, str]):
12+
self._environ = environ
13+
14+
@overload
15+
def get_bool(self, name: str, *, default: None = None) -> Optional[bool]: ...
16+
17+
@overload
18+
def get_bool(self, name: str, *, default: bool) -> bool: ...
19+
20+
def get_bool(self, name: str, *, default: Optional[bool] = None) -> Optional[bool]:
21+
try:
22+
raw_value = self._environ[name]
23+
except KeyError:
24+
return default
25+
value = raw_value.lower()
26+
if value in ["0", "false", "off"]:
27+
return False
28+
if value in ["1", "true", "on"]:
29+
return True
30+
raise ValueError(f"Invalid bool value: {name}={raw_value}")
31+
32+
@overload
33+
def get_int(self, name: str, *, default: None = None) -> Optional[int]: ...
34+
35+
@overload
36+
def get_int(self, name: str, *, default: int) -> int: ...
37+
38+
def get_int(self, name: str, *, default: Optional[int] = None) -> Optional[int]:
39+
try:
40+
raw_value = self._environ[name]
41+
except KeyError:
42+
return default
43+
try:
44+
return int(raw_value)
45+
except ValueError as e:
46+
raise ValueError(f"Invalid int value: {e}: {name}={raw_value}") from e
47+
48+
@overload
49+
def get_enum(
50+
self,
51+
name: str,
52+
enum_cls: type[_T],
53+
*,
54+
value_type: Optional[type[_Value]] = None,
55+
default: None = None,
56+
) -> Optional[_T]: ...
57+
58+
@overload
59+
def get_enum(
60+
self,
61+
name: str,
62+
enum_cls: type[_T],
63+
*,
64+
value_type: Optional[type[_Value]] = None,
65+
default: _T,
66+
) -> _T: ...
67+
68+
def get_enum(
69+
self,
70+
name: str,
71+
enum_cls: type[_T],
72+
*,
73+
value_type: Optional[type[_Value]] = None,
74+
default: Optional[_T] = None,
75+
) -> Optional[_T]:
76+
try:
77+
raw_value = self._environ[name]
78+
except KeyError:
79+
return default
80+
try:
81+
if value_type is not None:
82+
raw_value = value_type(raw_value)
83+
return enum_cls(raw_value)
84+
except (ValueError, TypeError) as e:
85+
raise ValueError(f"Invalid {enum_cls.__name__} value: {e}: {name}={raw_value}") from e
86+
87+
88+
environ = Environ(os.environ)

src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy.orm import joinedload
88

99
from dstack._internal.core.models.backends.base import BackendType
10+
from dstack._internal.core.models.common import NetworkMode
1011
from dstack._internal.core.models.configurations import TaskConfiguration
1112
from dstack._internal.core.models.fleets import FleetNodesSpec
1213
from dstack._internal.core.models.health import HealthStatus
@@ -25,8 +26,12 @@
2526
VolumeMountPoint,
2627
VolumeStatus,
2728
)
28-
from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs
29+
from dstack._internal.server.background.tasks.process_submitted_jobs import (
30+
_prepare_job_runtime_data,
31+
process_submitted_jobs,
32+
)
2933
from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel
34+
from dstack._internal.server.settings import JobNetworkMode
3035
from dstack._internal.server.testing.common import (
3136
ComputeMockSpec,
3237
create_fleet,
@@ -1004,3 +1009,102 @@ async def test_picks_high_priority_jobs_first(self, test_db, session: AsyncSessi
10041009
await process_submitted_jobs()
10051010
await session.refresh(job2)
10061011
assert job2.status == JobStatus.PROVISIONING
1012+
1013+
1014+
@pytest.mark.parametrize(
1015+
["job_network_mode", "blocks", "multinode", "network_mode", "constraints_are_set"],
1016+
[
1017+
pytest.param(
1018+
JobNetworkMode.HOST_FOR_MULTINODE_ONLY,
1019+
2,
1020+
False,
1021+
NetworkMode.BRIDGE,
1022+
True,
1023+
id="host-for-multinode-only--half-of-instance",
1024+
),
1025+
pytest.param(
1026+
JobNetworkMode.HOST_FOR_MULTINODE_ONLY,
1027+
4,
1028+
False,
1029+
NetworkMode.BRIDGE,
1030+
False,
1031+
id="host-for-multinode-only--entire-instance",
1032+
),
1033+
pytest.param(
1034+
JobNetworkMode.HOST_FOR_MULTINODE_ONLY,
1035+
4,
1036+
True,
1037+
NetworkMode.HOST,
1038+
False,
1039+
id="host-for-multinode-only--entire-instance--multinode",
1040+
),
1041+
pytest.param(
1042+
JobNetworkMode.HOST_WHEN_POSSIBLE,
1043+
2,
1044+
False,
1045+
NetworkMode.BRIDGE,
1046+
True,
1047+
id="host-when-possible--half-of-instance",
1048+
),
1049+
pytest.param(
1050+
JobNetworkMode.HOST_WHEN_POSSIBLE,
1051+
4,
1052+
False,
1053+
NetworkMode.HOST,
1054+
False,
1055+
id="host-when-possible--entire-instance",
1056+
),
1057+
pytest.param(
1058+
JobNetworkMode.HOST_WHEN_POSSIBLE,
1059+
4,
1060+
True,
1061+
NetworkMode.HOST,
1062+
False,
1063+
id="host-when-possible--entire-instance--multinode",
1064+
),
1065+
pytest.param(
1066+
JobNetworkMode.FORCED_BRIDGE,
1067+
2,
1068+
False,
1069+
NetworkMode.BRIDGE,
1070+
True,
1071+
id="forced-bridge--half-of-instance",
1072+
),
1073+
pytest.param(
1074+
JobNetworkMode.FORCED_BRIDGE,
1075+
4,
1076+
False,
1077+
NetworkMode.BRIDGE,
1078+
False,
1079+
id="forced-bridge--entire-instance",
1080+
),
1081+
pytest.param(
1082+
JobNetworkMode.FORCED_BRIDGE,
1083+
4,
1084+
True,
1085+
NetworkMode.BRIDGE,
1086+
False,
1087+
id="forced-bridge--entire-instance--multinode",
1088+
),
1089+
],
1090+
)
1091+
def test_prepare_job_runtime_data(
1092+
monkeypatch: pytest.MonkeyPatch,
1093+
job_network_mode: JobNetworkMode,
1094+
blocks: int,
1095+
multinode: bool,
1096+
network_mode: NetworkMode,
1097+
constraints_are_set: bool,
1098+
):
1099+
monkeypatch.setattr("dstack._internal.server.settings.JOB_NETWORK_MODE", job_network_mode)
1100+
offer = get_instance_offer_with_availability(blocks=blocks, total_blocks=4)
1101+
jrd = _prepare_job_runtime_data(offer=offer, multinode=multinode)
1102+
assert jrd.network_mode == network_mode
1103+
if constraints_are_set:
1104+
assert jrd.gpu is not None
1105+
assert jrd.cpu is not None
1106+
assert jrd.memory is not None
1107+
else:
1108+
assert jrd.gpu is None
1109+
assert jrd.cpu is None
1110+
assert jrd.memory is None

0 commit comments

Comments
 (0)