|
4 | 4 |
|
5 | 5 | import copy |
6 | 6 | import json |
| 7 | +from asyncio import Semaphore, gather |
| 8 | +from collections.abc import Iterable |
7 | 9 | from concurrent.futures import ThreadPoolExecutor |
8 | 10 | from dataclasses import replace |
9 | 11 | from typing import Any |
@@ -173,6 +175,8 @@ def __init__( |
173 | 175 | :param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or |
174 | 176 | validation of the JSON output. |
175 | 177 | :param max_workers: The maximum number of workers to use in the thread pool executor. |
| 178 | + This parameter is used limit the maximum number of requests that should be allowed to run concurrently |
| 179 | + when using the `run_async` method. |
176 | 180 | """ |
177 | 181 | self.prompt = prompt |
178 | 182 | ast = SandboxedEnvironment().parse(prompt) |
@@ -293,6 +297,52 @@ def _run_on_thread(self, prompt: ChatMessage | None) -> dict[str, Any]: |
293 | 297 | result = {"error": "LLM failed with exception: " + str(e)} |
294 | 298 | return result |
295 | 299 |
|
| 300 | + async def _run_async(self, prompt: ChatMessage | None) -> dict[str, Any]: |
| 301 | + # If prompt is None, return an error dictionary |
| 302 | + if prompt is None: |
| 303 | + return {"error": "Document has no content, skipping LLM call."} |
| 304 | + |
| 305 | + try: |
| 306 | + result = await self._chat_generator.run_async(messages=[prompt]) # type: ignore[attr-defined] |
| 307 | + except Exception as e: |
| 308 | + if self.raise_on_failure: |
| 309 | + raise e |
| 310 | + logger.exception( |
| 311 | + "LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.", |
| 312 | + class_name=self._chat_generator.__class__.__name__, |
| 313 | + error=e, |
| 314 | + ) |
| 315 | + result = {"error": "LLM failed with exception: " + str(e)} |
| 316 | + return result |
| 317 | + |
| 318 | + def _process_results( |
| 319 | + self, documents: list[Document], results: Iterable[dict[str, Any]] |
| 320 | + ) -> tuple[list[Document], list[Document]]: |
| 321 | + successful_documents = [] |
| 322 | + failed_documents = [] |
| 323 | + for document, result in zip(documents, results, strict=True): |
| 324 | + new_meta = {**document.meta} |
| 325 | + if "error" in result: |
| 326 | + new_meta["metadata_extraction_error"] = result["error"] |
| 327 | + new_meta["metadata_extraction_response"] = None |
| 328 | + failed_documents.append(replace(document, meta=new_meta)) |
| 329 | + continue |
| 330 | + |
| 331 | + parsed_metadata = self._extract_metadata(result["replies"][0].text) |
| 332 | + if "error" in parsed_metadata: |
| 333 | + new_meta["metadata_extraction_error"] = parsed_metadata["error"] |
| 334 | + new_meta["metadata_extraction_response"] = result["replies"][0] |
| 335 | + failed_documents.append(replace(document, meta=new_meta)) |
| 336 | + continue |
| 337 | + |
| 338 | + for key in parsed_metadata: |
| 339 | + new_meta[key] = parsed_metadata[key] |
| 340 | + # Remove metadata_extraction_error and metadata_extraction_response if present from previous runs |
| 341 | + new_meta.pop("metadata_extraction_error", None) |
| 342 | + new_meta.pop("metadata_extraction_response", None) |
| 343 | + successful_documents.append(replace(document, meta=new_meta)) |
| 344 | + return successful_documents, failed_documents |
| 345 | + |
296 | 346 | @component.output_types(documents=list[Document], failed_documents=list[Document]) |
297 | 347 | def run(self, documents: list[Document], page_range: list[str | int] | None = None) -> dict[str, Any]: |
298 | 348 | """ |
@@ -336,28 +386,64 @@ def run(self, documents: list[Document], page_range: list[str | int] | None = No |
336 | 386 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: |
337 | 387 | results = executor.map(self._run_on_thread, all_prompts) |
338 | 388 |
|
339 | | - successful_documents = [] |
340 | | - failed_documents = [] |
341 | | - for document, result in zip(documents, results, strict=True): |
342 | | - new_meta = {**document.meta} |
343 | | - if "error" in result: |
344 | | - new_meta["metadata_extraction_error"] = result["error"] |
345 | | - new_meta["metadata_extraction_response"] = None |
346 | | - failed_documents.append(replace(document, meta=new_meta)) |
347 | | - continue |
| 389 | + successful_documents, failed_documents = self._process_results(documents, results) |
348 | 390 |
|
349 | | - parsed_metadata = self._extract_metadata(result["replies"][0].text) |
350 | | - if "error" in parsed_metadata: |
351 | | - new_meta["metadata_extraction_error"] = parsed_metadata["error"] |
352 | | - new_meta["metadata_extraction_response"] = result["replies"][0] |
353 | | - failed_documents.append(replace(document, meta=new_meta)) |
354 | | - continue |
| 391 | + return {"documents": successful_documents, "failed_documents": failed_documents} |
355 | 392 |
|
356 | | - for key in parsed_metadata: |
357 | | - new_meta[key] = parsed_metadata[key] |
358 | | - # Remove metadata_extraction_error and metadata_extraction_response if present from previous runs |
359 | | - new_meta.pop("metadata_extraction_error", None) |
360 | | - new_meta.pop("metadata_extraction_response", None) |
361 | | - successful_documents.append(replace(document, meta=new_meta)) |
| 393 | + @component.output_types(documents=list[Document], failed_documents=list[Document]) |
| 394 | + async def run_async(self, documents: list[Document], page_range: list[str | int] | None = None) -> dict[str, Any]: |
| 395 | + """ |
| 396 | + Asynchronously extract metadata from documents using a Large Language Model. |
| 397 | +
|
| 398 | + If `page_range` is provided, the metadata will be extracted from the specified range of pages. This component |
| 399 | + will split the documents into pages and extract metadata from the specified range of pages. The metadata will be |
| 400 | + extracted from the entire document if `page_range` is not provided. |
| 401 | +
|
| 402 | + The original documents will be returned updated with the extracted metadata. |
| 403 | +
|
| 404 | + This is the asynchronous version of the `run` method. It has the same parameters |
| 405 | + and return values but can be used with `await` in an async code. |
| 406 | +
|
| 407 | + :param documents: List of documents to extract metadata from. |
| 408 | + :param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract |
| 409 | + metadata from the first and third pages of each document. It also accepts printable range |
| 410 | + strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10, |
| 411 | + 11, 12. |
| 412 | + If None, metadata will be extracted from the entire document for each document in the |
| 413 | + documents list. |
| 414 | + :returns: |
| 415 | + A dictionary with the keys: |
| 416 | + - "documents": A list of documents that were successfully updated with the extracted metadata. |
| 417 | + - "failed_documents": A list of documents that failed to extract metadata. These documents will have |
| 418 | + "metadata_extraction_error" and "metadata_extraction_response" in their metadata. These documents can be |
| 419 | + re-run with the extractor to extract metadata. |
| 420 | + """ |
| 421 | + if not hasattr(self._chat_generator, "run_async"): |
| 422 | + logger.warning( |
| 423 | + "{chat_generator_type} does not implement method 'run_async'. Falling back to 'run'.", |
| 424 | + chat_generator_type=type(self._chat_generator).__name__, |
| 425 | + ) |
| 426 | + return self.run(documents, page_range) |
| 427 | + |
| 428 | + if len(documents) == 0: |
| 429 | + logger.warning("No documents provided. Skipping metadata extraction.") |
| 430 | + return {"documents": [], "failed_documents": []} |
| 431 | + |
| 432 | + if not self._is_warmed_up: |
| 433 | + self.warm_up() |
| 434 | + |
| 435 | + expanded_range = self.expanded_range |
| 436 | + if page_range: |
| 437 | + expanded_range = expand_page_range(page_range) |
| 438 | + |
| 439 | + # Create ChatMessage prompts for each document |
| 440 | + all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range) |
| 441 | + |
| 442 | + # Run the LLM on each prompt |
| 443 | + sem = Semaphore(max(1, self.max_workers)) |
| 444 | + async with sem: |
| 445 | + results = await gather(*[self._run_async(prompt) for prompt in all_prompts]) |
| 446 | + |
| 447 | + successful_documents, failed_documents = self._process_results(documents, results) |
362 | 448 |
|
363 | 449 | return {"documents": successful_documents, "failed_documents": failed_documents} |
0 commit comments