diff --git a/haystack/components/extractors/llm_metadata_extractor.py b/haystack/components/extractors/llm_metadata_extractor.py index 3eb0231d07..df8c0a71fa 100644 --- a/haystack/components/extractors/llm_metadata_extractor.py +++ b/haystack/components/extractors/llm_metadata_extractor.py @@ -439,10 +439,14 @@ async def run_async(self, documents: list[Document], page_range: list[str | int] # Create ChatMessage prompts for each document all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range) - # Run the LLM on each prompt + # Run the LLM on each prompt, bounding concurrency per task so max_workers is enforced. sem = Semaphore(max(1, self.max_workers)) - async with sem: - results = await gather(*[self._run_async(prompt) for prompt in all_prompts]) + + async def _bounded_run(prompt: ChatMessage | None) -> dict[str, Any]: + async with sem: + return await self._run_async(prompt) + + results = await gather(*[_bounded_run(prompt) for prompt in all_prompts]) successful_documents, failed_documents = self._process_results(documents, results) diff --git a/releasenotes/notes/fix-llm-metadata-extractor-async-semaphore-ba6053152b0ecaac.yaml b/releasenotes/notes/fix-llm-metadata-extractor-async-semaphore-ba6053152b0ecaac.yaml new file mode 100644 index 0000000000..30dfcf1ce1 --- /dev/null +++ b/releasenotes/notes/fix-llm-metadata-extractor-async-semaphore-ba6053152b0ecaac.yaml @@ -0,0 +1,9 @@ +--- +fixes: + - | + Fixed a bug in ``LLMMetadataExtractor.run_async`` where the ``asyncio.Semaphore`` + intended to bound concurrent LLM calls to ``max_workers`` was acquired once + around the outer ``gather(...)`` call instead of inside each task. As a result, + ``max_workers`` had no effect in ``run_async`` and all LLM requests for a batch + were issued simultaneously. The semaphore is now acquired per task, so + ``max_workers`` correctly caps in-flight requests. diff --git a/test/components/extractors/test_llm_metadata_extractor.py b/test/components/extractors/test_llm_metadata_extractor.py index 0b3b5659fa..391f5e9a1f 100644 --- a/test/components/extractors/test_llm_metadata_extractor.py +++ b/test/components/extractors/test_llm_metadata_extractor.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import asyncio import os from unittest.mock import Mock @@ -345,6 +346,41 @@ async def test_run_with_document_content_none_async(self, monkeypatch: pytest.Mo # Ensure no attempt was made to call the LLM mock_chat_generator.run_async.assert_not_called() + @pytest.mark.asyncio + async def test_run_async_respects_max_workers(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + + max_workers = 2 + in_flight = 0 + peak_in_flight = 0 + + mock_chat_generator = Mock(spec=OpenAIChatGenerator) + + async def fake_run_async(messages, **kwargs): + nonlocal in_flight, peak_in_flight + in_flight += 1 + peak_in_flight = max(peak_in_flight, in_flight) + try: + await asyncio.sleep(0.01) + return {"replies": [ChatMessage.from_assistant('{"entities": []}')]} + finally: + in_flight -= 1 + + mock_chat_generator.run_async = fake_run_async + + extractor = LLMMetadataExtractor( + prompt="prompt {{document.content}}", + chat_generator=mock_chat_generator, + expected_keys=["entities"], + max_workers=max_workers, + ) + + docs = [Document(content=f"doc {i}") for i in range(10)] + result = await extractor.run_async(documents=docs) + + assert len(result["documents"]) == 10 + assert peak_in_flight <= max_workers + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None),