Skip to content

Commit 6f1f1f4

Browse files
authored
Respect top-level blocks in SSH fleet configuration (#3700)
Fixes: #3278
1 parent ae3b6be commit 6f1f1f4

File tree

4 files changed

+66
-7
lines changed

4 files changed

+66
-7
lines changed

src/dstack/_internal/core/models/fleets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,17 @@ class SSHHostParams(CoreModel):
7878
ssh_key: Optional[SSHKey] = None
7979

8080
blocks: Annotated[
81-
Union[Literal["auto"], int],
81+
Optional[Union[Literal["auto"], int]],
8282
Field(
8383
description=(
8484
"The amount of blocks to split the instance into, a number or `auto`."
8585
" `auto` means as many as possible."
8686
" The number of GPUs and CPUs must be divisible by the number of blocks."
87-
" Defaults to `1`, i.e. do not split"
87+
" Defaults to the top-level `blocks` value."
8888
),
8989
ge=1,
9090
),
91-
] = 1
91+
] = None
9292

9393
@validator("internal_ip")
9494
def validate_internal_ip(cls, value):

src/dstack/_internal/server/services/fleets.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ async def create_fleet_ssh_instance_model(
674674
spec: FleetSpec,
675675
ssh_params: SSHParams,
676676
env: Env,
677+
blocks: Union[int, Literal["auto"]],
677678
instance_num: int,
678679
host: Union[SSHHostParams, str],
679680
) -> InstanceModel:
@@ -684,15 +685,15 @@ async def create_fleet_ssh_instance_model(
684685
port = ssh_params.port
685686
proxy_jump = ssh_params.proxy_jump
686687
internal_ip = None
687-
blocks = 1
688688
else:
689689
hostname = host.hostname
690690
ssh_user = host.user or ssh_params.user
691691
ssh_key = host.ssh_key or ssh_params.ssh_key
692692
port = host.port or ssh_params.port
693693
proxy_jump = host.proxy_jump or ssh_params.proxy_jump
694694
internal_ip = host.internal_ip
695-
blocks = host.blocks
695+
if host.blocks is not None:
696+
blocks = host.blocks
696697

697698
if ssh_user is None or ssh_key is None:
698699
# This should not be reachable but checked by fleet spec validation
@@ -1042,6 +1043,7 @@ async def _create_fleet(
10421043
spec=spec,
10431044
ssh_params=spec.configuration.ssh_config,
10441045
env=spec.configuration.env,
1046+
blocks=spec.configuration.blocks,
10451047
instance_num=i,
10461048
host=host,
10471049
)
@@ -1152,6 +1154,7 @@ async def _update_fleet(
11521154
spec=spec,
11531155
ssh_params=spec.configuration.ssh_config,
11541156
env=spec.configuration.env,
1157+
blocks=spec.configuration.blocks,
11551158
instance_num=instance_num,
11561159
host=host,
11571160
)

src/dstack/_internal/server/testing/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Callable
44
from contextlib import contextmanager
55
from datetime import datetime, timezone
6-
from typing import Dict, List, Literal, Optional, Union
6+
from typing import Any, Dict, List, Literal, Optional, Union
77
from uuid import UUID
88

99
import gpuhunt
@@ -703,6 +703,7 @@ def get_ssh_fleet_configuration(
703703
hosts: Optional[list[Union[SSHHostParams, str]]] = None,
704704
network: Optional[str] = None,
705705
placement: Optional[InstanceGroupPlacement] = None,
706+
blocks: Optional[Union[int, Literal["auto"]]] = None,
706707
) -> FleetConfiguration:
707708
if ssh_key is None:
708709
ssh_key = SSHKey(public="", private=get_private_key_string())
@@ -714,10 +715,14 @@ def get_ssh_fleet_configuration(
714715
hosts=hosts,
715716
network=network,
716717
)
718+
optional_properties: dict[str, Any] = {}
719+
if blocks is not None:
720+
optional_properties["blocks"] = blocks
717721
return FleetConfiguration(
718722
name=name,
719723
ssh_config=ssh_config,
720724
placement=placement,
725+
**optional_properties,
721726
)
722727

723728

src/tests/_internal/server/routers/test_fleets.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from datetime import datetime, timezone
3-
from typing import Optional, Union
3+
from typing import Literal, Optional, Union
44
from unittest.mock import Mock, patch
55
from uuid import uuid4
66

@@ -16,6 +16,7 @@
1616
FleetConfiguration,
1717
FleetStatus,
1818
InstanceGroupPlacement,
19+
SSHHostParams,
1920
SSHParams,
2021
)
2122
from dstack._internal.core.models.instances import (
@@ -1178,6 +1179,56 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
11781179
instance = res.unique().scalar_one()
11791180
assert instance.remote_connection_info is not None
11801181

1182+
@pytest.mark.parametrize(
1183+
["top_level_blocks", "host_blocks", "host_type", "expected_blocks"],
1184+
[
1185+
pytest.param(None, None, str, 1, id="global-default-string"),
1186+
pytest.param(None, None, SSHHostParams, 1, id="global-default-object"),
1187+
pytest.param(4, None, str, 4, id="top-level-int-string"),
1188+
pytest.param(4, None, SSHHostParams, 4, id="top-level-int-object"),
1189+
pytest.param("auto", None, str, None, id="top-level-auto-string"),
1190+
pytest.param("auto", None, SSHHostParams, None, id="top-level-auto-object"),
1191+
pytest.param("auto", 4, SSHHostParams, 4, id="host-level-int"),
1192+
pytest.param(4, "auto", SSHHostParams, None, id="host-level-auto"),
1193+
],
1194+
)
1195+
@pytest.mark.asyncio
1196+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1197+
async def test_creates_ssh_fleet_with_blocks(
1198+
self,
1199+
test_db,
1200+
session: AsyncSession,
1201+
client: AsyncClient,
1202+
top_level_blocks: Optional[Union[int, Literal["auto"]]],
1203+
host_blocks: Optional[Union[int, Literal["auto"]]],
1204+
host_type: Union[type[str], type[SSHHostParams]],
1205+
expected_blocks: Optional[int],
1206+
):
1207+
user = await create_user(session, global_role=GlobalRole.USER)
1208+
project = await create_project(session)
1209+
await add_project_member(
1210+
session=session, project=project, user=user, project_role=ProjectRole.USER
1211+
)
1212+
if host_type is str:
1213+
host = "1.1.1.1"
1214+
elif host_blocks is None:
1215+
host = SSHHostParams(hostname="1.1.1.1")
1216+
else:
1217+
host = SSHHostParams(hostname="1.1.1.1", blocks=host_blocks)
1218+
conf = get_ssh_fleet_configuration(blocks=top_level_blocks, hosts=[host])
1219+
spec = get_fleet_spec(conf=conf)
1220+
response = await client.post(
1221+
f"/api/project/{project.name}/fleets/apply",
1222+
headers=get_auth_headers(user.token),
1223+
json={"plan": {"spec": spec.dict()}, "force": False},
1224+
)
1225+
assert response.status_code == 200, response.json()
1226+
res = await session.execute(select(FleetModel))
1227+
assert len(res.scalars().all()) == 1
1228+
res = await session.execute(select(InstanceModel))
1229+
instance = res.scalar_one()
1230+
assert instance.total_blocks == expected_blocks
1231+
11811232
@pytest.mark.asyncio
11821233
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
11831234
@freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), real_asyncio=True)

0 commit comments

Comments
 (0)