|
| 1 | +# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +from concurrent.futures import ThreadPoolExecutor |
| 6 | +from dataclasses import replace |
| 7 | +from typing import Any, Dict, List, Literal, Optional, Tuple, Union |
| 8 | + |
| 9 | +from haystack import Document, component, default_from_dict, default_to_dict, logging |
| 10 | +from haystack.components.generators.chat.types import ChatGenerator |
| 11 | +from haystack.core.serialization import component_to_dict |
| 12 | +from haystack.dataclasses import TextContent |
| 13 | +from haystack.utils import deserialize_chatgenerator_inplace |
| 14 | +from jinja2 import meta |
| 15 | +from jinja2.sandbox import SandboxedEnvironment |
| 16 | + |
| 17 | +from haystack_experimental.components.converters.image.document_to_image import DocumentToImageContent |
| 18 | +from haystack_experimental.dataclasses.chat_message import ChatMessage |
| 19 | + |
| 20 | +logger = logging.getLogger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +DEFAULT_PROMPT_TEMPLATE = """ |
| 24 | +You are part of an information extraction pipeline that extracts the content of image-based documents. |
| 25 | +
|
| 26 | +Extract the content from the provided image. |
| 27 | +You need to extract the content exactly. |
| 28 | +Format everything as markdown. |
| 29 | +Make sure to retain the reading order of the document. |
| 30 | +
|
| 31 | +**Visual Elements** |
| 32 | +Do not extract figures, drawings, maps, graphs or any other visual elements. |
| 33 | +Instead, add a caption that describes briefly what you see in the visual element. |
| 34 | +You must describe each visual element. |
| 35 | +If you only see a visual element without other content, you must describe this visual element. |
| 36 | +Enclose each image caption with [img-caption][/img-caption] |
| 37 | +
|
| 38 | +**Tables** |
| 39 | +Make sure to format the table in markdown. |
| 40 | +Add a short caption below the table that describes the table's content. |
| 41 | +Enclose each table caption with [table-caption][/table-caption]. |
| 42 | +The caption must be placed below the extracted table. |
| 43 | +
|
| 44 | +**Forms** |
| 45 | +Reproduce checkbox selections with markdown. |
| 46 | +
|
| 47 | +Go ahead and extract! |
| 48 | +
|
| 49 | +Document:""" |
| 50 | + |
| 51 | + |
| 52 | +@component |
| 53 | +class LLMDocumentContentExtractor: |
| 54 | + """ |
| 55 | + Extracts textual content from image-based documents using a vision-enabled LLM (Large Language Model). |
| 56 | +
|
| 57 | + This component converts each input document into an image using the DocumentToImageContent component, |
| 58 | + uses a prompt to instruct the LLM on how to extract content, and uses a ChatGenerator to extract structured |
| 59 | + textual content based on the provided prompt. |
| 60 | +
|
| 61 | + The prompt must not contain variables; it should only include instructions for the LLM. Image data and the prompt |
| 62 | + are passed together to the LLM as a chat message. |
| 63 | +
|
| 64 | + Documents for which the LLM fails to extract content are returned in a separate `failed_documents` list. These |
| 65 | + failed documents will have a `content_extraction_error` entry in their metadata. This metadata can be used for |
| 66 | + debugging or for reprocessing the documents later. |
| 67 | +
|
| 68 | + ### Usage example |
| 69 | + ```python |
| 70 | + from haystack import Document |
| 71 | + from haystack_experimental.components.generators.chat import OpenAIChatGenerator |
| 72 | + from haystack_experimental.components.extractors import LLMDocumentContentExtractor |
| 73 | + chat_generator = OpenAIChatGenerator() |
| 74 | + extractor = LLMDocumentContentExtractor(chat_generator=chat_generator) |
| 75 | + documents = [ |
| 76 | + Document(content="", meta={"file_path": "image.jpg"}), |
| 77 | + Document(content="", meta={"file_path": "document.pdf", "page_number": 1}), |
| 78 | + ] |
| 79 | + updated_documents = extractor.run(documents=documents)["documents"] |
| 80 | + print(updated_documents) |
| 81 | + # [Document(content='Extracted text from image.jpg', |
| 82 | + # meta={'file_path': 'image.jpg'}), |
| 83 | + # ...] |
| 84 | + ``` |
| 85 | + """ |
| 86 | + |
| 87 | + def __init__( |
| 88 | + self, |
| 89 | + *, |
| 90 | + chat_generator: ChatGenerator, |
| 91 | + prompt: str = DEFAULT_PROMPT_TEMPLATE, |
| 92 | + file_path_meta_field: str = "file_path", |
| 93 | + root_path: Optional[str] = None, |
| 94 | + detail: Optional[Literal["auto", "high", "low"]] = None, |
| 95 | + size: Optional[Tuple[int, int]] = None, |
| 96 | + raise_on_failure: bool = False, |
| 97 | + max_workers: int = 3, |
| 98 | + ): |
| 99 | + """ |
| 100 | + Initialize the LLMDocumentContentExtractor component. |
| 101 | +
|
| 102 | + :param chat_generator: A ChatGenerator instance representing the LLM used to extract text. This generator must |
| 103 | + support vision-based input and return a plain text response. |
| 104 | + :param prompt: Instructional text provided to the LLM. It must not contain Jinja variables. |
| 105 | + The prompt should only contain instructions on how to extract the content of the image-based document. |
| 106 | + :param file_path_meta_field: The metadata field in the Document that contains the file path to the image or PDF. |
| 107 | + :param root_path: The root directory path where document files are located. If provided, file paths in |
| 108 | + document metadata will be resolved relative to this path. If None, file paths are treated as absolute paths. |
| 109 | + :param detail: Optional detail level of the image (only supported by OpenAI). Can be "auto", "high", or "low". |
| 110 | + This will be passed to chat_generator when processing the images. |
| 111 | + :param size: If provided, resizes the image to fit within the specified dimensions (width, height) while |
| 112 | + maintaining aspect ratio. This reduces file size, memory usage, and processing time, which is beneficial |
| 113 | + when working with models that have resolution constraints or when transmitting images to remote services. |
| 114 | + :param raise_on_failure: If True, exceptions from the LLM are raised. If False, failed documents are logged |
| 115 | + and returned. |
| 116 | + :param max_workers: Maximum number of threads used to parallelize LLM calls across documents using a |
| 117 | + ThreadPoolExecutor. |
| 118 | + """ |
| 119 | + self._chat_generator = chat_generator |
| 120 | + self.prompt = prompt |
| 121 | + self.file_path_meta_field = file_path_meta_field |
| 122 | + self.root_path = root_path or "" |
| 123 | + self.detail = detail |
| 124 | + self.size = size |
| 125 | + # Ensure the prompt does not contain any variables. |
| 126 | + ast = SandboxedEnvironment().parse(prompt) |
| 127 | + template_variables = meta.find_undeclared_variables(ast) |
| 128 | + variables = list(template_variables) |
| 129 | + if len(variables) != 0: |
| 130 | + raise ValueError( |
| 131 | + f"The prompt must not have any variables only instructions on how to extract the content of the " |
| 132 | + f"image-based document. Found {','.join(variables)} in the prompt." |
| 133 | + ) |
| 134 | + self.raise_on_failure = raise_on_failure |
| 135 | + self.max_workers = max_workers |
| 136 | + self._document_to_image_content = DocumentToImageContent( |
| 137 | + file_path_meta_field=file_path_meta_field, |
| 138 | + root_path=root_path, |
| 139 | + detail=detail, |
| 140 | + size=size, |
| 141 | + ) |
| 142 | + self._is_warmed_up = False |
| 143 | + |
| 144 | + def warm_up(self): |
| 145 | + """ |
| 146 | + Warm up the ChatGenerator if it has a warm_up method. |
| 147 | + """ |
| 148 | + if not self._is_warmed_up: |
| 149 | + if hasattr(self._chat_generator, "warm_up"): |
| 150 | + self._chat_generator.warm_up() |
| 151 | + self._is_warmed_up = True |
| 152 | + |
| 153 | + def to_dict(self) -> Dict[str, Any]: |
| 154 | + """ |
| 155 | + Serializes the component to a dictionary. |
| 156 | +
|
| 157 | + :returns: |
| 158 | + Dictionary with serialized data. |
| 159 | + """ |
| 160 | + |
| 161 | + return default_to_dict( |
| 162 | + self, |
| 163 | + chat_generator=component_to_dict(obj=self._chat_generator, name="chat_generator"), |
| 164 | + prompt=self.prompt, |
| 165 | + file_path_meta_field=self.file_path_meta_field, |
| 166 | + root_path=self.root_path, |
| 167 | + detail=self.detail, |
| 168 | + size=self.size, |
| 169 | + raise_on_failure=self.raise_on_failure, |
| 170 | + max_workers=self.max_workers, |
| 171 | + ) |
| 172 | + |
| 173 | + @classmethod |
| 174 | + def from_dict(cls, data: Dict[str, Any]) -> "LLMDocumentContentExtractor": |
| 175 | + """ |
| 176 | + Deserializes the component from a dictionary. |
| 177 | +
|
| 178 | + :param data: |
| 179 | + Dictionary with serialized data. |
| 180 | + :returns: |
| 181 | + An instance of the component. |
| 182 | + """ |
| 183 | + deserialize_chatgenerator_inplace(data["init_parameters"], key="chat_generator") |
| 184 | + return default_from_dict(cls, data) |
| 185 | + |
| 186 | + def _run_on_thread(self, message: Optional[ChatMessage]) -> Dict[str, Any]: |
| 187 | + """ |
| 188 | + Execute the LLM inference in a separate thread for each document. |
| 189 | +
|
| 190 | + :param message: A ChatMessage containing the prompt and image content for the LLM. |
| 191 | + :returns: |
| 192 | + The LLM response if successful, or a dictionary with an "error" key on failure. |
| 193 | + """ |
| 194 | + # If message is None, return an error dictionary |
| 195 | + if message is None: |
| 196 | + return {"error": "Document has no content, skipping LLM call."} |
| 197 | + |
| 198 | + try: |
| 199 | + result = self._chat_generator.run(messages=[message]) |
| 200 | + except Exception as e: |
| 201 | + if self.raise_on_failure: |
| 202 | + raise e |
| 203 | + logger.error( |
| 204 | + "LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.", |
| 205 | + class_name=self._chat_generator.__class__.__name__, |
| 206 | + error=e, |
| 207 | + ) |
| 208 | + result = {"error": "LLM failed with exception: " + str(e)} |
| 209 | + return result |
| 210 | + |
| 211 | + @component.output_types(documents=List[Document], failed_documents=List[Document]) |
| 212 | + def run(self, documents: List[Document]) -> Dict[str, List[Document]]: |
| 213 | + """ |
| 214 | + Run content extraction on a list of image-based documents using a vision-capable LLM. |
| 215 | +
|
| 216 | + Each document is passed to the LLM along with a predefined prompt. The response is used to update the document's |
| 217 | + content. If the extraction fails, the document is returned in the `failed_documents` list with metadata |
| 218 | + describing the failure. |
| 219 | +
|
| 220 | + :param documents: A list of image-based documents to process. Each must have a valid file path in its metadata. |
| 221 | + :returns: |
| 222 | + A dictionary with: |
| 223 | + - "documents": Successfully processed documents, updated with extracted content. |
| 224 | + - "failed_documents": Documents that failed processing, annotated with failure metadata. |
| 225 | + """ |
| 226 | + if not documents: |
| 227 | + return {"documents": [], "failed_documents": []} |
| 228 | + |
| 229 | + # Create ChatMessage prompts for each document |
| 230 | + image_contents = self._document_to_image_content.run(documents=documents)["image_contents"] |
| 231 | + all_messages: List[Union[ChatMessage, None]] = [] |
| 232 | + for image_content in image_contents: |
| 233 | + if image_content is None: |
| 234 | + # If the image content is None, it means the document could not be converted to an image. |
| 235 | + # We skip this document. |
| 236 | + # We don't log a warning here since it is already logged in the DocumentToImageContent component. |
| 237 | + all_messages.append(None) |
| 238 | + continue |
| 239 | + message = ChatMessage.from_user(content_parts=[TextContent(text=self.prompt), image_content]) |
| 240 | + all_messages.append(message) |
| 241 | + |
| 242 | + # Run the LLM on each message |
| 243 | + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: |
| 244 | + results = executor.map(self._run_on_thread, all_messages) |
| 245 | + |
| 246 | + successful_documents = [] |
| 247 | + failed_documents = [] |
| 248 | + for document, result in zip(documents, results): |
| 249 | + if "error" in result: |
| 250 | + new_meta = { |
| 251 | + **document.meta, |
| 252 | + "content_extraction_error": result["error"], |
| 253 | + } |
| 254 | + failed_documents.append(replace(document, meta=new_meta)) |
| 255 | + continue |
| 256 | + |
| 257 | + # Remove content_extraction_error if present from previous runs |
| 258 | + new_meta = {**document.meta} |
| 259 | + new_meta.pop("content_extraction_error", None) |
| 260 | + |
| 261 | + extracted_content = result["replies"][0].text |
| 262 | + successful_documents.append(replace(document, content=extracted_content, meta=new_meta)) |
| 263 | + |
| 264 | + return {"documents": successful_documents, "failed_documents": failed_documents} |
0 commit comments