Skip to content

Commit da026f9

Browse files
peterschmidt85Andrey Cheptsov
andauthored
Show offers in dstack apply for elastic container fleets (#3754)
* Show offers for elastic container fleets Use run-capable offer lookup for cloud fleets with nodes.min=0 and nodes.target=0, while keeping create-instance filtering for non-elastic fleets.\n\nAdds router tests for elastic container backend offers and preserves no-offers behavior for non-elastic container fleets. * Refactor fleet offer semantics for create-instance filtering * Drop get_create_instance_offers wrapper --------- Co-authored-by: Andrey Cheptsov <andrey.cheptsov@github.com>
1 parent 73e51a5 commit da026f9

File tree

3 files changed

+102
-11
lines changed

3 files changed

+102
-11
lines changed

src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
)
3737
from dstack._internal.server.db import get_session_ctx
3838
from dstack._internal.server.models import FleetModel, InstanceModel, PlacementGroupModel
39-
from dstack._internal.server.services.fleets import get_create_instance_offers, is_cloud_cluster
39+
from dstack._internal.server.services.fleets import get_fleet_offers, is_cloud_cluster
4040
from dstack._internal.server.services.instances import (
4141
get_instance_configuration,
4242
get_instance_profile,
@@ -101,7 +101,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult:
101101
)
102102
master_job_provisioning_data = cluster_context.master_job_provisioning_data
103103

