Skip to content

Commit 1f85ab9

Browse files
authored
test: use freezegun instead of sleeps in time-sensitive tests (#957)
test: centralize freezegun mock time helpers and helper dependencies to testing package
1 parent 13d4f8d commit 1f85ab9

11 files changed

Lines changed: 168 additions & 52 deletions

File tree

diracx-db/tests/jobs/test_sandbox_metadata.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import secrets
5-
from datetime import datetime
4+
from datetime import datetime, timedelta
65

76
import pytest
87
import sqlalchemy
@@ -11,12 +10,14 @@
1110
from diracx.core.models import SandboxInfo, UserInfo
1211
from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB
1312
from diracx.db.sql.sandbox_metadata.schema import SandBoxes, SBEntityMapping
13+
from diracx.testing.time import install_sqlite_time_mock
1414

1515

1616
@pytest.fixture
1717
async def sandbox_metadata_db(tmp_path):
1818
sandbox_metadata_db = SandboxMetadataDB("sqlite+aiosqlite:///:memory:")
1919
async with sandbox_metadata_db.engine_context():
20+
install_sqlite_time_mock(sandbox_metadata_db.engine)
2021
async with sandbox_metadata_db.engine.begin() as conn:
2122
await conn.run_sync(sandbox_metadata_db.metadata.create_all)
2223
yield sandbox_metadata_db
@@ -39,7 +40,7 @@ def test_get_pfn(sandbox_metadata_db: SandboxMetadataDB):
3940
)
4041

4142

42-
async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB):
43+
async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB, frozen_time):
4344
# TODO: DAL tests should be very simple, such complex tests should be handled in diracx-routers
4445
user_info = UserInfo(
4546
sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo"
@@ -64,7 +65,7 @@ async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB):
6465
db_contents = await _dump_db(sandbox_metadata_db)
6566
owner_id1, last_access_time1 = db_contents[pfn1]
6667

67-
await asyncio.sleep(1) # The timestamp only has second precision
68+
frozen_time.tick(delta=timedelta(seconds=1))
6869
async with sandbox_metadata_db:
6970
with pytest.raises(SandboxAlreadyInsertedError):
7071
await sandbox_metadata_db.insert_sandbox(owner_id, "SandboxSE", pfn1, 100)
@@ -81,7 +82,7 @@ async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB):
8182
assert not await sandbox_metadata_db.sandbox_is_assigned(pfn1, "SandboxSE")
8283

8384
# Inserting again should update the last access time
84-
await asyncio.sleep(1) # The timestamp only has second precision
85+
frozen_time.tick(delta=timedelta(seconds=1))
8586
last_access_time3 = (await _dump_db(sandbox_metadata_db))[pfn1][1]
8687
assert last_access_time2 == last_access_time3
8788
async with sandbox_metadata_db:

diracx-db/tests/test_freeze_time.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sqlalchemy
99
from sqlalchemy.ext.asyncio import create_async_engine
1010

11-
from diracx.testing.time import julian_date, mock_sqlite_time
11+
from diracx.testing.time import install_sqlite_time_mock, julian_date
1212

1313
RE_SQLITE_TIME = re.compile(r"(\d{4})-(\d{2})-(\d{2})(?: (\d{2}):(\d{2}):(\d{2}))?")
1414

@@ -40,7 +40,7 @@ async def test_freeze_sqlite_datetime(with_mock):
4040
"""Test the SQLite DATETIME() function with freezegun."""
4141
engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True, echo=True)
4242
if with_mock:
43-
sqlalchemy.event.listen(engine.sync_engine, "connect", mock_sqlite_time)
43+
install_sqlite_time_mock(engine)
4444

4545
async with engine.begin() as conn:
4646
# DATETIME()
@@ -83,7 +83,7 @@ async def test_freeze_sqlite_julianday(with_mock):
8383
"""Test the SQLite JULIANDAY() function with freezegun."""
8484
engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True, echo=True)
8585
if with_mock:
86-
sqlalchemy.event.listen(engine.sync_engine, "connect", mock_sqlite_time)
86+
install_sqlite_time_mock(engine)
8787

