|
1 | | -"""Module for the DefaultConfluenceExtractor class.""" |
| 1 | +"""Module for the LangchainSummarizer class.""" |
2 | 2 |
|
3 | 3 | import logging |
4 | | -from langchain_community.document_loaders import ConfluenceLoader |
5 | | - |
6 | | -from extractor_api_lib.impl.types.extractor_types import ExtractorTypes |
7 | | -from extractor_api_lib.models.dataclasses.internal_information_piece import InternalInformationPiece |
8 | | -from extractor_api_lib.models.extraction_parameters import ExtractionParameters |
9 | | -from extractor_api_lib.extractors.information_extractor import InformationExtractor |
10 | | -from extractor_api_lib.impl.mapper.confluence_langchain_document2information_piece import ( |
11 | | - ConfluenceLangchainDocument2InformationPiece, |
| 4 | +import traceback |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +from langchain.text_splitter import RecursiveCharacterTextSplitter |
| 8 | +from langchain_core.documents import Document |
| 9 | +from langchain_core.runnables import Runnable, RunnableConfig, ensure_config |
| 10 | + |
| 11 | +from admin_api_lib.summarizer.summarizer import ( |
| 12 | + Summarizer, |
| 13 | + SummarizerInput, |
| 14 | + SummarizerOutput, |
12 | 15 | ) |
| 16 | +from rag_core_lib.impl.langfuse_manager.langfuse_manager import LangfuseManager |
| 17 | +from rag_core_lib.impl.utils.async_threadsafe_semaphore import AsyncThreadsafeSemaphore |
13 | 18 |
|
14 | 19 | logger = logging.getLogger(__name__) |
15 | 20 |
|
16 | 21 |
|
17 | | -class ConfluenceExtractor(InformationExtractor): |
18 | | - """Implementation of the InformationExtractor interface for confluence.""" |
| 22 | +class LangchainSummarizer(Summarizer): |
| 23 | + """Is responsible for summarizing input data. |
| 24 | +
|
| 25 | + LangchainSummarizer is responsible for summarizing input data using the LangfuseManager, |
| 26 | + RecursiveCharacterTextSplitter, and AsyncThreadsafeSemaphore. It handles chunking of the input |
| 27 | + document and retries the summarization process if an error occurs. |
| 28 | + """ |
19 | 29 |
|
20 | 30 | def __init__( |
21 | 31 | self, |
22 | | - mapper: ConfluenceLangchainDocument2InformationPiece, |
| 32 | + langfuse_manager: LangfuseManager, |
| 33 | + chunker: RecursiveCharacterTextSplitter, |
| 34 | + semaphore: AsyncThreadsafeSemaphore, |
23 | 35 | ): |
24 | | - """ |
25 | | - Initialize the ConfluenceExtractor. |
26 | | -
|
27 | | - Parameters |
28 | | - ---------- |
29 | | - mapper : ConfluenceLangchainDocument2InformationPiece |
30 | | - An instance of ConfluenceLangchainDocument2InformationPiece used for mapping langchain documents |
31 | | - to information pieces. |
32 | | - """ |
33 | | - self._mapper = mapper |
| 36 | + self._chunker = chunker |
| 37 | + self._langfuse_manager = langfuse_manager |
| 38 | + self._semaphore = semaphore |
34 | 39 |
|
35 | | - @property |
36 | | - def extractor_type(self) -> ExtractorTypes: |
37 | | - return ExtractorTypes.CONFLUENCE |
38 | | - |
39 | | - async def aextract_content( |
40 | | - self, |
41 | | - extraction_parameters: ExtractionParameters, |
42 | | - ) -> list[InternalInformationPiece]: |
| 40 | + async def ainvoke(self, query: SummarizerInput, config: Optional[RunnableConfig] = None) -> SummarizerOutput: |
43 | 41 | """ |
44 | | - Asynchronously extracts information pieces from Confluence. |
| 42 | + Asynchronously invokes the summarization process on the given query. |
45 | 43 |
|
46 | 44 | Parameters |
47 | 45 | ---------- |
48 | | - extraction_parameters : ExtractionParameters |
49 | | - The parameters required to connect to and extract data from Confluence. |
| 46 | + query : SummarizerInput |
| 47 | + The input data to be summarized. |
| 48 | + config : Optional[RunnableConfig], optional |
| 49 | + Configuration options for the summarization process, by default None. |
50 | 50 |
|
51 | 51 | Returns |
52 | 52 | ------- |
53 | | - list[InternalInformationPiece] |
54 | | - A list of information pieces extracted from Confluence. |
| 53 | + SummarizerOutput |
| 54 | + The summarized output. |
| 55 | +
|
| 56 | + Raises |
| 57 | + ------ |
| 58 | + Exception |
| 59 | + If the summary creation fails after the allowed number of tries. |
| 60 | +
|
| 61 | + Notes |
| 62 | + ----- |
| 63 | + This method handles chunking of the input document and retries the summarization |
| 64 | + process if an error occurs, up to the number of tries specified in the config. |
55 | 65 | """ |
56 | | - # Convert list of key value pairs to dict |
57 | | - confluence_loader_parameters = { |
58 | | - x.key: int(x.value) if x.value.isdigit() else x.value for x in extraction_parameters.kwargs |
59 | | - } |
60 | | - if not confluence_loader_parameters.get("max_pages") or isinstance( |
61 | | - confluence_loader_parameters.get("max_pages"), str |
62 | | - ): |
63 | | - logging.warning( |
64 | | - "max_pages parameter is not set or invalid discarding it. ConfluenceLoader will use default value." |
65 | | - ) |
66 | | - confluence_loader_parameters.pop("max_pages") |
67 | | - # Drop the document_name parameter as it is not used by the ConfluenceLoader |
68 | | - if "document_name" in confluence_loader_parameters: |
69 | | - confluence_loader_parameters.pop("document_name", None) |
70 | | - document_loader = ConfluenceLoader(**confluence_loader_parameters) |
71 | | - documents = document_loader.load() |
72 | | - return [self._mapper.map_document2informationpiece(x, extraction_parameters.document_name) for x in documents] |
| 66 | + assert query, "Query is empty: %s" % query # noqa S101 |
| 67 | + config = ensure_config(config) |
| 68 | + tries_remaining = config.get("configurable", {}).get("tries_remaining", 3) |
| 69 | + logger.debug("Tries remaining %d" % tries_remaining) |
| 70 | + |
| 71 | + if tries_remaining < 0: |
| 72 | + raise Exception("Summary creation failed.") |
| 73 | + document = Document(page_content=query) |
| 74 | + langchain_documents = self._chunker.split_documents([document]) |
| 75 | + |
| 76 | + outputs = [] |
| 77 | + for langchain_document in langchain_documents: |
| 78 | + async with self._semaphore: |
| 79 | + try: |
| 80 | + result = await self._create_chain().ainvoke({"text": langchain_document.page_content}, config) |
| 81 | + # Extract content from AIMessage if it's not already a string |
| 82 | + content = result.content if hasattr(result, "content") else str(result) |
| 83 | + outputs.append(content) |
| 84 | + except Exception as e: |
| 85 | + logger.error("Error in summarizing langchain doc: %s %s", e, traceback.format_exc()) |
| 86 | + config["tries_remaining"] = tries_remaining - 1 |
| 87 | + result = await self._create_chain().ainvoke({"text": langchain_document.page_content}, config) |
| 88 | + # Extract content from AIMessage if it's not already a string |
| 89 | + content = result.content if hasattr(result, "content") else str(result) |
| 90 | + outputs.append(content) |
| 91 | + |
| 92 | + if len(outputs) == 1: |
| 93 | + return outputs[0] |
| 94 | + summary = " ".join(outputs) |
| 95 | + logger.debug( |
| 96 | + "Reduced number of chars from %d to %d" |
| 97 | + % (len("".join([x.page_content for x in langchain_documents])), len(summary)) |
| 98 | + ) |
| 99 | + return await self.ainvoke(summary, config) |
| 100 | + |
| 101 | + def _create_chain(self) -> Runnable: |
| 102 | + return self._langfuse_manager.get_base_prompt(self.__class__.__name__) | self._langfuse_manager.get_base_llm( |
| 103 | + self.__class__.__name__ |
| 104 | + ) |
0 commit comments