|
1 | 1 | import json |
2 | 2 | from datetime import datetime, timezone |
3 | | -from typing import Optional, Union |
| 3 | +from typing import Literal, Optional, Union |
4 | 4 | from unittest.mock import Mock, patch |
5 | 5 | from uuid import uuid4 |
6 | 6 |
|
|
16 | 16 | FleetConfiguration, |
17 | 17 | FleetStatus, |
18 | 18 | InstanceGroupPlacement, |
| 19 | + SSHHostParams, |
19 | 20 | SSHParams, |
20 | 21 | ) |
21 | 22 | from dstack._internal.core.models.instances import ( |
@@ -1178,6 +1179,56 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A |
1178 | 1179 | instance = res.unique().scalar_one() |
1179 | 1180 | assert instance.remote_connection_info is not None |
1180 | 1181 |
|
| 1182 | + @pytest.mark.parametrize( |
| 1183 | + ["top_level_blocks", "host_blocks", "host_type", "expected_blocks"], |
| 1184 | + [ |
| 1185 | + pytest.param(None, None, str, 1, id="global-default-string"), |
| 1186 | + pytest.param(None, None, SSHHostParams, 1, id="global-default-object"), |
| 1187 | + pytest.param(4, None, str, 4, id="top-level-int-string"), |
| 1188 | + pytest.param(4, None, SSHHostParams, 4, id="top-level-int-object"), |
| 1189 | + pytest.param("auto", None, str, None, id="top-level-auto-string"), |
| 1190 | + pytest.param("auto", None, SSHHostParams, None, id="top-level-auto-object"), |
| 1191 | + pytest.param("auto", 4, SSHHostParams, 4, id="host-level-int"), |
| 1192 | + pytest.param(4, "auto", SSHHostParams, None, id="host-level-auto"), |
| 1193 | + ], |
| 1194 | + ) |
| 1195 | + @pytest.mark.asyncio |
| 1196 | + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
| 1197 | + async def test_creates_ssh_fleet_with_blocks( |
| 1198 | + self, |
| 1199 | + test_db, |
| 1200 | + session: AsyncSession, |
| 1201 | + client: AsyncClient, |
| 1202 | + top_level_blocks: Optional[Union[int, Literal["auto"]]], |
| 1203 | + host_blocks: Optional[Union[int, Literal["auto"]]], |
| 1204 | + host_type: Union[type[str], type[SSHHostParams]], |
| 1205 | + expected_blocks: Optional[int], |
| 1206 | + ): |
| 1207 | + user = await create_user(session, global_role=GlobalRole.USER) |
| 1208 | + project = await create_project(session) |
| 1209 | + await add_project_member( |
| 1210 | + session=session, project=project, user=user, project_role=ProjectRole.USER |
| 1211 | + ) |
| 1212 | + if host_type is str: |
| 1213 | + host = "1.1.1.1" |
| 1214 | + elif host_blocks is None: |
| 1215 | + host = SSHHostParams(hostname="1.1.1.1") |
| 1216 | + else: |
| 1217 | + host = SSHHostParams(hostname="1.1.1.1", blocks=host_blocks) |
| 1218 | + conf = get_ssh_fleet_configuration(blocks=top_level_blocks, hosts=[host]) |
| 1219 | + spec = get_fleet_spec(conf=conf) |
| 1220 | + response = await client.post( |
| 1221 | + f"/api/project/{project.name}/fleets/apply", |
| 1222 | + headers=get_auth_headers(user.token), |
| 1223 | + json={"plan": {"spec": spec.dict()}, "force": False}, |
| 1224 | + ) |
| 1225 | + assert response.status_code == 200, response.json() |
| 1226 | + res = await session.execute(select(FleetModel)) |
| 1227 | + assert len(res.scalars().all()) == 1 |
| 1228 | + res = await session.execute(select(InstanceModel)) |
| 1229 | + instance = res.scalar_one() |
| 1230 | + assert instance.total_blocks == expected_blocks |
| 1231 | + |
1181 | 1232 | @pytest.mark.asyncio |
1182 | 1233 | @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
1183 | 1234 | @freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), real_asyncio=True) |
|
0 commit comments