8888
async with engine.begin() as conn:
8989
# JULIANDAY()

diracx-logic/tests/jobs/test_sandboxes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import httpx2
1111
import pytest
1212
import signurlarity.exceptions
13-
import sqlalchemy
1413

1514
from diracx.core.exceptions import SandboxNotFoundError
1615
from diracx.core.models import ChecksumAlgorithm, SandboxFormat, SandboxInfo, UserInfo
@@ -21,7 +20,7 @@
2120
get_sandbox_file,
2221
initiate_sandbox_upload,
2322
)
24-
from diracx.testing.time import mock_sqlite_time
23+
from diracx.testing.time import install_sqlite_time_mock
2524

2625
FAKE_USER_INFO = UserInfo(
2726
sub="fakevo:97ae90d3-36aa-4271-becf-e61173d93fe3",
@@ -36,7 +35,7 @@ async def sandbox_metadata_db() -> AsyncGenerator[SandboxMetadataDB, None]:
3635
"""Create a fake sandbox metadata database."""
3736
db = SandboxMetadataDB(db_url="sqlite+aiosqlite:///:memory:")
3837
async with db.engine_context():
39-
sqlalchemy.event.listen(db.engine.sync_engine, "connect", mock_sqlite_time)
38+
install_sqlite_time_mock(db.engine)
4039

4140
async with db.engine.begin() as conn:
4241
await conn.run_sync(db.metadata.create_all)

diracx-logic/tests/jobs/test_status.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
from datetime import datetime, timezone
55

66
import pytest
7-
import sqlalchemy
87

98
from diracx.core.models import JobMetaData
109
from diracx.db.os.job_parameters import JobParametersDB as RealJobParametersDB
1110
from diracx.db.sql.job.db import JobDB
1211
from diracx.logic.jobs import set_job_parameters_or_attributes
1312
from diracx.testing.mock_osdb import MockOSDBMixin
14-
from diracx.testing.time import mock_sqlite_time
13+
from diracx.testing.time import install_sqlite_time_mock
1514

1615

1716
# Reuse the generic MockOSDBMixin to build a mock JobParameters DB implementation
@@ -36,7 +35,7 @@ async def job_db() -> AsyncGenerator[JobDB, None]:
3635
"""Create a fake sandbox metadata database."""
3736
db = JobDB(db_url="sqlite+aiosqlite:///:memory:")
3837
async with db.engine_context():
39-
sqlalchemy.event.listen(db.engine.sync_engine, "connect", mock_sqlite_time)
38+
install_sqlite_time_mock(db.engine)
4039

4140
async with db.engine.begin() as conn:
4241
await conn.run_sync(db.metadata.create_all)

diracx-routers/tests/auth/test_standard.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import jwt
1313
import pytest
1414
from cryptography.fernet import Fernet
15-
from freezegun import freeze_time
1615
from joserfc.jwk import RSAKey, OKPKey, KeySet
1716
from joserfc.errors import (
1817
UnsupportedKeyOperationError,
@@ -562,56 +561,60 @@ async def test_refresh_token_rotation(test_client, auth_httpx_mock: HTTPXMock):
562561

563562

564563
async def test_refresh_token_expired(
565-
test_client, test_auth_settings: AuthSettings, auth_httpx_mock: HTTPXMock
564+
test_client,
565+
test_auth_settings: AuthSettings,
566+
auth_httpx_mock: HTTPXMock,
567+
frozen_time,
566568
):
567569
"""Test the expiration date of the passed refresh token.
568570
569571
- get a refresh token
570572
- move time forward past its expiration time
571573
- ensure the expired token is rejected.
572574
"""
573-
with freeze_time(datetime.now(tz=timezone.utc)) as frozen_time:
574-
# Get refresh token
575-
refresh_token = _get_tokens(test_client)["refresh_token"]
575+
# Get refresh token
576+
refresh_token = _get_tokens(test_client)["refresh_token"]
576577

577-
frozen_time.tick(
578-
delta=timedelta(minutes=test_auth_settings.refresh_token_expire_minutes + 1)
579-
)
578+
frozen_time.tick(
579+
delta=timedelta(minutes=test_auth_settings.refresh_token_expire_minutes + 1)
580+
)
580581

581-
request_data = {
582-
"grant_type": "refresh_token",
583-
"refresh_token": refresh_token,
584-
"client_id": DIRAC_CLIENT_ID,
585-
}
582+
request_data = {
583+
"grant_type": "refresh_token",
584+
"refresh_token": refresh_token,
585+
"client_id": DIRAC_CLIENT_ID,
586+
}
586587

587-
# Try to get a new access token using the expired refresh token
588-
r = test_client.post("/api/auth/token", data=request_data)
588+
# Try to get a new access token using the expired refresh token
589+
r = test_client.post("/api/auth/token", data=request_data)
589590
data = r.json()
590591
assert r.status_code == 401, data
591592
assert data["detail"] == "expired_token: The token is expired"
592593

593594

594595
async def test_access_token_expired(
595-
test_client, test_auth_settings: AuthSettings, auth_httpx_mock: HTTPXMock
596+
test_client,
597+
test_auth_settings: AuthSettings,
598+
auth_httpx_mock: HTTPXMock,
599+
frozen_time,
596600
):
597601
"""Test the expiration date of the passed access token.
598602
599603
- get an access token
600604
- move time forward past its expiration time
601605
- ensure the expired token is rejected.
602606
"""
603-
with freeze_time(datetime.now(tz=timezone.utc)) as frozen_time:
604-
# Get access token
605-
access_token = _get_tokens(test_client)["access_token"]
607+
# Get access token
608+
access_token = _get_tokens(test_client)["access_token"]
606609

607-
frozen_time.tick(
608-
delta=timedelta(minutes=test_auth_settings.access_token_expire_minutes + 1)
609-
)
610+
frozen_time.tick(
611+
delta=timedelta(minutes=test_auth_settings.access_token_expire_minutes + 1)
612+
)
610613

611-
headers = {"Authorization": f"Bearer {access_token}"}
614+
headers = {"Authorization": f"Bearer {access_token}"}
612615

613-
# Try to get the userinfo using the expired access token
614-
r = test_client.get("/api/auth/userinfo", headers=headers)
616+
# Try to get the userinfo using the expired access token
617+
r = test_client.get("/api/auth/userinfo", headers=headers)
615618
data = r.json()
616619
assert r.status_code == 401, data
617620
assert data["detail"] == "Invalid JWT"

diracx-routers/tests/jobs/test_heartbeat_commands.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from datetime import datetime, timedelta, timezone
4-
from time import sleep
54

65
import pytest
76
from fastapi.testclient import TestClient
@@ -23,7 +22,7 @@
2322
)
2423

2524

26-
def test_heartbeat(normal_user_client: TestClient, valid_job_id: int):
25+
def test_heartbeat(frozen_time, normal_user_client: TestClient, valid_job_id: int):
2726
search_body = {
2827
"search": [{"parameter": "JobID", "operator": "eq", "value": valid_job_id}]
2928
}
@@ -60,7 +59,7 @@ def test_heartbeat(normal_user_client: TestClient, valid_job_id: int):
6059
)
6160
r.raise_for_status()
6261

