Skip to content

Commit ea85c59

Browse files
committed
Support nodes in-place update for cloud fleets
1 parent da026f9 commit ea85c59

4 files changed

Lines changed: 228 additions & 14 deletions

File tree

src/dstack/_internal/server/background/pipeline_tasks/fleets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
workers_num: int = 10,
6464
queue_lower_limit_factor: float = 0.5,
6565
queue_upper_limit_factor: float = 2.0,
66-
min_processing_interval: timedelta = timedelta(seconds=30),
66+
min_processing_interval: timedelta = timedelta(seconds=15),
6767
lock_timeout: timedelta = timedelta(seconds=20),
6868
heartbeat_trigger: timedelta = timedelta(seconds=10),
6969
*,

src/dstack/_internal/server/services/fleets.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,20 +1225,25 @@ def _check_can_update_fleet_spec(current: FleetSpec, new: FleetSpec, diff: Model
12251225
_check_can_update_fleet_configuration(current.configuration, new.configuration)
12261226

12271227

1228-
@_check_can_update("ssh_config")
1229-
def _check_can_update_fleet_configuration(
1230-
current: FleetConfiguration, new: FleetConfiguration, diff: ModelDiff
1231-
):
1228+
def _check_can_update_fleet_configuration(current: FleetConfiguration, new: FleetConfiguration):
1229+
diff = diff_models(current, new)
1230+
current_ssh_config = current.ssh_config
1231+
new_ssh_config = new.ssh_config
1232+
if current_ssh_config is None:
1233+
if new_ssh_config is not None:
1234+
raise ServerClientError("Fleet type changed from Cloud to SSH, cannot update")
1235+
# TODO: Support best-effort `nodes.target` apply semantics:
1236+
# create missing instances and terminate extra idle instances.
1237+
# Current in-place update only persists `target`; FleetPipeline reconciles `min`/`max`.
1238+
_check_can_update_inner(current, new, ("nodes",))
1239+
return
1240+
1241+
if new_ssh_config is None:
1242+
raise ServerClientError("Fleet type changed from SSH to Cloud, cannot update")
1243+
1244+
_check_can_update_inner(current, new, ("ssh_config",))
12321245
if "ssh_config" in diff:
1233-
current_ssh_config = current.ssh_config
1234-
new_ssh_config = new.ssh_config
1235-
if current_ssh_config is None:
1236-
if new_ssh_config is not None:
1237-
raise ServerClientError("Fleet type changed from Cloud to SSH, cannot update")
1238-
elif new_ssh_config is None:
1239-
raise ServerClientError("Fleet type changed from SSH to Cloud, cannot update")
1240-
else:
1241-
_check_can_update_ssh_config(current_ssh_config, new_ssh_config)
1246+
_check_can_update_ssh_config(current_ssh_config, new_ssh_config)
12421247

12431248

12441249
@_check_can_update("hosts")

src/tests/_internal/server/routers/test_fleets.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,108 @@ async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: A
14361436
assert instance.status == InstanceStatus.PENDING
14371437
assert instance.remote_connection_info is not None
14381438

1439+
@pytest.mark.asyncio
1440+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1441+
async def test_updates_cloud_fleet_nodes_in_place_when_fleet_in_use(
1442+
self, test_db, session: AsyncSession, client: AsyncClient
1443+
):
1444+
user = await create_user(session, global_role=GlobalRole.USER)
1445+
project = await create_project(session)
1446+
await add_project_member(
1447+
session=session, project=project, user=user, project_role=ProjectRole.USER
1448+
)
1449+
current_spec = get_fleet_spec(
1450+
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=2))
1451+
)
1452+
fleet = await create_fleet(session=session, project=project, spec=current_spec)
1453+
repo = await create_repo(session=session, project_id=project.id)
1454+
run = await create_run(session=session, project=project, repo=repo, user=user, fleet=fleet)
1455+
job = await create_job(session=session, run=run, fleet=fleet)
1456+
instance = await create_instance(
1457+
session=session,
1458+
project=project,
1459+
fleet=fleet,
1460+
job=job,
1461+
status=InstanceStatus.BUSY,
1462+
instance_num=0,
1463+
)
1464+
spec = current_spec.copy(deep=True)
1465+
spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=3)
1466+
1467+
response = await client.post(
1468+
f"/api/project/{project.name}/fleets/apply",
1469+
headers=get_auth_headers(user.token),
1470+
json={
1471+
"plan": {
1472+
"spec": spec.dict(),
1473+
"current_resource": _fleet_model_to_json_dict(fleet),
1474+
},
1475+
"force": False,
1476+
},
1477+
)
1478+
1479+
response_json = response.json()
1480+
assert response.status_code == 200, response_json
1481+
assert response_json["id"] == str(fleet.id)
1482+
assert response_json["spec"]["configuration"]["nodes"] == {"min": 1, "max": 3}
1483+
1484+
await session.refresh(fleet)
1485+
await session.refresh(instance)
1486+
assert json.loads(fleet.spec)["configuration"]["nodes"] == {"min": 1, "max": 3}
1487+
assert instance.status == InstanceStatus.BUSY
1488+
1489+
@pytest.mark.asyncio
1490+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1491+
async def test_updates_cloud_fleet_nodes_target_without_changing_instance_count(
1492+
self, test_db, session: AsyncSession, client: AsyncClient
1493+
):
1494+
user = await create_user(session, global_role=GlobalRole.USER)
1495+
project = await create_project(session)
1496+
await add_project_member(
1497+
session=session, project=project, user=user, project_role=ProjectRole.USER
1498+
)
1499+
current_spec = get_fleet_spec(
1500+
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1))
1501+
)
1502+
fleet = await create_fleet(session=session, project=project, spec=current_spec)
1503+
spec = current_spec.copy(deep=True)
1504+
spec.configuration.nodes = FleetNodesSpec(min=0, target=1, max=1)
1505+
1506+
response = await client.post(
1507+
f"/api/project/{project.name}/fleets/apply",
1508+
headers=get_auth_headers(user.token),
1509+
json={
1510+
"plan": {
1511+
"spec": spec.dict(),
1512+
"current_resource": _fleet_model_to_json_dict(fleet),
1513+
},
1514+
"force": False,
1515+
},
1516+
)
1517+
1518+
response_json = response.json()
1519+
assert response.status_code == 200, response_json
1520+
assert response_json["id"] == str(fleet.id)
1521+
assert response_json["spec"]["configuration"]["nodes"] == {
1522+
"min": 0,
1523+
"target": 1,
1524+
"max": 1,
1525+
}
1526+
1527+
await session.refresh(fleet)
1528+
assert json.loads(fleet.spec)["configuration"]["nodes"] == {
1529+
"min": 0,
1530+
"target": 1,
1531+
"max": 1,
1532+
}
1533+
res = await session.execute(
1534+
select(InstanceModel).where(
1535+
InstanceModel.fleet_id == fleet.id,
1536+
InstanceModel.deleted == False,
1537+
)
1538+
)
1539+
assert list(res.scalars().all()) == []
1540+
14391541
@pytest.mark.asyncio
14401542
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
14411543
@freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc))
@@ -2118,6 +2220,62 @@ async def test_returns_update_plan_for_existing_fleet(
21182220
"action": "update",
21192221
}
21202222

