Skip to content

Commit 75e5e25

Browse files
committed
refactor: streamline LangchainSummarizer and implement DefaultConfluenceExtractor with improved extraction logic
1 parent 4e2c334 commit 75e5e25

2 files changed

Lines changed: 53 additions & 93 deletions

File tree

libs/admin-api-lib/src/admin_api_lib/impl/summarizer/langchain_summarizer.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Module for the LangchainSummarizer class."""
22

3-
import asyncio
43
import logging
54
import traceback
65
from typing import Optional
@@ -28,8 +27,6 @@ class LangchainSummarizer(Summarizer):
2827
document and retries the summarization process if an error occurs.
2928
"""
3029

31-
RETRY_WAIT_TIME = 10
32-
3330
def __init__(
3431
self,
3532
langfuse_manager: LangfuseManager,
@@ -87,12 +84,7 @@ async def ainvoke(self, query: SummarizerInput, config: Optional[RunnableConfig]
8784
except Exception as e:
8885
logger.error("Error in summarizing langchain doc: %s %s", e, traceback.format_exc())
8986
config["tries_remaining"] = tries_remaining - 1
90-
if "rate limit" in str(e).lower() or "ratelimit" in str(e).lower():
91-
logger.warning(
92-
"Rate limit encountered, waiting %d seconds before retry...", self.RETRY_WAIT_TIME
93-
)
94-
await asyncio.sleep(self.RETRY_WAIT_TIME)
95-
result = await self.ainvoke(query, config)
87+
result = await self._create_chain().ainvoke({"text": langchain_document.page_content}, config)
9688
# Extract content from AIMessage if it's not already a string
9789
content = result.content if hasattr(result, "content") else str(result)
9890
outputs.append(content)
Lines changed: 52 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,72 @@
1-
"""Module for the LangchainSummarizer class."""
1+
"""Module for the DefaultConfluenceExtractor class."""
22

33
import logging
4-
import traceback
5-
from typing import Optional
6-
7-
from langchain.text_splitter import RecursiveCharacterTextSplitter
8-
from langchain_core.documents import Document
9-
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
10-
11-
from admin_api_lib.summarizer.summarizer import (
12-
Summarizer,
13-
SummarizerInput,
14-
SummarizerOutput,
4+
from langchain_community.document_loaders import ConfluenceLoader
5+
6+
from extractor_api_lib.impl.types.extractor_types import ExtractorTypes
7+
from extractor_api_lib.models.dataclasses.internal_information_piece import InternalInformationPiece
8+
from extractor_api_lib.models.extraction_parameters import ExtractionParameters
9+
from extractor_api_lib.extractors.information_extractor import InformationExtractor
10+
from extractor_api_lib.impl.mapper.confluence_langchain_document2information_piece import (
11+
ConfluenceLangchainDocument2InformationPiece,
1512
)
16-
from rag_core_lib.impl.langfuse_manager.langfuse_manager import LangfuseManager
17-
from rag_core_lib.impl.utils.async_threadsafe_semaphore import AsyncThreadsafeSemaphore
1813

1914
logger = logging.getLogger(__name__)
2015

2116

22-
class LangchainSummarizer(Summarizer):
23-
"""Is responsible for summarizing input data.
24-
25-
LangchainSummarizer is responsible for summarizing input data using the LangfuseManager,
26-
RecursiveCharacterTextSplitter, and AsyncThreadsafeSemaphore. It handles chunking of the input
27-
document and retries the summarization process if an error occurs.
28-
"""
17+
class ConfluenceExtractor(InformationExtractor):
18+
"""Implementation of the InformationExtractor interface for confluence."""
2919

3020
def __init__(
3121
self,
32-
langfuse_manager: LangfuseManager,
33-
chunker: RecursiveCharacterTextSplitter,
34-
semaphore: AsyncThreadsafeSemaphore,
22+
mapper: ConfluenceLangchainDocument2InformationPiece,
3523
):
36-
self._chunker = chunker
37-
self._langfuse_manager = langfuse_manager
38-
self._semaphore = semaphore
39-
40-
async def ainvoke(self, query: SummarizerInput, config: Optional[RunnableConfig] = None) -> SummarizerOutput:
4124
"""
42-
Asynchronously invokes the summarization process on the given query.
25+
Initialize the ConfluenceExtractor.
4326
4427
Parameters
4528
----------
46-
query : SummarizerInput
47-
The input data to be summarized.
48-
config : Optional[RunnableConfig], optional
49-
Configuration options for the summarization process, by default None.
50-
51-
Returns
52-
-------
53-
SummarizerOutput
54-
The summarized output.
55-
56-
Raises
57-
------
58-
Exception
59-
If the summary creation fails after the allowed number of tries.
60-
61-
Notes
62-
-----
63-
This method handles chunking of the input document and retries the summarization
64-
process if an error occurs, up to the number of tries specified in the config.
29+
mapper : ConfluenceLangchainDocument2InformationPiece
30+
An instance of ConfluenceLangchainDocument2InformationPiece used for mapping langchain documents
31+
to information pieces.
6532
"""
66-
assert query, "Query is empty: %s" % query # noqa S101
67-
config = ensure_config(config)
68-
tries_remaining = config.get("configurable", {}).get("tries_remaining", 3)
69-
logger.debug("Tries remaining %d" % tries_remaining)
33+
self._mapper = mapper
7034

71-
if tries_remaining < 0:
72-
raise Exception("Summary creation failed.")
73-
document = Document(page_content=query)
74-
langchain_documents = self._chunker.split_documents([document])
35+
@property
36+
def extractor_type(self) -> ExtractorTypes:
37+
return ExtractorTypes.CONFLUENCE
7538

76-
outputs = []
77-
for langchain_document in langchain_documents:
78-
async with self._semaphore:
79-
try:
80-
result = await self._create_chain().ainvoke({"text": langchain_document.page_content}, config)
81-
# Extract content from AIMessage if it's not already a string
82-
content = result.content if hasattr(result, "content") else str(result)
83-
outputs.append(content)
84-
except Exception as e:
85-
logger.error("Error in summarizing langchain doc: %s %s", e, traceback.format_exc())
86-
config["tries_remaining"] = tries_remaining - 1
87-
result = await self._create_chain().ainvoke({"text": langchain_document.page_content}, config)
88-
# Extract content from AIMessage if it's not already a string
89-
content = result.content if hasattr(result, "content") else str(result)
90-
outputs.append(content)
39+
async def aextract_content(
40+
self,
41+
extraction_parameters: ExtractionParameters,
42+
) -> list[InternalInformationPiece]:
43+
"""
44+
Asynchronously extracts information pieces from Confluence.
9145
92-
if len(outputs) == 1:
93-
return outputs[0]
94-
summary = " ".join(outputs)
95-
logger.debug(
96-
"Reduced number of chars from %d to %d"
97-
% (len("".join([x.page_content for x in langchain_documents])), len(summary))
98-
)
99-
return await self.ainvoke(summary, config)
46+
Parameters
47+
----------
48+
extraction_parameters : ExtractionParameters
49+
The parameters required to connect to and extract data from Confluence.
10050
101-
def _create_chain(self) -> Runnable:
102-
return self._langfuse_manager.get_base_prompt(self.__class__.__name__) | self._langfuse_manager.get_base_llm(
103-
self.__class__.__name__
104-
)
51+
Returns
52+
-------
53+
list[InternalInformationPiece]
54+
A list of information pieces extracted from Confluence.
55+
"""
56+
# Convert list of key value pairs to dict
57+
confluence_loader_parameters = {
58+
x.key: int(x.value) if x.value.isdigit() else x.value for x in extraction_parameters.kwargs
59+
}
60+
if not confluence_loader_parameters.get("max_pages") or isinstance(
61+
confluence_loader_parameters.get("max_pages"), str
62+
):
63+
logging.warning(
64+
"max_pages parameter is not set or invalid discarding it. ConfluenceLoader will use default value."
65+
)
66+
confluence_loader_parameters.pop("max_pages")
67+
# Drop the document_name parameter as it is not used by the ConfluenceLoader
68+
if "document_name" in confluence_loader_parameters:
69+
confluence_loader_parameters.pop("document_name", None)
70+
document_loader = ConfluenceLoader(**confluence_loader_parameters)
71+
documents = document_loader.load()
72+
return [self._mapper.map_document2informationpiece(x, extraction_parameters.document_name) for x in documents]

0 commit comments

Comments
 (0)