|
1 | | -"""Module for the LangchainSummarizer class.""" |
| 1 | +"""Module for the DefaultConfluenceExtractor class.""" |
2 | 2 |
|
3 | 3 | import logging |
4 | | -import traceback |
5 | | -from typing import Optional |
6 | | - |
7 | | -from langchain.text_splitter import RecursiveCharacterTextSplitter |
8 | | -from langchain_core.documents import Document |
9 | | -from langchain_core.runnables import Runnable, RunnableConfig, ensure_config |
10 | | - |
11 | | -from admin_api_lib.summarizer.summarizer import ( |
12 | | - Summarizer, |
13 | | - SummarizerInput, |
14 | | - SummarizerOutput, |
| 4 | +from langchain_community.document_loaders import ConfluenceLoader |
| 5 | + |
| 6 | +from extractor_api_lib.impl.types.extractor_types import ExtractorTypes |
| 7 | +from extractor_api_lib.models.dataclasses.internal_information_piece import InternalInformationPiece |
| 8 | +from extractor_api_lib.models.extraction_parameters import ExtractionParameters |
| 9 | +from extractor_api_lib.extractors.information_extractor import InformationExtractor |
| 10 | +from extractor_api_lib.impl.mapper.confluence_langchain_document2information_piece import ( |
| 11 | + ConfluenceLangchainDocument2InformationPiece, |
15 | 12 | ) |
16 | | -from rag_core_lib.impl.langfuse_manager.langfuse_manager import LangfuseManager |
17 | | -from rag_core_lib.impl.utils.async_threadsafe_semaphore import AsyncThreadsafeSemaphore |
18 | 13 |
|
19 | 14 | logger = logging.getLogger(__name__) |
20 | 15 |
|
21 | 16 |
|
22 | | -class LangchainSummarizer(Summarizer): |
23 | | - """Is responsible for summarizing input data. |
24 | | -
|
25 | | - LangchainSummarizer is responsible for summarizing input data using the LangfuseManager, |
26 | | - RecursiveCharacterTextSplitter, and AsyncThreadsafeSemaphore. It handles chunking of the input |
27 | | - document and retries the summarization process if an error occurs. |
28 | | - """ |
| 17 | +class ConfluenceExtractor(InformationExtractor): |
| 18 | + """Implementation of the InformationExtractor interface for confluence.""" |
29 | 19 |
|
30 | 20 | def __init__( |
31 | 21 | self, |
32 | | - langfuse_manager: LangfuseManager, |
33 | | - chunker: RecursiveCharacterTextSplitter, |
34 | | - semaphore: AsyncThreadsafeSemaphore, |
| 22 | + mapper: ConfluenceLangchainDocument2InformationPiece, |
35 | 23 | ): |
36 | | - self._chunker = chunker |
37 | | - self._langfuse_manager = langfuse_manager |
38 | | - self._semaphore = semaphore |
39 | | - |
40 | | - async def ainvoke(self, query: SummarizerInput, config: Optional[RunnableConfig] = None) -> SummarizerOutput: |
41 | 24 | """ |
42 | | - Asynchronously invokes the summarization process on the given query. |
| 25 | + Initialize the ConfluenceExtractor. |
43 | 26 |
|
44 | 27 | Parameters |
45 | 28 | ---------- |
46 | | - query : SummarizerInput |
47 | | - The input data to be summarized. |
48 | | - config : Optional[RunnableConfig], optional |
49 | | - Configuration options for the summarization process, by default None. |
50 | | -
|
51 | | - Returns |
52 | | - ------- |
53 | | - SummarizerOutput |
54 | | - The summarized output. |
55 | | -
|
56 | | - Raises |
57 | | - ------ |
58 | | - Exception |
59 | | - If the summary creation fails after the allowed number of tries. |
60 | | -
|
61 | | - Notes |
62 | | - ----- |
63 | | - This method handles chunking of the input document and retries the summarization |
64 | | - process if an error occurs, up to the number of tries specified in the config. |
| 29 | + mapper : ConfluenceLangchainDocument2InformationPiece |
| 30 | + An instance of ConfluenceLangchainDocument2InformationPiece used for mapping langchain documents |
| 31 | + to information pieces. |
65 | 32 | """ |
66 | | - assert query, "Query is empty: %s" % query # noqa S101 |
67 | | - config = ensure_config(config) |
68 | | - tries_remaining = config.get("configurable", {}).get("tries_remaining", 3) |
69 | | - logger.debug("Tries remaining %d" % tries_remaining) |
| 33 | + self._mapper = mapper |
70 | 34 |
|
71 | | - if tries_remaining < 0: |
72 | | - raise Exception("Summary creation failed.") |
73 | | - document = Document(page_content=query) |
74 | | - langchain_documents = self._chunker.split_documents([document]) |
| 35 | + @property |
| 36 | + def extractor_type(self) -> ExtractorTypes: |
| 37 | + return ExtractorTypes.CONFLUENCE |
75 | 38 |
|
76 | | - outputs = [] |
77 | | - for langchain_document in langchain_documents: |
78 | | - async with self._semaphore: |
79 | | - try: |
80 | | - result = await self._create_chain().ainvoke({"text": langchain_document.page_content}, config) |
81 | | - # Extract content from AIMessage if it's not already a string |
82 | | - content = result.content if hasattr(result, "content") else str(result) |
83 | | - outputs.append(content) |
84 | | - except Exception as e: |
85 | | - logger.error("Error in summarizing langchain doc: %s %s", e, traceback.format_exc()) |
86 | | - config["tries_remaining"] = tries_remaining - 1 |
87 | | - result = await self._create_chain().ainvoke({"text": langchain_document.page_content}, config) |
88 | | - # Extract content from AIMessage if it's not already a string |
89 | | - content = result.content if hasattr(result, "content") else str(result) |
90 | | - outputs.append(content) |
| 39 | + async def aextract_content( |
| 40 | + self, |
| 41 | + extraction_parameters: ExtractionParameters, |
| 42 | + ) -> list[InternalInformationPiece]: |
| 43 | + """ |
| 44 | + Asynchronously extracts information pieces from Confluence. |
91 | 45 |
|
92 | | - if len(outputs) == 1: |
93 | | - return outputs[0] |
94 | | - summary = " ".join(outputs) |
95 | | - logger.debug( |
96 | | - "Reduced number of chars from %d to %d" |
97 | | - % (len("".join([x.page_content for x in langchain_documents])), len(summary)) |
98 | | - ) |
99 | | - return await self.ainvoke(summary, config) |
| 46 | + Parameters |
| 47 | + ---------- |
| 48 | + extraction_parameters : ExtractionParameters |
| 49 | + The parameters required to connect to and extract data from Confluence. |
100 | 50 |
|
101 | | - def _create_chain(self) -> Runnable: |
102 | | - return self._langfuse_manager.get_base_prompt(self.__class__.__name__) | self._langfuse_manager.get_base_llm( |
103 | | - self.__class__.__name__ |
104 | | - ) |
| 51 | + Returns |
| 52 | + ------- |
| 53 | + list[InternalInformationPiece] |
| 54 | + A list of information pieces extracted from Confluence. |
| 55 | + """ |
| 56 | + # Convert list of key value pairs to dict |
| 57 | + confluence_loader_parameters = { |
| 58 | + x.key: int(x.value) if x.value.isdigit() else x.value for x in extraction_parameters.kwargs |
| 59 | + } |
| 60 | + if not confluence_loader_parameters.get("max_pages") or isinstance( |
| 61 | + confluence_loader_parameters.get("max_pages"), str |
| 62 | + ): |
| 63 | + logging.warning( |
| 64 | + "max_pages parameter is not set or invalid discarding it. ConfluenceLoader will use default value." |
| 65 | + ) |
| 66 | + confluence_loader_parameters.pop("max_pages") |
| 67 | + # Drop the document_name parameter as it is not used by the ConfluenceLoader |
| 68 | + if "document_name" in confluence_loader_parameters: |
| 69 | + confluence_loader_parameters.pop("document_name", None) |
| 70 | + document_loader = ConfluenceLoader(**confluence_loader_parameters) |
| 71 | + documents = document_loader.load() |
| 72 | + return [self._mapper.map_document2informationpiece(x, extraction_parameters.document_name) for x in documents] |
0 commit comments