Skip to content

Commit 50b2141

Browse files
etairlbogdankostic
andauthored
fix: enforce max_workers in LLMMetadataExtractor.run_async (#11248)
Co-authored-by: bogdankostic <bogdankostic@web.de>
1 parent 78f954b commit 50b2141

3 files changed

Lines changed: 52 additions & 3 deletions

File tree

haystack/components/extractors/llm_metadata_extractor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,14 @@ async def run_async(self, documents: list[Document], page_range: list[str | int]
439439
# Create ChatMessage prompts for each document
440440
all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range)
441441

442-
# Run the LLM on each prompt
442+
# Run the LLM on each prompt, bounding concurrency per task so max_workers is enforced.
443443
sem = Semaphore(max(1, self.max_workers))
444-
async with sem:
445-
results = await gather(*[self._run_async(prompt) for prompt in all_prompts])
444+
445+
async def _bounded_run(prompt: ChatMessage | None) -> dict[str, Any]:
446+
async with sem:
447+
return await self._run_async(prompt)
448+
449+
results = await gather(*[_bounded_run(prompt) for prompt in all_prompts])
446450

447451
successful_documents, failed_documents = self._process_results(documents, results)
448452

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed a bug in ``LLMMetadataExtractor.run_async`` where the ``asyncio.Semaphore``
5+
intended to bound concurrent LLM calls to ``max_workers`` was acquired once
6+
around the outer ``gather(...)`` call instead of inside each task. As a result,
7+
``max_workers`` had no effect in ``run_async`` and all LLM requests for a batch
8+
were issued simultaneously. The semaphore is now acquired per task, so
9+
``max_workers`` correctly caps in-flight requests.

test/components/extractors/test_llm_metadata_extractor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import asyncio
56
import os
67
from unittest.mock import Mock
78

@@ -345,6 +346,41 @@ async def test_run_with_document_content_none_async(self, monkeypatch: pytest.Mo
345346
# Ensure no attempt was made to call the LLM
346347
mock_chat_generator.run_async.assert_not_called()
347348

349+
@pytest.mark.asyncio
350+
async def test_run_async_respects_max_workers(self, monkeypatch: pytest.MonkeyPatch) -> None:
351+
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
352+
353+
max_workers = 2
354+
in_flight = 0
355+
peak_in_flight = 0
356+
357+
mock_chat_generator = Mock(spec=OpenAIChatGenerator)
358+
359+
async def fake_run_async(messages, **kwargs):
360+
nonlocal in_flight, peak_in_flight
361+
in_flight += 1
362+
peak_in_flight = max(peak_in_flight, in_flight)
363+
try:
364+
await asyncio.sleep(0.01)
365+
return {"replies": [ChatMessage.from_assistant('{"entities": []}')]}
366+
finally:
367+
in_flight -= 1
368+
369+
mock_chat_generator.run_async = fake_run_async
370+
371+
extractor = LLMMetadataExtractor(
372+
prompt="prompt {{document.content}}",
373+
chat_generator=mock_chat_generator,
374+
expected_keys=["entities"],
375+
max_workers=max_workers,
376+
)
377+
378+
docs = [Document(content=f"doc {i}") for i in range(10)]
379+
result = await extractor.run_async(documents=docs)
380+
381+
assert len(result["documents"]) == 10
382+
assert peak_in_flight <= max_workers
383+
348384
@pytest.mark.integration
349385
@pytest.mark.skipif(
350386
not os.environ.get("OPENAI_API_KEY", None),

0 commit comments

Comments
 (0)