Skip to content

Commit 04a0778

Browse files
authored
feat: Add run_async method for LLMMetadataExtractor (#10984)
* feat: Add run_async method for LLMMetadataExtractor * Address review comments * fix: missing await
1 parent 48f84db commit 04a0778

3 files changed

Lines changed: 260 additions & 61 deletions

File tree

haystack/components/extractors/llm_metadata_extractor.py

Lines changed: 107 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import copy
66
import json
7+
from asyncio import Semaphore, gather
8+
from collections.abc import Iterable
79
from concurrent.futures import ThreadPoolExecutor
810
from dataclasses import replace
911
from typing import Any
@@ -173,6 +175,8 @@ def __init__(
173175
:param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or
174176
validation of the JSON output.
175177
:param max_workers: The maximum number of workers to use in the thread pool executor.
178+
This parameter is used limit the maximum number of requests that should be allowed to run concurrently
179+
when using the `run_async` method.
176180
"""
177181
self.prompt = prompt
178182
ast = SandboxedEnvironment().parse(prompt)
@@ -293,6 +297,52 @@ def _run_on_thread(self, prompt: ChatMessage | None) -> dict[str, Any]:
293297
result = {"error": "LLM failed with exception: " + str(e)}
294298
return result
295299

300+
async def _run_async(self, prompt: ChatMessage | None) -> dict[str, Any]:
301+
# If prompt is None, return an error dictionary
302+
if prompt is None:
303+
return {"error": "Document has no content, skipping LLM call."}
304+
305+
try:
306+
result = await self._chat_generator.run_async(messages=[prompt]) # type: ignore[attr-defined]
307+
except Exception as e:
308+
if self.raise_on_failure:
309+
raise e
310+
logger.exception(
311+
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
312+
class_name=self._chat_generator.__class__.__name__,
313+
error=e,
314+
)
315+
result = {"error": "LLM failed with exception: " + str(e)}
316+
return result
317+
318+
def _process_results(
319+
self, documents: list[Document], results: Iterable[dict[str, Any]]
320+
) -> tuple[list[Document], list[Document]]:
321+
successful_documents = []
322+
failed_documents = []
323+
for document, result in zip(documents, results, strict=True):
324+
new_meta = {**document.meta}
325+
if "error" in result:
326+
new_meta["metadata_extraction_error"] = result["error"]
327+
new_meta["metadata_extraction_response"] = None
328+
failed_documents.append(replace(document, meta=new_meta))
329+
continue
330+
331+
parsed_metadata = self._extract_metadata(result["replies"][0].text)
332+
if "error" in parsed_metadata:
333+
new_meta["metadata_extraction_error"] = parsed_metadata["error"]
334+
new_meta["metadata_extraction_response"] = result["replies"][0]
335+
failed_documents.append(replace(document, meta=new_meta))
336+
continue
337+
338+
for key in parsed_metadata:
339+
new_meta[key] = parsed_metadata[key]
340+
# Remove metadata_extraction_error and metadata_extraction_response if present from previous runs
341+
new_meta.pop("metadata_extraction_error", None)
342+
new_meta.pop("metadata_extraction_response", None)
343+
successful_documents.append(replace(document, meta=new_meta))
344+
return successful_documents, failed_documents
345+
296346
@component.output_types(documents=list[Document], failed_documents=list[Document])
297347
def run(self, documents: list[Document], page_range: list[str | int] | None = None) -> dict[str, Any]:
298348
"""
@@ -336,28 +386,64 @@ def run(self, documents: list[Document], page_range: list[str | int] | None = No
336386
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
337387
results = executor.map(self._run_on_thread, all_prompts)
338388

339-
successful_documents = []
340-
failed_documents = []
341-
for document, result in zip(documents, results, strict=True):
342-
new_meta = {**document.meta}
343-
if "error" in result:
344-
new_meta["metadata_extraction_error"] = result["error"]
345-
new_meta["metadata_extraction_response"] = None
346-
failed_documents.append(replace(document, meta=new_meta))
347-
continue
389+
successful_documents, failed_documents = self._process_results(documents, results)
348390

349-
parsed_metadata = self._extract_metadata(result["replies"][0].text)
350-
if "error" in parsed_metadata:
351-
new_meta["metadata_extraction_error"] = parsed_metadata["error"]
352-
new_meta["metadata_extraction_response"] = result["replies"][0]
353-
failed_documents.append(replace(document, meta=new_meta))
354-
continue
391+
return {"documents": successful_documents, "failed_documents": failed_documents}
355392

356-
for key in parsed_metadata:
357-
new_meta[key] = parsed_metadata[key]
358-
# Remove metadata_extraction_error and metadata_extraction_response if present from previous runs
359-
new_meta.pop("metadata_extraction_error", None)
360-
new_meta.pop("metadata_extraction_response", None)
361-
successful_documents.append(replace(document, meta=new_meta))
393+
@component.output_types(documents=list[Document], failed_documents=list[Document])
394+
async def run_async(self, documents: list[Document], page_range: list[str | int] | None = None) -> dict[str, Any]:
395+
"""
396+
Asynchronously extract metadata from documents using a Large Language Model.
397+
398+
If `page_range` is provided, the metadata will be extracted from the specified range of pages. This component
399+
will split the documents into pages and extract metadata from the specified range of pages. The metadata will be
400+
extracted from the entire document if `page_range` is not provided.
401+
402+
The original documents will be returned updated with the extracted metadata.
403+
404+
This is the asynchronous version of the `run` method. It has the same parameters
405+
and return values but can be used with `await` in an async code.
406+
407+
:param documents: List of documents to extract metadata from.
408+
:param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
409+
metadata from the first and third pages of each document. It also accepts printable range
410+
strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
411+
11, 12.
412+
If None, metadata will be extracted from the entire document for each document in the
413+
documents list.
414+
:returns:
415+
A dictionary with the keys:
416+
- "documents": A list of documents that were successfully updated with the extracted metadata.
417+
- "failed_documents": A list of documents that failed to extract metadata. These documents will have
418+
"metadata_extraction_error" and "metadata_extraction_response" in their metadata. These documents can be
419+
re-run with the extractor to extract metadata.
420+
"""
421+
if not hasattr(self._chat_generator, "run_async"):
422+
logger.warning(
423+
"{chat_generator_type} does not implement method 'run_async'. Falling back to 'run'.",
424+
chat_generator_type=type(self._chat_generator).__name__,
425+
)
426+
return self.run(documents, page_range)
427+
428+
if len(documents) == 0:
429+
logger.warning("No documents provided. Skipping metadata extraction.")
430+
return {"documents": [], "failed_documents": []}
431+
432+
if not self._is_warmed_up:
433+
self.warm_up()
434+
435+
expanded_range = self.expanded_range
436+
if page_range:
437+
expanded_range = expand_page_range(page_range)
438+
439+
# Create ChatMessage prompts for each document
440+
all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range)
441+
442+
# Run the LLM on each prompt
443+
sem = Semaphore(max(1, self.max_workers))
444+
async with sem:
445+
results = await gather(*[self._run_async(prompt) for prompt in all_prompts])
446+
447+
successful_documents, failed_documents = self._process_results(documents, results)
362448

363449
return {"documents": successful_documents, "failed_documents": failed_documents}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
enhancements:
3+
- |
4+
Add ``run_async`` method to ``LLMMetadataExtractor``. ``ChatGenerator`` requests now run concurrently using the existing ``max_workers`` init parameter.

0 commit comments

Comments
 (0)