Skip to content

Commit 2831c81

Browse files
authored
Optimize Python DB tests (#3755)
* Drop db parametrization for test_returns_40x_if_not_authenticated tests * Replace BaseModel.metadata.drop_all with _truncate_postgres_db on postgres * Run BaseModel.metadata.create_all once on Postgres * Drop fixtures tests * Comment on db fixtures * Drop more redundant test_db parametrization * Drop skipped test
1 parent 65fec5f commit 2831c81

File tree

14 files changed

+90
-147
lines changed

14 files changed

+90
-147
lines changed

src/dstack/_internal/server/testing/conf.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@
88
from dstack._internal.server.db import Database, override_db
99
from dstack._internal.server.models import BaseModel
1010

11+
# Remember initialized URLs to create metadata once per session.
12+
_initialized_postgres_db_urls = set()
13+
1114

1215
@pytest.fixture(scope="session")
1316
def postgres_container():
1417
with PostgresContainer("postgres:16-alpine", driver="asyncpg") as postgres:
1518
yield postgres.get_connection_url()
1619

1720

21+
# test_db is function-scoped since making it session-scoped did not bring much benefit.
1822
@pytest_asyncio.fixture
1923
async def test_db(request):
2024
db_type = getattr(request, "param", "sqlite")
@@ -37,12 +41,17 @@ async def test_db(request):
3741
raise ValueError(f"Unknown db_type {db_type}")
3842
db = Database(db_url, engine=engine)
3943
override_db(db)
40-
async with db.engine.begin() as conn:
41-
await conn.run_sync(BaseModel.metadata.drop_all)
42-
await conn.run_sync(BaseModel.metadata.create_all)
44+
if db_type == "sqlite":
45+
async with db.engine.begin() as conn:
46+
await conn.run_sync(BaseModel.metadata.create_all)
47+
# Relying on function-scoped engine for a clean DB
48+
else:
49+
if db_url not in _initialized_postgres_db_urls:
50+
async with db.engine.begin() as conn:
51+
await conn.run_sync(BaseModel.metadata.create_all)
52+
_initialized_postgres_db_urls.add(db_url)
53+
await _truncate_postgres_db(db)
4354
yield db
44-
async with db.engine.begin() as conn:
45-
await conn.run_sync(BaseModel.metadata.drop_all)
4655
await db.engine.dispose()
4756

4857

@@ -51,3 +60,15 @@ async def session(test_db):
5160
db = test_db
5261
async with db.get_session() as session:
5362
yield session
63+
64+
65+
async def _truncate_postgres_db(db: Database):
66+
preparer = db.engine.sync_engine.dialect.identifier_preparer
67+
table_names = ", ".join(
68+
preparer.format_table(table) for table in BaseModel.metadata.sorted_tables
69+
)
70+
if not table_names:
71+
return
72+
truncate_statement = f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"
73+
async with db.engine.begin() as conn:
74+
await conn.exec_driver_sql(truncate_statement)

src/tests/_internal/server/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from dstack._internal.server.services import logs as logs_services
1111
from dstack._internal.server.services.docker import ImageConfig, ImageConfigObject
1212
from dstack._internal.server.services.logs.filelog import FileLogStorage
13-
from dstack._internal.server.testing.conf import postgres_container, session, test_db # noqa: F401
13+
from dstack._internal.server.testing.conf import ( # noqa: F401
14+
postgres_container,
15+
session,
16+
test_db,
17+
)
1418

1519

1620
@pytest.fixture

src/tests/_internal/server/routers/test_fleets.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,7 @@
5959

6060
class TestListFleets:
6161
@pytest.mark.asyncio
62-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
63-
async def test_returns_40x_if_not_authenticated(
64-
self, test_db, session: AsyncSession, client: AsyncClient
65-
):
62+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
6663
response = await client.post("/api/fleets/list")
6764
assert response.status_code in [401, 403]
6865

@@ -365,10 +362,7 @@ async def test_returns_fleet_once_if_imported_twice(
365362

366363
class TestListProjectFleets:
367364
@pytest.mark.asyncio
368-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
369-
async def test_returns_40x_if_not_authenticated(
370-
self, test_db, session: AsyncSession, client: AsyncClient
371-
):
365+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
372366
response = await client.post("/api/project/main/fleets/list")
373367
assert response.status_code in [401, 403]
374368

@@ -555,10 +549,7 @@ async def test_returns_fleet_once_if_imported_twice(
555549

556550
class TestGetFleet:
557551
@pytest.mark.asyncio
558-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
559-
async def test_returns_40x_if_not_authenticated(
560-
self, test_db, session: AsyncSession, client: AsyncClient
561-
):
552+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
562553
response = await client.post("/api/project/main/fleets/get")
563554
assert response.status_code in [401, 403]
564555

@@ -913,10 +904,7 @@ async def test_patches_profile_fleets_for_old_clients(
913904

914905
class TestApplyFleetPlan:
915906
@pytest.mark.asyncio
916-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
917-
async def test_returns_40x_if_not_authenticated(
918-
self, test_db, session: AsyncSession, client: AsyncClient
919-
):
907+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
920908
response = await client.post("/api/project/main/fleets/apply")
921909
assert response.status_code in [401, 403]
922910

@@ -1547,10 +1535,7 @@ async def test_importer_member_cannot_apply_plan_on_imported_fleet(
15471535

15481536
class TestDeleteFleets:
15491537
@pytest.mark.asyncio
1550-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1551-
async def test_returns_40x_if_not_authenticated(
1552-
self, test_db, session: AsyncSession, client: AsyncClient
1553-
):
1538+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
15541539
response = await client.post("/api/project/main/fleets/delete")
15551540
assert response.status_code in [401, 403]
15561541

@@ -1727,10 +1712,7 @@ async def test_importer_member_cannot_delete_imported_fleet(
17271712

17281713
class TestDeleteFleetInstances:
17291714
@pytest.mark.asyncio
1730-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1731-
async def test_returns_40x_if_not_authenticated(
1732-
self, test_db, session: AsyncSession, client: AsyncClient
1733-
):
1715+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
17341716
response = await client.post("/api/project/main/fleets/delete_instances")
17351717
assert response.status_code in [401, 403]
17361718

@@ -1973,10 +1955,7 @@ async def test_importer_member_cannot_delete_imported_fleet_instances(
19731955

19741956
class TestGetPlan:
19751957
@pytest.mark.asyncio
1976-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1977-
async def test_returns_40x_if_not_authenticated(
1978-
self, test_db, session: AsyncSession, client: AsyncClient
1979-
):
1958+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
19801959
response = await client.post("/api/project/main/fleets/get_plan")
19811960
assert response.status_code in [401, 403]
19821961

src/tests/_internal/server/routers/test_gateways.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222

2323
class TestListAndGetGateways:
2424
@pytest.mark.asyncio
25-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
26-
async def test_returns_40x_if_not_authenticated(
27-
self, test_db, session: AsyncSession, client: AsyncClient
28-
):
25+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
2926
response = await client.post("/api/project/main/gateways/list")
3027
assert response.status_code in [401, 403]
3128

src/tests/_internal/server/routers/test_projects.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929

3030
class TestListProjects:
3131
@pytest.mark.asyncio
32-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
33-
async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClient):
32+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
3433
response = await client.post("/api/projects/list")
3534
assert response.status_code in [401, 403]
3635

@@ -385,10 +384,7 @@ async def test_returns_total_count(self, test_db, session: AsyncSession, client:
385384

386385
class TestListOnlyNoFleets:
387386
@pytest.mark.asyncio
388-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
389-
async def test_list_only_no_fleets_returns_40x_if_not_authenticated(
390-
self, test_db, client: AsyncClient
391-
):
387+
async def test_list_only_no_fleets_returns_40x_if_not_authenticated(self, client: AsyncClient):
392388
response = await client.post("/api/projects/list_only_no_fleets")
393389
assert response.status_code in [401, 403]
394390

@@ -926,8 +922,7 @@ async def test_only_no_fleets_admin_requires_membership(
926922

927923
class TestCreateProject:
928924
@pytest.mark.asyncio
929-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
930-
async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClient):
925+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
931926
response = await client.post("/api/projects/create")
932927
assert response.status_code in [401, 403]
933928

@@ -1162,8 +1157,7 @@ async def test_creates_private_project_explicitly(
11621157

11631158
class TestDeleteProject:
11641159
@pytest.mark.asyncio
1165-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1166-
async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClient):
1160+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
11671161
response = await client.post("/api/projects/delete")
11681162
assert response.status_code in [401, 403]
11691163

@@ -1375,8 +1369,7 @@ async def test_errors_if_project_has_active_volumes(
13751369

13761370
class TestGetProject:
13771371
@pytest.mark.asyncio
1378-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1379-
async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClient):
1372+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
13801373
response = await client.post("/api/projects/test_project/get")
13811374
assert response.status_code in [401, 403]
13821375

@@ -1607,8 +1600,7 @@ async def test_member_can_access_both_public_and_private_projects(
16071600

16081601
class TestSetProjectMembers:
16091602
@pytest.mark.asyncio
1610-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1611-
async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClient):
1603+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
16121604
response = await client.post("/api/projects/test_project/get")
16131605
assert response.status_code in [401, 403]
16141606

@@ -1971,8 +1963,7 @@ async def test_cannot_add_same_user_twice(
19711963

19721964
class TestUpdateProjectVisibility:
19731965
@pytest.mark.asyncio
1974-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1975-
async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClient):
1966+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
19761967
response = await client.post("/api/projects/test/update")
19771968
assert response.status_code in [401, 403]
19781969

src/tests/_internal/server/routers/test_public_keys.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818

1919

2020
@pytest.mark.asyncio
21-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
22-
@pytest.mark.usefixtures("test_db")
2321
class TestListUserPublicKeys:
2422
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
2523
response = await client.post("/api/users/public_keys/list")
2624
assert response.status_code in [401, 403]
2725

26+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
27+
@pytest.mark.usefixtures("test_db")
2828
async def test_lists_own_public_keys(self, session: AsyncSession, client: AsyncClient):
2929
user = await create_user(session=session)
3030
key = await create_user_public_key(
@@ -50,6 +50,8 @@ async def test_lists_own_public_keys(self, session: AsyncSession, client: AsyncC
5050
}
5151
]
5252

53+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
54+
@pytest.mark.usefixtures("test_db")
5355
async def test_does_not_list_other_users_keys(
5456
self, session: AsyncSession, client: AsyncClient
5557
):
@@ -63,6 +65,8 @@ async def test_does_not_list_other_users_keys(
6365
assert response.status_code == 200
6466
assert response.json() == []
6567

68+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
69+
@pytest.mark.usefixtures("test_db")
6670
async def test_returns_keys_in_reverse_chronological_order(
6771
self, session: AsyncSession, client: AsyncClient
6872
):
@@ -93,8 +97,6 @@ async def test_returns_keys_in_reverse_chronological_order(
9397

9498

9599
@pytest.mark.asyncio
96-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
97-
@pytest.mark.usefixtures("test_db")
98100
class TestAddUserPublicKey:
99101
PUBLIC_KEY_NO_COMMENT = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAA"
100102
PUBLIC_KEY = f"{PUBLIC_KEY_NO_COMMENT} test@example.com"
@@ -114,6 +116,8 @@ async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
114116
)
115117
assert response.status_code in [401, 403]
116118

119+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
120+
@pytest.mark.usefixtures("test_db")
117121
@freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc))
118122
async def test_adds_valid_public_key(
119123
self,
@@ -137,6 +141,8 @@ async def test_adds_valid_public_key(
137141
}
138142
validate_openssh_public_key_mock.assert_awaited_once_with(self.PUBLIC_KEY)
139143

144+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
145+
@pytest.mark.usefixtures("test_db")
140146
@pytest.mark.usefixtures("validate_openssh_public_key_mock")
141147
async def test_adds_key_with_custom_name(self, session: AsyncSession, client: AsyncClient):
142148
user = await create_user(session=session)
@@ -148,6 +154,8 @@ async def test_adds_key_with_custom_name(self, session: AsyncSession, client: As
148154
assert response.status_code == 200
149155
assert response.json()["name"] == "my-laptop"
150156

157+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
158+
@pytest.mark.usefixtures("test_db")
151159
@pytest.mark.usefixtures("validate_openssh_public_key_mock")
152160
async def test_uses_md5_as_name_when_no_comment_and_no_name(
153161
self, session: AsyncSession, client: AsyncClient
@@ -161,6 +169,8 @@ async def test_uses_md5_as_name_when_no_comment_and_no_name(
161169
assert response.status_code == 200
162170
assert response.json()["name"] == "744e414c6ac55e3f15c1dd48229cbe74"
163171

172+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
173+
@pytest.mark.usefixtures("test_db")
164174
@pytest.mark.parametrize(
165175
"key",
166176
[
@@ -180,6 +190,8 @@ async def test_returns_400_for_invalid_key(
180190
assert response.status_code == 400
181191
assert "Invalid public key" in response.json()["detail"][0]["msg"]
182192

193+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
194+
@pytest.mark.usefixtures("test_db")
183195
async def test_returns_400_for_unsupported_key(
184196
self, session: AsyncSession, client: AsyncClient
185197
):
@@ -192,6 +204,8 @@ async def test_returns_400_for_unsupported_key(
192204
assert response.status_code == 400
193205
assert response.json()["detail"][0]["msg"] == "Unsupported key type: ssh-dss"
194206

207+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
208+
@pytest.mark.usefixtures("test_db")
195209
@pytest.mark.usefixtures("validate_openssh_public_key_mock")
196210
async def test_returns_400_resource_exists_for_duplicate_key(
197211
self, session: AsyncSession, client: AsyncClient
@@ -214,8 +228,6 @@ async def test_returns_400_resource_exists_for_duplicate_key(
214228

215229

216230
@pytest.mark.asyncio
217-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
218-
@pytest.mark.usefixtures("test_db")
219231
class TestDeleteUserPublicKeys:
220232
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
221233
response = await client.post(
@@ -224,6 +236,8 @@ async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
224236
)
225237
assert response.status_code in [401, 403]
226238

239+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
240+
@pytest.mark.usefixtures("test_db")
227241
async def test_deletes_public_key(self, session: AsyncSession, client: AsyncClient):
228242
user = await create_user(session=session)
229243
key = await create_user_public_key(session=session, user=user)
@@ -241,6 +255,8 @@ async def test_deletes_public_key(self, session: AsyncSession, client: AsyncClie
241255
)
242256
assert res.scalars().all() == [other_key]
243257

258+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
259+
@pytest.mark.usefixtures("test_db")
244260
async def test_silently_ignores_nonexistent_ids(
245261
self, session: AsyncSession, client: AsyncClient
246262
):
@@ -252,6 +268,8 @@ async def test_silently_ignores_nonexistent_ids(
252268
)
253269
assert response.status_code == 200
254270

271+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
272+
@pytest.mark.usefixtures("test_db")
255273
async def test_does_not_delete_other_users_keys(
256274
self, session: AsyncSession, client: AsyncClient
257275
):

src/tests/_internal/server/routers/test_runs.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,7 @@ def get_service_run_spec(
642642

643643
class TestListRuns:
644644
@pytest.mark.asyncio
645-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
646-
async def test_returns_40x_if_not_authenticated(
647-
self, test_db, session: AsyncSession, client: AsyncClient
648-
):
645+
async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
649646
response = await client.post("/api/runs/list")
650647
assert response.status_code in [401, 403]
651648

src/tests/_internal/server/routers/test_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class TestGetInfo:
1010
@pytest.mark.asyncio
11-
async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClient):
11+
async def test_returns_server_info(self, test_db, client: AsyncClient):
1212
with patch.object(settings, "DSTACK_VERSION", "0.18.10"):
1313
response = await client.post("/api/server/get_info")
1414
assert response.status_code == 200

0 commit comments

Comments
 (0)