|
14 | 14 | from dstack._internal.core.models.common import EntityReference |
15 | 15 | from dstack._internal.core.models.fleets import ( |
16 | 16 | FleetConfiguration, |
| 17 | + FleetNodesSpec, |
17 | 18 | FleetStatus, |
18 | 19 | InstanceGroupPlacement, |
19 | 20 | SSHHostParams, |
@@ -2028,6 +2029,78 @@ async def test_returns_create_plan_for_new_fleet( |
2028 | 2029 | "action": "create", |
2029 | 2030 | } |
2030 | 2031 |
|
| 2032 | + @pytest.mark.asyncio |
| 2033 | + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
| 2034 | + async def test_returns_offers_for_elastic_container_backend_fleet( |
| 2035 | + self, test_db, session: AsyncSession, client: AsyncClient |
| 2036 | + ): |
| 2037 | + user = await create_user(session=session, global_role=GlobalRole.USER) |
| 2038 | + project = await create_project(session=session, owner=user) |
| 2039 | + await add_project_member( |
| 2040 | + session=session, project=project, user=user, project_role=ProjectRole.USER |
| 2041 | + ) |
| 2042 | + offer = get_instance_offer_with_availability( |
| 2043 | + backend=BackendType.RUNPOD, |
| 2044 | + region="US-OR-1", |
| 2045 | + price=0.7185, |
| 2046 | + ) |
| 2047 | + spec = get_fleet_spec( |
| 2048 | + conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1)) |
| 2049 | + ) |
| 2050 | + with patch("dstack._internal.server.services.backends.get_project_backends") as m: |
| 2051 | + backend_mock = Mock() |
| 2052 | + m.return_value = [backend_mock] |
| 2053 | + backend_mock.TYPE = BackendType.RUNPOD |
| 2054 | + backend_mock.compute.return_value.get_offers.return_value = [offer] |
| 2055 | + response = await client.post( |
| 2056 | + f"/api/project/{project.name}/fleets/get_plan", |
| 2057 | + headers=get_auth_headers(user.token), |
| 2058 | + json={"spec": spec.dict()}, |
| 2059 | + ) |
| 2060 | + backend_mock.compute.return_value.get_offers.assert_called_once() |
| 2061 | + |
| 2062 | + response_json = response.json() |
| 2063 | + assert response.status_code == 200, response_json |
| 2064 | + assert response_json["offers"] == [json.loads(offer.json())] |
| 2065 | + assert response_json["total_offers"] == 1 |
| 2066 | + assert response_json["max_offer_price"] == offer.price |
| 2067 | + |
| 2068 | + @pytest.mark.asyncio |
| 2069 | + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
| 2070 | + async def test_returns_no_offers_for_non_elastic_container_backend_fleet( |
| 2071 | + self, test_db, session: AsyncSession, client: AsyncClient |
| 2072 | + ): |
| 2073 | + user = await create_user(session=session, global_role=GlobalRole.USER) |
| 2074 | + project = await create_project(session=session, owner=user) |
| 2075 | + await add_project_member( |
| 2076 | + session=session, project=project, user=user, project_role=ProjectRole.USER |
| 2077 | + ) |
| 2078 | + offer = get_instance_offer_with_availability( |
| 2079 | + backend=BackendType.RUNPOD, |
| 2080 | + region="US-OR-1", |
| 2081 | + price=0.7185, |
| 2082 | + ) |
| 2083 | + spec = get_fleet_spec( |
| 2084 | + conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=1, max=1)) |
| 2085 | + ) |
| 2086 | + with patch("dstack._internal.server.services.backends.get_project_backends") as m: |
| 2087 | + backend_mock = Mock() |
| 2088 | + m.return_value = [backend_mock] |
| 2089 | + backend_mock.TYPE = BackendType.RUNPOD |
| 2090 | + backend_mock.compute.return_value.get_offers.return_value = [offer] |
| 2091 | + response = await client.post( |
| 2092 | + f"/api/project/{project.name}/fleets/get_plan", |
| 2093 | + headers=get_auth_headers(user.token), |
| 2094 | + json={"spec": spec.dict()}, |
| 2095 | + ) |
| 2096 | + backend_mock.compute.return_value.get_offers.assert_called_once() |
| 2097 | + |
| 2098 | + response_json = response.json() |
| 2099 | + assert response.status_code == 200, response_json |
| 2100 | + assert response_json["offers"] == [] |
| 2101 | + assert response_json["total_offers"] == 0 |
| 2102 | + assert response_json["max_offer_price"] is None |
| 2103 | + |
2031 | 2104 | @pytest.mark.asyncio |
2032 | 2105 | @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
2033 | 2106 | async def test_returns_update_plan_for_existing_fleet( |
|
0 commit comments