|
27 | 27 | from tests.unit.utils.auth_helpers import mock_authorization_resolvers |
28 | 28 |
|
29 | 29 | MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") |
| 30 | + |
| 31 | + |
| 32 | +@pytest.fixture(autouse=True) |
| 33 | +def _reset_feedback_config(): |
| 34 | + """Save and restore feedback configuration so tests don't leak state.""" |
| 35 | + original_enabled = configuration.user_data_collection_configuration.feedback_enabled |
| 36 | + original_storage = configuration.user_data_collection_configuration.feedback_storage |
| 37 | + yield |
| 38 | + configuration.user_data_collection_configuration.feedback_enabled = original_enabled |
| 39 | + configuration.user_data_collection_configuration.feedback_storage = original_storage |
30 | 40 | VALID_BASE = { |
31 | 41 | "conversation_id": "12345678-abcd-0000-0123-456789abcdef", |
32 | 42 | "user_question": "What is Kubernetes?", |
@@ -383,24 +393,26 @@ def test_update_feedback_status_concurrent(mocker: MockerFixture) -> None: |
383 | 393 | configuration.user_data_collection_configuration.feedback_enabled = True |
384 | 394 |
|
385 | 395 | auth: AuthTuple = ("test_user_id", "test_user", True, "test_token") |
386 | | - results: list[Any] = [None, None, None] |
387 | | - errors: list[Exception | None] = [None, None, None] |
| 396 | + thread_args = [(0, False), (1, True), (2, False)] |
| 397 | + results: list[Any] = [None] * len(thread_args) |
| 398 | + errors: list[Exception | None] = [None] * len(thread_args) |
| 399 | + barrier = threading.Barrier(len(thread_args) + 1) |
388 | 400 |
|
389 | 401 | def worker(index: int, desired_status: bool) -> None: |
390 | 402 | """Thread worker that calls update_feedback_status.""" |
391 | 403 | req = FeedbackStatusUpdateRequest(status=desired_status) |
392 | 404 | try: |
| 405 | + barrier.wait() |
393 | 406 | results[index] = asyncio.run(update_feedback_status(req, auth=auth)) |
394 | 407 | except Exception as exc: # pylint: disable=broad-exception-caught |
395 | 408 | errors[index] = exc |
396 | 409 |
|
397 | 410 | threads = [ |
398 | | - threading.Thread(target=worker, args=(0, False)), |
399 | | - threading.Thread(target=worker, args=(1, True)), |
400 | | - threading.Thread(target=worker, args=(2, False)), |
| 411 | + threading.Thread(target=worker, args=args) for args in thread_args |
401 | 412 | ] |
402 | 413 | for t in threads: |
403 | 414 | t.start() |
| 415 | + barrier.wait() |
404 | 416 | for t in threads: |
405 | 417 | t.join() |
406 | 418 |
|
|
0 commit comments