63-
sleep(1)
62+
frozen_time.tick(delta=timedelta(seconds=1))
6463
# Send another heartbeat and check that a Kill job command was set
6564
payload = {valid_job_id: {"Vsize": 1235}}
6665
r = normal_user_client.patch("/api/jobs/heartbeat", json=payload)
@@ -74,7 +73,7 @@ def test_heartbeat(normal_user_client: TestClient, valid_job_id: int):
7473
assert commands[0]["command"] == "Kill", (
7574
f"Wrong job command received, should be 'Kill' but got {commands[0]=}"
7675
)
77-
sleep(1)
76+
frozen_time.tick(delta=timedelta(seconds=1))
7877

7978
# Send another heartbeat and check the job commands are empty
8079
payload = {valid_job_id: {"Vsize": 1234}}
@@ -87,6 +86,7 @@ def test_heartbeat(normal_user_client: TestClient, valid_job_id: int):
8786

8887

8988
def test_multiple_jobs_receive_independent_kill_commands(
89+
frozen_time,
9090
normal_user_client: TestClient,
9191
valid_job_ids: list[int],
9292
):
@@ -105,7 +105,7 @@ def test_multiple_jobs_receive_independent_kill_commands(
105105
)
106106
r.raise_for_status()
107107

108-
sleep(1)
108+
frozen_time.tick(delta=timedelta(seconds=1))
109109

