Skip to content

Commit 5e8bf1d

Browse files
romanlutzCopilot
andauthored
MAINT: Fix flaky sleeps and MagicMock misuse in unit tests (#1874)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent b3b018f commit 5e8bf1d

10 files changed

Lines changed: 55 additions & 35 deletions

File tree

tests/unit/auth/test_copilot_authenticator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,9 @@ async def test_get_token_serializes_concurrent_requests(self, mock_env_vars, moc
491491
async def mock_fetch():
492492
nonlocal fetch_call_count
493493
fetch_call_count += 1
494-
await asyncio.sleep(0.01) # minimal delay to test concurrency
494+
# Yield once so concurrent callers contend for the lock; the lock
495+
# guarantees serialization regardless of real-time delays.
496+
await asyncio.sleep(0)
495497
return f"token.{fetch_call_count}"
496498

497499
def mock_load_side_effect():

tests/unit/cli/test_pyrit_shell.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ def test_run_async_raises_timeout_error(self):
191191
try:
192192

193193
async def hangs():
194-
await asyncio.sleep(10)
194+
# Block on an Event that's never set so the coroutine truly
195+
# cannot complete on its own; the timeout under test must cut it off.
196+
await asyncio.Event().wait()
195197

196198
with pytest.raises(TimeoutError, match="did not complete"):
197199
s._run_async(hangs(), timeout=0.05)

tests/unit/exceptions/test_retry_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ def test_contextvar_isolation_across_tasks(self) -> None:
9595
async def task_a() -> None:
9696
c = RetryCollector()
9797
set_retry_collector(c)
98-
await asyncio.sleep(0.01)
98+
await asyncio.sleep(0)
9999
results["a_has_collector"] = get_retry_collector() is c
100100
clear_retry_collector()
101101

102102
async def task_b() -> None:
103-
await asyncio.sleep(0.005)
103+
await asyncio.sleep(0)
104104
results["b_sees_none"] = get_retry_collector() is None
105105

106106
async def run() -> None:

tests/unit/executor/attack/core/test_attack_executor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ async def mock_execute(*, context):
259259
nonlocal concurrent_count, max_concurrent
260260
concurrent_count += 1
261261
max_concurrent = max(max_concurrent, concurrent_count)
262-
await asyncio.sleep(0.05)
262+
# Yield so other tasks bounded by the semaphore can also enter.
263+
await asyncio.sleep(0)
263264
concurrent_count -= 1
264265
return create_attack_result(context.params.objective)
265266

@@ -282,7 +283,8 @@ async def test_single_concurrency_serializes_execution(self):
282283
async def mock_execute(*, context):
283284
objective = context.params.objective
284285
execution_order.append(f"start_{objective}")
285-
await asyncio.sleep(0.01)
286+
# Yield once so another task could interleave if max_concurrency > 1.
287+
await asyncio.sleep(0)
286288
execution_order.append(f"end_{objective}")
287289
return create_attack_result(objective)
288290

@@ -455,9 +457,9 @@ async def test_attribution_parallel_safe_with_high_concurrency(self):
455457
async def out_of_order(context):
456458
attr = context._attribution
457459
assert attr is not None
458-
# Reverse-delay tasks so completion order is inverse of input order.
459-
i = int(context.params.objective.split("-")[1])
460-
await asyncio.sleep(0.005 * (10 - i))
460+
# Yield so all tasks run concurrently under the high-concurrency executor;
461+
# the assertion verifies attribution is per-task regardless of order.
462+
await asyncio.sleep(0)
461463
seen[context.params.objective] = attr
462464
return create_attack_result(context.params.objective)
463465

tests/unit/models/test_message.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,23 @@ def test_duplicate_message_preserves_original_prompt_id(self, message: Message)
152152

153153
def test_duplicate_message_creates_new_timestamp(self, message: Message) -> None:
154154
"""Test that duplicate_message creates new timestamps."""
155-
import time
155+
from datetime import timedelta, timezone
156+
from unittest.mock import patch
156157

157158
original_timestamps = [piece.timestamp for piece in message.message_pieces]
159+
fake_now = max(original_timestamps) + timedelta(seconds=1)
158160

159-
time.sleep(0.01) # Small delay to ensure different timestamp
160-
duplicated = message.duplicate_message()
161+
with patch("pyrit.models.messages.message.datetime") as mock_datetime:
162+
mock_datetime.now.return_value = fake_now
163+
duplicated = message.duplicate_message()
161164

162165
for dup_piece in duplicated.message_pieces:
163-
# Verify timestamp is newer than all original timestamps
166+
# Every duplicated piece shares the new timestamp produced by duplicate_message.
167+
assert dup_piece.timestamp == fake_now
168+
# And it is strictly newer than every original timestamp.
164169
for orig_ts in original_timestamps:
165-
assert dup_piece.timestamp >= orig_ts
170+
assert dup_piece.timestamp > orig_ts
171+
mock_datetime.now.assert_called_once_with(tz=timezone.utc)
166172

167173
def test_duplicate_message_is_deep_copy(self, message: Message) -> None:
168174
"""Test that duplicate_message creates a deep copy (modifications don't affect original)."""

tests/unit/models/test_message_piece.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
import os
55
import tempfile
6-
import time
76
import uuid
87
import warnings
98
from collections.abc import MutableSequence
109
from datetime import datetime, timedelta, timezone
10+
from unittest.mock import patch
1111

1212
import pytest
1313
from unit.mocks import MockPromptTarget, get_mock_target, get_sample_conversations
@@ -41,14 +41,16 @@ def test_id_set():
4141

4242

4343
def test_datetime_set():
44-
now = datetime.now(tz=timezone.utc)
45-
time.sleep(0.1)
46-
entry = MessagePiece(
47-
role="user",
48-
original_value="Hello",
49-
converted_value="Hello",
50-
)
51-
assert entry.timestamp > now
44+
fake_now = datetime(2099, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
45+
with patch("pyrit.models.messages.message_piece.datetime") as mock_datetime:
46+
mock_datetime.now.return_value = fake_now
47+
entry = MessagePiece(
48+
role="user",
49+
original_value="Hello",
50+
converted_value="Hello",
51+
)
52+
assert entry.timestamp == fake_now
53+
mock_datetime.now.assert_called_once_with(tz=timezone.utc)
5254

5355

5456
def test_converters_serialize():

tests/unit/output/test_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ async def test_output_scenario_async_forwards_sort_groups_by_success_rate(mock_c
117117

118118
async def test_output_scenario_async_unsupported_format():
119119
with pytest.raises(ValueError, match="Unsupported format"):
120-
await output_scenario_async(MagicMock(), format="markdown")
120+
await output_scenario_async(AsyncMock(), format="markdown")
121121

122122

123123
# --- output_scorer_async tests ---
@@ -150,7 +150,7 @@ async def test_output_scorer_async_with_harm_category(mock_cls):
150150

151151
async def test_output_scorer_async_unsupported_format():
152152
with pytest.raises(ValueError, match="Unsupported format"):
153-
await output_scorer_async(scorer_identifier=MagicMock(), format="markdown")
153+
await output_scorer_async(scorer_identifier=AsyncMock(), format="markdown")
154154

155155

156156
# --- output_conversation_async tests ---
@@ -185,7 +185,7 @@ async def test_output_conversation_async_with_scores(mock_cls):
185185

186186
async def test_output_conversation_async_unsupported_format():
187187
with pytest.raises(ValueError, match="Unsupported format"):
188-
await output_conversation_async([MagicMock()], format="markdown")
188+
await output_conversation_async([AsyncMock()], format="markdown")
189189

190190

191191
# --- output_score_async tests ---
@@ -208,4 +208,4 @@ async def test_output_score_async_pretty_default(mock_cls):
208208

209209
async def test_output_score_async_unsupported_format():
210210
with pytest.raises(ValueError, match="Unsupported format"):
211-
await output_score_async([MagicMock()], format="markdown")
211+
await output_score_async([AsyncMock()], format="markdown")

tests/unit/prompt_target/test_discover_target_capabilities.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,9 @@ async def test_timeout_returns_false_after_retries(self) -> None:
755755
target = MockPromptTarget()
756756

757757
async def _hang(**_kwargs: object) -> list[Message]:
758-
await asyncio.sleep(10)
758+
# Block on an Event that's never set so the probe truly cannot
759+
# complete on its own; per_probe_timeout_s must cut it off.
760+
await asyncio.Event().wait()
759761
return _ok_response()
760762

761763
target._send_prompt_to_target_async = AsyncMock(side_effect=_hang) # type: ignore[method-assign]

tests/unit/scenario/core/test_scenario.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,8 @@ async def run_async(*, executor, **kwargs):
12181218
async with lock:
12191219
in_flight[0] += 1
12201220
peak[0] = max(peak[0], in_flight[0])
1221-
await asyncio.sleep(0.02)
1221+
# Yield so other tasks contending for the semaphore can enter.
1222+
await asyncio.sleep(0)
12221223
async with lock:
12231224
in_flight[0] -= 1
12241225
_stamp_scenario_linkage(
@@ -1313,7 +1314,9 @@ async def test_failure_lets_inflight_siblings_finish_but_skips_queued(
13131314

13141315
async def ok_run(idx, name):
13151316
started_calls.append(name)
1316-
await asyncio.sleep(0.05)
1317+
# Wait for the bad task to fail before this one completes, so the
1318+
# failure is observed mid-flight (no wall-clock dependency).
1319+
await bad_started.wait()
13171320
completed_calls.append(name)
13181321
_stamp_scenario_linkage(
13191322
attack_results=[sample_attack_results[idx]],
@@ -1367,7 +1370,8 @@ async def test_multiple_inflight_failures_are_grouped_into_exception_group(
13671370
# observed (no queueing) and every failure should propagate.
13681371
def make_fail_run(name: str):
13691372
async def _run(*args, **kwargs):
1370-
await asyncio.sleep(0.01)
1373+
# Yield so all three workers are in-flight before any fails.
1374+
await asyncio.sleep(0)
13711375
raise RuntimeError(f"{name} boom")
13721376

13731377
return AsyncMock(side_effect=_run)

tests/unit/score/test_scorer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -998,15 +998,15 @@ async def test_score_response_async_concurrent_execution():
998998

999999
async def mock_aux_score_async(message: Message, **kwargs) -> list[Score]:
10001000
call_order.append("aux_start")
1001-
# Simulate some async work
1002-
await asyncio.sleep(0.01)
1001+
# Yield so the other scorer can interleave (proves concurrent execution).
1002+
await asyncio.sleep(0)
10031003
call_order.append("aux_end")
10041004
return [MagicMock(spec=Score)]
10051005

10061006
async def mock_obj_score_async(message: Message, **kwargs) -> list[Score]:
10071007
call_order.append("obj_start")
1008-
# Simulate some async work
1009-
await asyncio.sleep(0.01)
1008+
# Yield so the other scorer can interleave (proves concurrent execution).
1009+
await asyncio.sleep(0)
10101010
call_order.append("obj_end")
10111011
score = MagicMock(spec=Score)
10121012
score.get_value.return_value = True

0 commit comments

Comments
 (0)