Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from dstack._internal.server.db import get_session_ctx
from dstack._internal.server.models import FleetModel, InstanceModel, PlacementGroupModel
from dstack._internal.server.services.fleets import get_create_instance_offers, is_cloud_cluster
from dstack._internal.server.services.fleets import get_fleet_offers, is_cloud_cluster
from dstack._internal.server.services.instances import (
get_instance_configuration,
get_instance_profile,
Expand Down Expand Up @@ -101,7 +101,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult:
)
master_job_provisioning_data = cluster_context.master_job_provisioning_data

offers = await get_create_instance_offers(
offers = await get_fleet_offers(
project=instance_model.project,
profile=profile,
requirements=requirements,
Expand All @@ -111,6 +111,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult:
exclude_not_available=True,
master_job_provisioning_data=master_job_provisioning_data,
infer_master_job_provisioning_data_from_fleet_instances=False,
include_only_create_instance_supported_backends=True,
)

# Limit number of offers tried to prevent long-running processing in case all offers fail.
Expand All @@ -120,7 +121,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult:
compute = backend.compute()
assert isinstance(compute, ComputeWithCreateInstanceSupport)
if master_job_provisioning_data is not None:
# `get_create_instance_offers()` already restricts backend and region from the master.
# `get_fleet_offers()` already restricts backend and region from the master.
# Availability zone still has to be narrowed per offer.
instance_offer = get_instance_offer_with_restricted_az(
instance_offer=instance_offer,
Expand Down
33 changes: 25 additions & 8 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,12 +442,20 @@ async def get_plan(

offers = []
if effective_spec.configuration.ssh_config is None:
offers_with_backends = await get_create_instance_offers(
requirements = get_fleet_requirements(effective_spec)
nodes = effective_spec.configuration.nodes
include_only_create_instance_supported_backends = True
if nodes is not None:
include_only_create_instance_supported_backends = nodes.target != 0
offers_with_backends = await get_fleet_offers(
project=project,
profile=effective_spec.merged_profile,
requirements=get_fleet_requirements(effective_spec),
requirements=requirements,
fleet_spec=effective_spec,
blocks=effective_spec.configuration.blocks,
include_only_create_instance_supported_backends=(
include_only_create_instance_supported_backends
),
)
offers = [offer for _, offer in offers_with_backends]

Expand All @@ -468,7 +476,7 @@ async def get_plan(
return plan


async def get_create_instance_offers(
async def get_fleet_offers(
project: ProjectModel,
profile: Profile,
requirements: Requirements,
Expand All @@ -479,7 +487,15 @@ async def get_create_instance_offers(
exclude_not_available: bool = False,
master_job_provisioning_data: Optional[JobProvisioningData] = None,
infer_master_job_provisioning_data_from_fleet_instances: bool = True,
include_only_create_instance_supported_backends: bool = True,
) -> List[Tuple[Backend, InstanceOfferWithAvailability]]:
"""
Return offers for fleet planning and provisioning.

By default, restricts to backends that support `create_instance`.
Set `include_only_create_instance_supported_backends=False` to include
all matching backends.
"""
multinode = False
if fleet_spec is not None:
multinode = fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER
Expand Down Expand Up @@ -508,11 +524,12 @@ async def get_create_instance_offers(
placement_group=placement_group,
blocks=blocks,
)
offers = [
(backend, offer)
for backend, offer in offers
if offer.backend in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT
]
if include_only_create_instance_supported_backends:
offers = [
(backend, offer)
for backend, offer in offers
if offer.backend in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT
]
return offers


Expand Down
73 changes: 73 additions & 0 deletions src/tests/_internal/server/routers/test_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.fleets import (
FleetConfiguration,
FleetNodesSpec,
FleetStatus,
InstanceGroupPlacement,
SSHHostParams,
Expand Down Expand Up @@ -2028,6 +2029,78 @@ async def test_returns_create_plan_for_new_fleet(
"action": "create",
}

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_offers_for_elastic_container_backend_fleet(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
offer = get_instance_offer_with_availability(
backend=BackendType.RUNPOD,
region="US-OR-1",
price=0.7185,
)
spec = get_fleet_spec(
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1))
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
response = await client.post(
f"/api/project/{project.name}/fleets/get_plan",
headers=get_auth_headers(user.token),
json={"spec": spec.dict()},
)
backend_mock.compute.return_value.get_offers.assert_called_once()

response_json = response.json()
assert response.status_code == 200, response_json
assert response_json["offers"] == [json.loads(offer.json())]
assert response_json["total_offers"] == 1
assert response_json["max_offer_price"] == offer.price

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_no_offers_for_non_elastic_container_backend_fleet(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
offer = get_instance_offer_with_availability(
backend=BackendType.RUNPOD,
region="US-OR-1",
price=0.7185,
)
spec = get_fleet_spec(
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=1, max=1))
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
response = await client.post(
f"/api/project/{project.name}/fleets/get_plan",
headers=get_auth_headers(user.token),
json={"spec": spec.dict()},
)
backend_mock.compute.return_value.get_offers.assert_called_once()

response_json = response.json()
assert response.status_code == 200, response_json
assert response_json["offers"] == []
assert response_json["total_offers"] == 0
assert response_json["max_offer_price"] is None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_update_plan_for_existing_fleet(
Expand Down
Loading