2223+
@pytest.mark.asyncio
2224+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
2225+
async def test_returns_update_plan_for_existing_cloud_fleet_nodes_update(
2226+
self, test_db, session: AsyncSession, client: AsyncClient
2227+
):
2228+
user = await create_user(session=session, global_role=GlobalRole.USER)
2229+
project = await create_project(session=session, owner=user)
2230+
await add_project_member(
2231+
session=session, project=project, user=user, project_role=ProjectRole.USER
2232+
)
2233+
current_spec = get_fleet_spec(
2234+
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1))
2235+
)
2236+
spec = current_spec.copy(deep=True)
2237+
spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1)
2238+
fleet = await create_fleet(session=session, project=project, spec=current_spec)
2239+
2240+
response = await client.post(
2241+
f"/api/project/{project.name}/fleets/get_plan",
2242+
headers=get_auth_headers(user.token),
2243+
json={"spec": spec.dict()},
2244+
)
2245+
2246+
response_json = response.json()
2247+
assert response.status_code == 200, response_json
2248+
assert response_json["current_resource"]["id"] == str(fleet.id)
2249+
assert response_json["action"] == "update"
2250+
2251+
@pytest.mark.asyncio
2252+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
2253+
async def test_returns_create_plan_for_existing_cloud_fleet_blocks_update(
2254+
self, test_db, session: AsyncSession, client: AsyncClient
2255+
):
2256+
user = await create_user(session=session, global_role=GlobalRole.USER)
2257+
project = await create_project(session=session, owner=user)
2258+
await add_project_member(
2259+
session=session, project=project, user=user, project_role=ProjectRole.USER
2260+
)
2261+
current_spec = get_fleet_spec(
2262+
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1))
2263+
)
2264+
spec = current_spec.copy(deep=True)
2265+
spec.configuration.blocks = 2
2266+
fleet = await create_fleet(session=session, project=project, spec=current_spec)
2267+
2268+
response = await client.post(
2269+
f"/api/project/{project.name}/fleets/get_plan",
2270+
headers=get_auth_headers(user.token),
2271+
json={"spec": spec.dict()},
2272+
)
2273+
2274+
response_json = response.json()
2275+
assert response.status_code == 200, response_json
2276+
assert response_json["current_resource"]["id"] == str(fleet.id)
2277+
assert response_json["action"] == "create"
2278+
21212279
@pytest.mark.asyncio
21222280
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
21232281
async def test_returns_create_plan_for_existing_fleet(

src/tests/_internal/server/services/test_fleets.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dstack._internal.core.backends.base.backend import Backend
88
from dstack._internal.core.errors import ServerClientError
99
from dstack._internal.core.models.backends.base import BackendType
10+
from dstack._internal.core.models.common import ApplyAction
1011
from dstack._internal.core.models.fleets import (
1112
FleetConfiguration,
1213
FleetNodesSpec,
@@ -203,3 +204,53 @@ async def test_returns_none_without_current_master_instance(
203204
)
204205

205206
assert master_provisioning_data is None
207+
208+
209+
class TestGetPlanCloudFleetUpdate:
210+
@pytest.fixture
211+
def get_project_backends_mock(self, monkeypatch: pytest.MonkeyPatch) -> list[Backend]:
212+
mock = Mock(spec_set=get_project_backends, return_value=[])
213+
monkeypatch.setattr("dstack._internal.server.services.backends.get_project_backends", mock)
214+
return mock
215+
216+
@pytest.mark.asyncio
217+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
218+
@pytest.mark.usefixtures("test_db", "get_project_backends_mock")
219+
async def test_ok_nodes_update(self, session: AsyncSession):
220+
user = await create_user(session=session)
221+
project = await create_project(session=session, owner=user)
222+
current_spec = get_fleet_spec(
223+
conf=FleetConfiguration(
224+
name="my-fleet",
225+
nodes=FleetNodesSpec(min=0, target=0, max=1),
226+
)
227+
)
228+
await create_fleet(session=session, project=project, spec=current_spec)
229+
new_spec = current_spec.copy(deep=True)
230+
new_spec.configuration.nodes = FleetNodesSpec(min=0, target=1, max=1)
231+
232+
plan = await get_plan(session=session, project=project, user=user, spec=new_spec)
233+
234+
assert plan.current_resource is not None
235+
assert plan.action == ApplyAction.UPDATE
236+
237+
@pytest.mark.asyncio
238+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
239+
@pytest.mark.usefixtures("test_db", "get_project_backends_mock")
240+
async def test_placement_update_requires_recreate(self, session: AsyncSession):
241+
user = await create_user(session=session)
242+
project = await create_project(session=session, owner=user)
243+
current_spec = get_fleet_spec(
244+
conf=FleetConfiguration(
245+
name="my-fleet",
246+
nodes=FleetNodesSpec(min=0, target=0, max=1),
247+
)
248+
)
249+
await create_fleet(session=session, project=project, spec=current_spec)
250+
new_spec = current_spec.copy(deep=True)
251+
new_spec.configuration.placement = InstanceGroupPlacement.CLUSTER
252+
253+
plan = await get_plan(session=session, project=project, user=user, spec=new_spec)
254+
255+
assert plan.current_resource is not None
256+
assert plan.action == ApplyAction.CREATE

0 commit comments

Comments
 (0)