|
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
| 5 | +import asyncio |
5 | 6 | import os |
6 | 7 | from unittest.mock import Mock |
7 | 8 |
|
@@ -345,6 +346,41 @@ async def test_run_with_document_content_none_async(self, monkeypatch: pytest.Mo |
345 | 346 | # Ensure no attempt was made to call the LLM |
346 | 347 | mock_chat_generator.run_async.assert_not_called() |
347 | 348 |
|
| 349 | + @pytest.mark.asyncio |
| 350 | + async def test_run_async_respects_max_workers(self, monkeypatch: pytest.MonkeyPatch) -> None: |
| 351 | + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") |
| 352 | + |
| 353 | + max_workers = 2 |
| 354 | + in_flight = 0 |
| 355 | + peak_in_flight = 0 |
| 356 | + |
| 357 | + mock_chat_generator = Mock(spec=OpenAIChatGenerator) |
| 358 | + |
| 359 | + async def fake_run_async(messages, **kwargs): |
| 360 | + nonlocal in_flight, peak_in_flight |
| 361 | + in_flight += 1 |
| 362 | + peak_in_flight = max(peak_in_flight, in_flight) |
| 363 | + try: |
| 364 | + await asyncio.sleep(0.01) |
| 365 | + return {"replies": [ChatMessage.from_assistant('{"entities": []}')]} |
| 366 | + finally: |
| 367 | + in_flight -= 1 |
| 368 | + |
| 369 | + mock_chat_generator.run_async = fake_run_async |
| 370 | + |
| 371 | + extractor = LLMMetadataExtractor( |
| 372 | + prompt="prompt {{document.content}}", |
| 373 | + chat_generator=mock_chat_generator, |
| 374 | + expected_keys=["entities"], |
| 375 | + max_workers=max_workers, |
| 376 | + ) |
| 377 | + |
| 378 | + docs = [Document(content=f"doc {i}") for i in range(10)] |
| 379 | + result = await extractor.run_async(documents=docs) |
| 380 | + |
| 381 | + assert len(result["documents"]) == 10 |
| 382 | + assert peak_in_flight <= max_workers |
| 383 | + |
348 | 384 | @pytest.mark.integration |
349 | 385 | @pytest.mark.skipif( |
350 | 386 | not os.environ.get("OPENAI_API_KEY", None), |
|
0 commit comments