Skip to content

Commit 03a26e2

Browse files
authored
Implement /api/project/{project_name}/fleets/apply (#2577)
1 parent 52d113f commit 03a26e2

File tree

9 files changed

+157
-30
lines changed

9 files changed

+157
-30
lines changed

src/dstack/_internal/cli/services/configurators/fleet.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
)
1818
from dstack._internal.cli.utils.fleet import get_fleets_table
1919
from dstack._internal.cli.utils.rich import MultiItemStatus
20-
from dstack._internal.core.errors import ConfigurationError, ResourceNotExistsError
20+
from dstack._internal.core.errors import (
21+
CLIError,
22+
ConfigurationError,
23+
ResourceNotExistsError,
24+
ServerClientError,
25+
URLNotFoundError,
26+
)
2127
from dstack._internal.core.models.configurations import ApplyConfigurationType
2228
from dstack._internal.core.models.fleets import (
2329
Fleet,
@@ -31,6 +37,7 @@
3137
from dstack._internal.utils.common import local_time
3238
from dstack._internal.utils.logging import get_logger
3339
from dstack._internal.utils.ssh import convert_ssh_key_to_pem, generate_public_key, pkey_from_str
40+
from dstack.api._public import Client
3441
from dstack.api.utils import load_profile
3542

3643
logger = get_logger(__name__)
@@ -109,11 +116,11 @@ def apply_configuration(
109116
else:
110117
time.sleep(1)
111118

112-
with console.status("Creating fleet..."):
113-
fleet = self.api.client.fleets.create(
114-
project_name=self.api.project,
115-
spec=spec,
116-
)
119+
try:
120+
with console.status("Applying plan..."):
121+
fleet = _apply_plan(self.api, plan)
122+
except ServerClientError as e:
123+
raise CLIError(e.msg)
117124
if command_args.detach:
118125
console.print("Fleet configuration submitted. Exiting...")
119126
return
@@ -350,3 +357,17 @@ def _failed_provisioning(fleet: Fleet) -> bool:
350357
if instance.status == InstanceStatus.TERMINATED:
351358
return True
352359
return False
360+
361+
362+
def _apply_plan(api: Client, plan: FleetPlan) -> Fleet:
363+
try:
364+
return api.client.fleets.apply_plan(
365+
project_name=api.project,
366+
plan=plan,
367+
)
368+
except URLNotFoundError:
369+
# TODO: Remove in 0.20
370+
return api.client.fleets.create(
371+
project_name=api.project,
372+
spec=plan.spec,
373+
)

src/dstack/_internal/core/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ class ClientError(DstackError):
1818
pass
1919

2020

21+
class URLNotFoundError(ClientError):
22+
pass
23+
24+
2125
class ServerClientErrorCode(str, enum.Enum):
2226
UNSPECIFIED_ERROR = "error"
2327
RESOURCE_EXISTS = "resource_exists"

src/dstack/_internal/core/models/fleets.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,20 @@ class FleetPlan(CoreModel):
312312
project_name: str
313313
user: str
314314
spec: FleetSpec
315-
current_resource: Optional[Fleet]
315+
current_resource: Optional[Fleet] = None
316316
offers: List[InstanceOfferWithAvailability]
317317
total_offers: int
318-
max_offer_price: Optional[float]
318+
max_offer_price: Optional[float] = None
319+
320+
321+
class ApplyFleetPlanInput(CoreModel):
322+
spec: FleetSpec
323+
current_resource: Annotated[
324+
Optional[Fleet],
325+
Field(
326+
description=(
327+
"The expected current resource."
328+
" If the resource has changed, the apply fails unless `force: true`."
329+
)
330+
),
331+
] = None

src/dstack/_internal/server/routers/fleets.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dstack._internal.server.db import get_session
1010
from dstack._internal.server.models import ProjectModel, UserModel
1111
from dstack._internal.server.schemas.fleets import (
12+
ApplyFleetPlanRequest,
1213
CreateFleetRequest,
1314
DeleteFleetInstancesRequest,
1415
DeleteFleetsRequest,
@@ -107,6 +108,27 @@ async def get_plan(
107108
return plan
108109

109110

111+
@project_router.post("/apply")
112+
async def apply_plan(
113+
body: ApplyFleetPlanRequest,
114+
session: AsyncSession = Depends(get_session),
115+
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
116+
) -> Fleet:
117+
"""
118+
Creates a new fleet or updates an existing fleet.
119+
Errors if the expected current resource from the plan does not match the current resource.
120+
Use `force: true` to apply even if the current resource does not match.
121+
"""
122+
user, project = user_project
123+
return await fleets_services.apply_plan(
124+
session=session,
125+
user=user,
126+
project=project,
127+
plan=body.plan,
128+
force=body.force,
129+
)
130+
131+
110132
@project_router.post("/create")
111133
async def create_fleet(
112134
body: CreateFleetRequest,

src/dstack/_internal/server/schemas/fleets.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from datetime import datetime
2-
from typing import List, Optional
2+
from typing import Annotated, List, Optional
33
from uuid import UUID
44

55
from pydantic import Field
66

77
from dstack._internal.core.models.common import CoreModel
8-
from dstack._internal.core.models.fleets import FleetSpec
8+
from dstack._internal.core.models.fleets import ApplyFleetPlanInput, FleetSpec
99

1010

1111
class ListFleetsRequest(CoreModel):
@@ -26,6 +26,16 @@ class GetFleetPlanRequest(CoreModel):
2626
spec: FleetSpec
2727

2828

29+
class ApplyFleetPlanRequest(CoreModel):
30+
plan: ApplyFleetPlanInput
31+
force: Annotated[
32+
bool,
33+
Field(
34+
description="Use `force: true` to apply even if the expected resource does not match."
35+
),
36+
]
37+
38+
2939
class CreateFleetRequest(CoreModel):
3040
spec: FleetSpec
3141

src/dstack/_internal/server/services/fleets.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from dstack._internal.core.models.envs import Env
1919
from dstack._internal.core.models.fleets import (
20+
ApplyFleetPlanInput,
2021
Fleet,
2122
FleetPlan,
2223
FleetSpec,
@@ -306,6 +307,21 @@ async def get_create_instance_offers(
306307
return offers
307308

308309

310+
async def apply_plan(
311+
session: AsyncSession,
312+
user: UserModel,
313+
project: ProjectModel,
314+
plan: ApplyFleetPlanInput,
315+
force: bool,
316+
) -> Fleet:
317+
return await create_fleet(
318+
session=session,
319+
project=project,
320+
user=user,
321+
spec=plan.spec,
322+
)
323+
324+
309325
async def create_fleet(
310326
session: AsyncSession,
311327
project: ProjectModel,

src/dstack/api/server/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import requests
77

88
from dstack import version
9-
from dstack._internal.core.errors import ClientError, ServerClientError
9+
from dstack._internal.core.errors import ClientError, ServerClientError, URLNotFoundError
1010
from dstack._internal.utils.logging import get_logger
1111
from dstack.api.server._backends import BackendsAPIClient
1212
from dstack.api.server._fleets import FleetsAPIClient
@@ -154,6 +154,8 @@ def _request(
154154
raise ClientError(
155155
f"Access to {resp.request.url} is denied. Please check your access token"
156156
)
157+
if resp.status_code == 404:
158+
raise URLNotFoundError(f"Status code 404 when requesting {resp.request.url}")
157159
if 400 <= resp.status_code < 600:
158160
raise ClientError(
159161
f"Unexpected error: status code {resp.status_code}"

src/dstack/api/server/_fleets.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Optional, Union
22

33
from pydantic import parse_obj_as
44

5-
from dstack._internal.core.models.fleets import Fleet, FleetPlan, FleetSpec
5+
from dstack._internal.core.models.fleets import ApplyFleetPlanInput, Fleet, FleetPlan, FleetSpec
66
from dstack._internal.server.schemas.fleets import (
7+
ApplyFleetPlanRequest,
78
CreateFleetRequest,
89
DeleteFleetInstancesRequest,
910
DeleteFleetsRequest,
@@ -32,18 +33,20 @@ def get_plan(
3233
spec: FleetSpec,
3334
) -> FleetPlan:
3435
body = GetFleetPlanRequest(spec=spec)
35-
body_json = body.json(exclude=_get_fleet_spec_excludes(spec))
36+
body_json = body.json(exclude=_get_get_plan_excludes(spec))
3637
resp = self._request(f"/api/project/{project_name}/fleets/get_plan", body=body_json)
3738
return parse_obj_as(FleetPlan.__response__, resp.json())
3839

39-
def create(
40+
def apply_plan(
4041
self,
4142
project_name: str,
42-
spec: FleetSpec,
43+
plan: Union[FleetPlan, ApplyFleetPlanInput],
44+
force: bool = False,
4345
) -> Fleet:
44-
body = CreateFleetRequest(spec=spec)
45-
body_json = body.json(exclude=_get_fleet_spec_excludes(spec))
46-
resp = self._request(f"/api/project/{project_name}/fleets/create", body=body_json)
46+
plan_input = ApplyFleetPlanInput.__response__.parse_obj(plan)
47+
body = ApplyFleetPlanRequest(plan=plan_input, force=force)
48+
body_json = body.json(exclude=_get_apply_plan_excludes(plan_input))
49+
resp = self._request(f"/api/project/{project_name}/fleets/apply", body=body_json)
4750
return parse_obj_as(Fleet.__response__, resp.json())
4851

4952
def delete(self, project_name: str, names: List[str]) -> None:
@@ -54,6 +57,42 @@ def delete_instances(self, project_name: str, name: str, instance_nums: List[int
5457
body = DeleteFleetInstancesRequest(name=name, instance_nums=instance_nums)
5558
self._request(f"/api/project/{project_name}/fleets/delete_instances", body=body.json())
5659

60+
# Deprecated
61+
# TODO: Remove in 0.20
62+
def create(
63+
self,
64+
project_name: str,
65+
spec: FleetSpec,
66+
) -> Fleet:
67+
body = CreateFleetRequest(spec=spec)
68+
body_json = body.json(exclude=_get_create_fleet_excludes(spec))
69+
resp = self._request(f"/api/project/{project_name}/fleets/create", body=body_json)
70+
return parse_obj_as(Fleet.__response__, resp.json())
71+
72+
73+
def _get_get_plan_excludes(fleet_spec: FleetSpec) -> Dict:
74+
get_plan_excludes = {}
75+
spec_excludes = _get_fleet_spec_excludes(fleet_spec)
76+
if spec_excludes:
77+
get_plan_excludes["spec"] = spec_excludes
78+
return get_plan_excludes
79+
80+
81+
def _get_apply_plan_excludes(plan_input: ApplyFleetPlanInput) -> Dict:
82+
apply_plan_excludes = {}
83+
spec_excludes = _get_fleet_spec_excludes(plan_input.spec)
84+
if spec_excludes:
85+
apply_plan_excludes["spec"] = apply_plan_excludes
86+
return {"plan": apply_plan_excludes}
87+
88+
89+
def _get_create_fleet_excludes(fleet_spec: FleetSpec) -> Dict:
90+
create_fleet_excludes = {}
91+
spec_excludes = _get_fleet_spec_excludes(fleet_spec)
92+
if spec_excludes:
93+
create_fleet_excludes["spec"] = spec_excludes
94+
return create_fleet_excludes
95+
5796

5897
def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[Dict]:
5998
"""
@@ -76,5 +115,5 @@ def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[Dict]:
76115
if profile_excludes:
77116
spec_excludes["profile"] = profile_excludes
78117
if spec_excludes:
79-
return {"spec": spec_excludes}
118+
return spec_excludes
80119
return None

src/tests/_internal/server/routers/test_fleets.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,13 @@ async def test_not_returns_by_name_if_fleet_does_not_exist(
292292
assert response.status_code == 400
293293

294294

295-
class TestCreateFleet:
295+
class TestApplyFleetPlan:
296296
@pytest.mark.asyncio
297297
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
298298
async def test_returns_40x_if_not_authenticated(
299299
self, test_db, session: AsyncSession, client: AsyncClient
300300
):
301-
response = await client.post("/api/project/main/fleets/create")
301+
response = await client.post("/api/project/main/fleets/apply")
302302
assert response.status_code == 403
303303

304304
@pytest.mark.asyncio
@@ -314,9 +314,9 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async
314314
with patch("uuid.uuid4") as m:
315315
m.return_value = UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e")
316316
response = await client.post(
317-
f"/api/project/{project.name}/fleets/create",
317+
f"/api/project/{project.name}/fleets/apply",
318318
headers=get_auth_headers(user.token),
319-
json={"spec": spec.dict()},
319+
json={"plan": {"spec": spec.dict()}, "force": False},
320320
)
321321
assert response.status_code == 200
322322
assert response.json() == {
@@ -427,9 +427,9 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
427427
with patch("uuid.uuid4") as m:
428428
m.return_value = UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e")
429429
response = await client.post(
430-
f"/api/project/{project.name}/fleets/create",
430+
f"/api/project/{project.name}/fleets/apply",
431431
headers=get_auth_headers(user.token),
432-
json={"spec": spec.dict()},
432+
json={"plan": {"spec": spec.dict()}, "force": False},
433433
)
434434
assert response.status_code == 200, response.json()
435435
assert response.json() == {
@@ -559,9 +559,9 @@ async def test_errors_if_ssh_key_is_bad(
559559
)
560560
)
561561
response = await client.post(
562-
f"/api/project/{project.name}/fleets/create",
562+
f"/api/project/{project.name}/fleets/apply",
563563
headers=get_auth_headers(user.token),
564-
json={"spec": spec.dict()},
564+
json={"plan": {"spec": spec.dict()}, "force": False},
565565
)
566566
assert response.status_code == 400
567567

@@ -590,9 +590,9 @@ async def test_forbids_if_no_permission_to_manage_ssh_fleets(
590590
DefaultPermissions(allow_non_admins_manage_ssh_fleets=False)
591591
):
592592
response = await client.post(
593-
f"/api/project/{project.name}/fleets/create",
593+
f"/api/project/{project.name}/fleets/apply",
594594
headers=get_auth_headers(user.token),
595-
json={"spec": spec.dict()},
595+
json={"plan": {"spec": spec.dict()}, "force": False},
596596
)
597597
assert response.status_code in [401, 403]
598598

0 commit comments

Comments
 (0)