104-
offers = await get_create_instance_offers(
104+
offers = await get_fleet_offers(
105105
project=instance_model.project,
106106
profile=profile,
107107
requirements=requirements,
@@ -111,6 +111,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult:
111111
exclude_not_available=True,
112112
master_job_provisioning_data=master_job_provisioning_data,
113113
infer_master_job_provisioning_data_from_fleet_instances=False,
114+
include_only_create_instance_supported_backends=True,
114115
)
115116

116117
# Limit number of offers tried to prevent long-running processing in case all offers fail.
@@ -120,7 +121,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult:
120121
compute = backend.compute()
121122
assert isinstance(compute, ComputeWithCreateInstanceSupport)
122123
if master_job_provisioning_data is not None:
123-
# `get_create_instance_offers()` already restricts backend and region from the master.
124+
# `get_fleet_offers()` already restricts backend and region from the master.
124125
# Availability zone still has to be narrowed per offer.
125126
instance_offer = get_instance_offer_with_restricted_az(
126127
instance_offer=instance_offer,

src/dstack/_internal/server/services/fleets.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -442,12 +442,20 @@ async def get_plan(
442442

443443
offers = []
444444
if effective_spec.configuration.ssh_config is None:
445-
offers_with_backends = await get_create_instance_offers(
445+
requirements = get_fleet_requirements(effective_spec)
446+
nodes = effective_spec.configuration.nodes
447+
include_only_create_instance_supported_backends = True
448+
if nodes is not None:
449+
include_only_create_instance_supported_backends = nodes.target != 0
450+
offers_with_backends = await get_fleet_offers(
446451
project=project,
447452
profile=effective_spec.merged_profile,
448-
requirements=get_fleet_requirements(effective_spec),
453+
requirements=requirements,
449454
fleet_spec=effective_spec,
450455
blocks=effective_spec.configuration.blocks,
456+
include_only_create_instance_supported_backends=(
457+
include_only_create_instance_supported_backends
458+
),
451459
)
452460
offers = [offer for _, offer in offers_with_backends]
453461

@@ -468,7 +476,7 @@ async def get_plan(
468476
return plan
469477

470478

471-
async def get_create_instance_offers(
479+
async def get_fleet_offers(
472480
project: ProjectModel,
473481
profile: Profile,
474482
requirements: Requirements,
@@ -479,7 +487,15 @@ async def get_create_instance_offers(
479487
exclude_not_available: bool = False,
480488
master_job_provisioning_data: Optional[JobProvisioningData] = None,
481489
infer_master_job_provisioning_data_from_fleet_instances: bool = True,
490+
include_only_create_instance_supported_backends: bool = True,
482491
) -> List[Tuple[Backend, InstanceOfferWithAvailability]]:
492+
"""
493+
Return offers for fleet planning and provisioning.
494+
495+
By default, restricts to backends that support `create_instance`.
496+
Set `include_only_create_instance_supported_backends=False` to include
497+
all matching backends.
498+
"""
483499
multinode = False
484500
if fleet_spec is not None:
485501
multinode = fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER
@@ -508,11 +524,12 @@ async def get_create_instance_offers(
508524
placement_group=placement_group,
509525
blocks=blocks,
510526
)
511-
offers = [
512-
(backend, offer)
513-
for backend, offer in offers
514-
if offer.backend in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT
515-
]
527+
if include_only_create_instance_supported_backends:
528+
offers = [
529+
(backend, offer)
530+
for backend, offer in offers
531+
if offer.backend in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT
532+
]
516533
return offers
517534

518535

src/tests/_internal/server/routers/test_fleets.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from dstack._internal.core.models.common import EntityReference
1515
from dstack._internal.core.models.fleets import (
1616
FleetConfiguration,
17+
FleetNodesSpec,
1718
FleetStatus,
1819
InstanceGroupPlacement,
1920
SSHHostParams,
@@ -2007,6 +2008,78 @@ async def test_returns_create_plan_for_new_fleet(
20072008
"action": "create",
20082009
}
20092010

2011+
@pytest.mark.asyncio
2012+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
2013+
async def test_returns_offers_for_elastic_container_backend_fleet(
2014+
self, test_db, session: AsyncSession, client: AsyncClient
2015+
):
2016+
user = await create_user(session=session, global_role=GlobalRole.USER)
2017+
project = await create_project(session=session, owner=user)
2018+
await add_project_member(
2019+
session=session, project=project, user=user, project_role=ProjectRole.USER
2020+
)
2021+
offer = get_instance_offer_with_availability(
2022+
backend=BackendType.RUNPOD,
2023+
region="US-OR-1",
2024+
price=0.7185,
2025+
)
2026+
spec = get_fleet_spec(
2027+
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1))
2028+
)
2029+
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
2030+
backend_mock = Mock()
2031+
m.return_value = [backend_mock]
2032+
backend_mock.TYPE = BackendType.RUNPOD
2033+
backend_mock.compute.return_value.get_offers.return_value = [offer]
2034+
response = await client.post(
2035+
f"/api/project/{project.name}/fleets/get_plan",
2036+
headers=get_auth_headers(user.token),
2037+
json={"spec": spec.dict()},
2038+
)
2039+
backend_mock.compute.return_value.get_offers.assert_called_once()
2040+
2041+
response_json = response.json()
2042+
assert response.status_code == 200, response_json
2043+
assert response_json["offers"] == [json.loads(offer.json())]
2044+
assert response_json["total_offers"] == 1
2045+
assert response_json["max_offer_price"] == offer.price
2046+
2047+
@pytest.mark.asyncio
2048+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
2049+
async def test_returns_no_offers_for_non_elastic_container_backend_fleet(
2050+
self, test_db, session: AsyncSession, client: AsyncClient
2051+
):
2052+
user = await create_user(session=session, global_role=GlobalRole.USER)
2053+
project = await create_project(session=session, owner=user)
2054+
await add_project_member(
2055+
session=session, project=project, user=user, project_role=ProjectRole.USER
2056+
)
2057+
offer = get_instance_offer_with_availability(
2058+
backend=BackendType.RUNPOD,
2059+
region="US-OR-1",
2060+
price=0.7185,
2061+
)
2062+
spec = get_fleet_spec(
2063+
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=1, max=1))
2064+
)
2065+
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
2066+
backend_mock = Mock()
2067+
m.return_value = [backend_mock]
2068+
backend_mock.TYPE = BackendType.RUNPOD
2069+
backend_mock.compute.return_value.get_offers.return_value = [offer]
2070+
response = await client.post(
2071+
f"/api/project/{project.name}/fleets/get_plan",
2072+
headers=get_auth_headers(user.token),
2073+
json={"spec": spec.dict()},
2074+
)
2075+
backend_mock.compute.return_value.get_offers.assert_called_once()
2076+
2077+
response_json = response.json()
2078+
assert response.status_code == 200, response_json
2079+
assert response_json["offers"] == []
2080+
assert response_json["total_offers"] == 0
2081+
assert response_json["max_offer_price"] is None
2082+
20102083
@pytest.mark.asyncio
20112084
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
20122085
async def test_returns_update_plan_for_existing_fleet(

0 commit comments

Comments
 (0)