110110
r = normal_user_client.patch(
111111
"/api/jobs/heartbeat",
@@ -118,7 +118,7 @@ def test_multiple_jobs_receive_independent_kill_commands(
118118
assert {cmd["job_id"] for cmd in commands} == set(valid_job_ids)
119119
assert {cmd["command"] for cmd in commands} == {"Kill"}
120120

121-
sleep(1)
121+
frozen_time.tick(delta=timedelta(seconds=1))
122122

123123
r = normal_user_client.patch(
124124
"/api/jobs/heartbeat",
@@ -130,6 +130,7 @@ def test_multiple_jobs_receive_independent_kill_commands(
130130

131131

132132
def test_non_killed_status_does_not_create_command(
133+
frozen_time,
133134
normal_user_client: TestClient,
134135
valid_job_id: int,
135136
):
@@ -147,7 +148,7 @@ def test_non_killed_status_does_not_create_command(
147148
)
148149
r.raise_for_status()
149150

150-
sleep(1)
151+
frozen_time.tick(delta=timedelta(seconds=1))
151152

152153
r = normal_user_client.patch(
153154
"/api/jobs/heartbeat",
@@ -159,6 +160,7 @@ def test_non_killed_status_does_not_create_command(
159160

160161

161162
def test_deleted_creates_kill_command(
163+
frozen_time,
162164
normal_user_client: TestClient,
163165
valid_job_id: int,
164166
):
@@ -176,7 +178,7 @@ def test_deleted_creates_kill_command(
176178
)
177179
r.raise_for_status()
178180

179-
sleep(1)
181+
frozen_time.tick(delta=timedelta(seconds=1))
180182

181183
r = normal_user_client.patch(
182184
"/api/jobs/heartbeat",
@@ -189,7 +191,7 @@ def test_deleted_creates_kill_command(
189191
assert commands[0]["job_id"] == valid_job_id
190192
assert commands[0]["command"] == "Kill"
191193

192-
sleep(1)
194+
frozen_time.tick(delta=timedelta(seconds=1))
193195

194196
r = normal_user_client.patch(
195197
"/api/jobs/heartbeat",

diracx-testing/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ dependencies = [
2020
"pytest-xdist",
2121
"httpx2",
2222
"joserfc",
23+
"freezegun",
24+
"sqlalchemy",
2325
"uuid-utils",
2426
"pytest-github-actions-annotate-failures",
2527
]

diracx-testing/src/diracx/testing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"demo_urls",
1111
"do_device_flow_with_dex",
1212
"fernet_key",
13+
"frozen_time",
1314
"private_key",
1415
"pytest_addoption",
1516
"session_client_factory",
@@ -23,6 +24,7 @@
2324
]
2425

2526
from .entrypoints import verify_entry_points
27+
from .time import frozen_time
2628
from .utils import (
2729
ClientFactory,
2830
aio_moto,

0 commit comments

Comments
 (0)