Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 107 additions & 21 deletions haystack/components/extractors/llm_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import copy
import json
from asyncio import Semaphore, gather
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import replace
from typing import Any
Expand Down Expand Up @@ -173,6 +175,8 @@ def __init__(
:param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or
validation of the JSON output.
:param max_workers: The maximum number of workers to use in the thread pool executor.
This parameter is used limit the maximum number of requests that should be allowed to run concurrently
when using the `run_async` method.
"""
self.prompt = prompt
ast = SandboxedEnvironment().parse(prompt)
Expand Down Expand Up @@ -293,6 +297,52 @@ def _run_on_thread(self, prompt: ChatMessage | None) -> dict[str, Any]:
result = {"error": "LLM failed with exception: " + str(e)}
return result

async def _run_async(self, prompt: ChatMessage | None) -> dict[str, Any]:
# If prompt is None, return an error dictionary
if prompt is None:
return {"error": "Document has no content, skipping LLM call."}

try:
result = await self._chat_generator.run_async(messages=[prompt]) # type: ignore[attr-defined]
except Exception as e:
if self.raise_on_failure:
raise e
logger.exception(
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
Comment thread
sjrl marked this conversation as resolved.
class_name=self._chat_generator.__class__.__name__,
error=e,
)
result = {"error": "LLM failed with exception: " + str(e)}
return result

def _process_results(
self, documents: list[Document], results: Iterable[dict[str, Any]]
) -> tuple[list[Document], list[Document]]:
successful_documents = []
failed_documents = []
for document, result in zip(documents, results, strict=True):
new_meta = {**document.meta}
if "error" in result:
new_meta["metadata_extraction_error"] = result["error"]
new_meta["metadata_extraction_response"] = None
failed_documents.append(replace(document, meta=new_meta))
continue

parsed_metadata = self._extract_metadata(result["replies"][0].text)
if "error" in parsed_metadata:
new_meta["metadata_extraction_error"] = parsed_metadata["error"]
new_meta["metadata_extraction_response"] = result["replies"][0]
failed_documents.append(replace(document, meta=new_meta))
continue

for key in parsed_metadata:
new_meta[key] = parsed_metadata[key]
# Remove metadata_extraction_error and metadata_extraction_response if present from previous runs
new_meta.pop("metadata_extraction_error", None)
new_meta.pop("metadata_extraction_response", None)
successful_documents.append(replace(document, meta=new_meta))
return successful_documents, failed_documents

@component.output_types(documents=list[Document], failed_documents=list[Document])
def run(self, documents: list[Document], page_range: list[str | int] | None = None) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -336,28 +386,64 @@ def run(self, documents: list[Document], page_range: list[str | int] | None = No
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
results = executor.map(self._run_on_thread, all_prompts)

successful_documents = []
failed_documents = []
for document, result in zip(documents, results, strict=True):
new_meta = {**document.meta}
if "error" in result:
new_meta["metadata_extraction_error"] = result["error"]
new_meta["metadata_extraction_response"] = None
failed_documents.append(replace(document, meta=new_meta))
continue
successful_documents, failed_documents = self._process_results(documents, results)

parsed_metadata = self._extract_metadata(result["replies"][0].text)
if "error" in parsed_metadata:
new_meta["metadata_extraction_error"] = parsed_metadata["error"]
new_meta["metadata_extraction_response"] = result["replies"][0]
failed_documents.append(replace(document, meta=new_meta))
continue
return {"documents": successful_documents, "failed_documents": failed_documents}

for key in parsed_metadata:
new_meta[key] = parsed_metadata[key]
# Remove metadata_extraction_error and metadata_extraction_response if present from previous runs
new_meta.pop("metadata_extraction_error", None)
new_meta.pop("metadata_extraction_response", None)
successful_documents.append(replace(document, meta=new_meta))
@component.output_types(documents=list[Document], failed_documents=list[Document])
async def run_async(self, documents: list[Document], page_range: list[str | int] | None = None) -> dict[str, Any]:
"""
Asynchronously extract metadata from documents using a Large Language Model.

If `page_range` is provided, the metadata will be extracted from the specified range of pages. This component
will split the documents into pages and extract metadata from the specified range of pages. The metadata will be
extracted from the entire document if `page_range` is not provided.

The original documents will be returned updated with the extracted metadata.

This is the asynchronous version of the `run` method. It has the same parameters
and return values but can be used with `await` in an async code.

:param documents: List of documents to extract metadata from.
:param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
metadata from the first and third pages of each document. It also accepts printable range
strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
11, 12.
If None, metadata will be extracted from the entire document for each document in the
documents list.
:returns:
A dictionary with the keys:
- "documents": A list of documents that were successfully updated with the extracted metadata.
- "failed_documents": A list of documents that failed to extract metadata. These documents will have
"metadata_extraction_error" and "metadata_extraction_response" in their metadata. These documents can be
re-run with the extractor to extract metadata.
"""
if not hasattr(self._chat_generator, "run_async"):
logger.warning(
"{chat_generator_type} does not implement method 'run_async'. Falling back to 'run'.",
chat_generator_type=type(self._chat_generator).__name__,
)
return self.run(documents, page_range)

if len(documents) == 0:
logger.warning("No documents provided. Skipping metadata extraction.")
return {"documents": [], "failed_documents": []}

if not self._is_warmed_up:
self.warm_up()

expanded_range = self.expanded_range
if page_range:
expanded_range = expand_page_range(page_range)

# Create ChatMessage prompts for each document
all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range)

# Run the LLM on each prompt
sem = Semaphore(max(1, self.max_workers))
async with sem:
results = await gather(*[self._run_async(prompt) for prompt in all_prompts])

successful_documents, failed_documents = self._process_results(documents, results)

return {"documents": successful_documents, "failed_documents": failed_documents}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Add ``run_async`` method to ``LLMMetadataExtractor``. ``ChatGenerator`` requests now run concurrently using the existing ``max_workers`` init parameter.
Loading
Loading