Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,12 +442,20 @@ async def get_plan(

offers = []
if effective_spec.configuration.ssh_config is None:
offers_with_backends = await get_create_instance_offers(
requirements = get_fleet_requirements(effective_spec)
nodes = effective_spec.configuration.nodes
include_only_create_instance_supported_backends = True
if nodes is not None:
include_only_create_instance_supported_backends = nodes.target != 0
offers_with_backends = await get_fleet_offers(
project=project,
profile=effective_spec.merged_profile,
requirements=get_fleet_requirements(effective_spec),
requirements=requirements,
fleet_spec=effective_spec,
blocks=effective_spec.configuration.blocks,
include_only_create_instance_supported_backends=(
include_only_create_instance_supported_backends
),
)
offers = [offer for _, offer in offers_with_backends]

Expand Down Expand Up @@ -480,6 +488,49 @@ async def get_create_instance_offers(
master_job_provisioning_data: Optional[JobProvisioningData] = None,
infer_master_job_provisioning_data_from_fleet_instances: bool = True,
) -> List[Tuple[Backend, InstanceOfferWithAvailability]]:
"""
Comment thread
peterschmidt85 marked this conversation as resolved.
Outdated
Return fleet offers restricted to backends that support `create_instance`.

This method is for create-instance provisioning semantics
(typically VM-based backends, not container-only backends).
"""
return await get_fleet_offers(
project=project,
profile=profile,
requirements=requirements,
placement_group=placement_group,
fleet_spec=fleet_spec,
fleet_model=fleet_model,
blocks=blocks,
exclude_not_available=exclude_not_available,
master_job_provisioning_data=master_job_provisioning_data,
infer_master_job_provisioning_data_from_fleet_instances=(
infer_master_job_provisioning_data_from_fleet_instances
),
include_only_create_instance_supported_backends=True,
)


async def get_fleet_offers(
project: ProjectModel,
profile: Profile,
requirements: Requirements,
placement_group: Optional[PlacementGroup] = None,
fleet_spec: Optional[FleetSpec] = None,
fleet_model: Optional[FleetModel] = None,
blocks: Union[int, Literal["auto"]] = 1,
exclude_not_available: bool = False,
master_job_provisioning_data: Optional[JobProvisioningData] = None,
infer_master_job_provisioning_data_from_fleet_instances: bool = True,
include_only_create_instance_supported_backends: bool = True,
) -> List[Tuple[Backend, InstanceOfferWithAvailability]]:
"""
Return offers for fleet planning and provisioning.

By default, restricts to backends that support `create_instance`.
Set `include_only_create_instance_supported_backends=False` to include
all matching backends.
"""
multinode = False
if fleet_spec is not None:
multinode = fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER
Expand Down Expand Up @@ -508,11 +559,12 @@ async def get_create_instance_offers(
placement_group=placement_group,
blocks=blocks,
)
offers = [
(backend, offer)
for backend, offer in offers
if offer.backend in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT
]
if include_only_create_instance_supported_backends:
offers = [
(backend, offer)
for backend, offer in offers
if offer.backend in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT
]
return offers


Expand Down
73 changes: 73 additions & 0 deletions src/tests/_internal/server/routers/test_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.fleets import (
FleetConfiguration,
FleetNodesSpec,
FleetStatus,
InstanceGroupPlacement,
SSHHostParams,
Expand Down Expand Up @@ -2028,6 +2029,78 @@ async def test_returns_create_plan_for_new_fleet(
"action": "create",
}

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_offers_for_elastic_container_backend_fleet(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
offer = get_instance_offer_with_availability(
backend=BackendType.RUNPOD,
region="US-OR-1",
price=0.7185,
)
spec = get_fleet_spec(
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1))
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
response = await client.post(
f"/api/project/{project.name}/fleets/get_plan",
headers=get_auth_headers(user.token),
json={"spec": spec.dict()},
)
backend_mock.compute.return_value.get_offers.assert_called_once()

response_json = response.json()
assert response.status_code == 200, response_json
assert response_json["offers"] == [json.loads(offer.json())]
assert response_json["total_offers"] == 1
assert response_json["max_offer_price"] == offer.price

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_no_offers_for_non_elastic_container_backend_fleet(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
offer = get_instance_offer_with_availability(
backend=BackendType.RUNPOD,
region="US-OR-1",
price=0.7185,
)
spec = get_fleet_spec(
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=1, max=1))
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
response = await client.post(
f"/api/project/{project.name}/fleets/get_plan",
headers=get_auth_headers(user.token),
json={"spec": spec.dict()},
)
backend_mock.compute.return_value.get_offers.assert_called_once()

response_json = response.json()
assert response.status_code == 200, response_json
assert response_json["offers"] == []
assert response_json["total_offers"] == 0
assert response_json["max_offer_price"] is None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_update_plan_for_existing_fleet(
Expand Down
Loading