Skip to content

Commit 4e2c334

Browse files
committed
refactor: update DependencyContainer and ConfluenceExtractor for improved summarization handling
1 parent 2341e55 commit 4e2c334

4 files changed

Lines changed: 150 additions & 117 deletions

File tree

libs/admin-api-lib/src/admin_api_lib/dependency_container.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,9 @@ class DependencyContainer(DeclarativeContainer):
138138
chunker=summary_text_splitter,
139139
semaphore=Singleton(AsyncThreadsafeSemaphore, summarizer_settings.maximum_concurrreny),
140140
)
141-
traced_summarizer = Singleton(
142-
LangfuseTracedRunnable,
143-
inner_chain=summarizer,
144-
settings=langfuse_settings,
145-
)
146141

147142
summary_enhancer = List(
148-
Singleton(PageSummaryEnhancer, traced_summarizer, chunker_settings),
143+
Singleton(PageSummaryEnhancer, summarizer, chunker_settings),
149144
)
150145
untraced_information_enhancer = Singleton(
151146
GeneralEnhancer,
Lines changed: 84 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,104 @@
1-
"""Module for the DefaultConfluenceExtractor class."""
1+
"""Module for the LangchainSummarizer class."""
22

33
import logging
4-
from langchain_community.document_loaders import ConfluenceLoader
5-
6-
from extractor_api_lib.impl.types.extractor_types import ExtractorTypes
7-
from extractor_api_lib.models.dataclasses.internal_information_piece import InternalInformationPiece
8-
from extractor_api_lib.models.extraction_parameters import ExtractionParameters
9-
from extractor_api_lib.extractors.information_extractor import InformationExtractor
10-
from extractor_api_lib.impl.mapper.confluence_langchain_document2information_piece import (
11-
ConfluenceLangchainDocument2InformationPiece,
4+
import traceback
5+
from typing import Optional
6+
7+
from langchain.text_splitter import RecursiveCharacterTextSplitter
8+
from langchain_core.documents import Document
9+
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
10+
11+
from admin_api_lib.summarizer.summarizer import (
12+
Summarizer,
13+
SummarizerInput,
14+
SummarizerOutput,
1215
)
16+
from rag_core_lib.impl.langfuse_manager.langfuse_manager import LangfuseManager
17+
from rag_core_lib.impl.utils.async_threadsafe_semaphore import AsyncThreadsafeSemaphore
1318

1419
logger = logging.getLogger(__name__)
1520

1621

17-
class ConfluenceExtractor(InformationExtractor):
18-
"""Implementation of the InformationExtractor interface for confluence."""
22+
class LangchainSummarizer(Summarizer):
23+
"""Is responsible for summarizing input data.
24+
25+
LangchainSummarizer is responsible for summarizing input data using the LangfuseManager,
26+
RecursiveCharacterTextSplitter, and AsyncThreadsafeSemaphore. It handles chunking of the input
27+
document and retries the summarization process if an error occurs.
28+
"""
1929

2030
def __init__(
2131
self,
22-
mapper: ConfluenceLangchainDocument2InformationPiece,
32+
langfuse_manager: LangfuseManager,
33+
chunker: RecursiveCharacterTextSplitter,
34+
semaphore: AsyncThreadsafeSemaphore,
2335
):
24-
"""
25-
Initialize the ConfluenceExtractor.
26-
27-
Parameters
28-
----------
29-
mapper : ConfluenceLangchainDocument2InformationPiece
30-
An instance of ConfluenceLangchainDocument2InformationPiece used for mapping langchain documents
31-
to information pieces.
32-
"""
33-
self._mapper = mapper
36+
self._chunker = chunker
37+
self._langfuse_manager = langfuse_manager
38+
self._semaphore = semaphore
3439

35-
@property
36-
def extractor_type(self) -> ExtractorTypes:
37-
return ExtractorTypes.CONFLUENCE
38-
39-
async def aextract_content(
40-
self,
41-
extraction_parameters: ExtractionParameters,
42-
) -> list[InternalInformationPiece]:
40+
async def ainvoke(self, query: SummarizerInput, config: Optional[RunnableConfig] = None) -> SummarizerOutput:
4341
"""
44-
Asynchronously extracts information pieces from Confluence.
42+
Asynchronously invokes the summarization process on the given query.
4543
4644
Parameters
4745
----------
48-
extraction_parameters : ExtractionParameters
49-
The parameters required to connect to and extract data from Confluence.
46+
query : SummarizerInput
47+
The input data to be summarized.
48+
config : Optional[RunnableConfig], optional
49+
Configuration options for the summarization process, by default None.
5050
5151
Returns
5252
-------
53-
list[InternalInformationPiece]
54-
A list of information pieces extracted from Confluence.
53+
SummarizerOutput
54+
The summarized output.
55+
56+
Raises
57+
------
58+
Exception
59+
If the summary creation fails after the allowed number of tries.
60+
61+
Notes
62+
-----
63+
This method handles chunking of the input document and retries the summarization
64+
process if an error occurs, up to the number of tries specified in the config.
5565
"""
56-
# Convert list of key value pairs to dict
57-
confluence_loader_parameters = {
58-
x.key: int(x.value) if x.value.isdigit() else x.value for x in extraction_parameters.kwargs
59-
}
60-
if not confluence_loader_parameters.get("max_pages") or isinstance(
61-
confluence_loader_parameters.get("max_pages"), str
62-
):
63-
logging.warning(
64-
"max_pages parameter is not set or invalid discarding it. ConfluenceLoader will use default value."
65-
)
66-
confluence_loader_parameters.pop("max_pages")
67-
# Drop the document_name parameter as it is not used by the ConfluenceLoader
68-
if "document_name" in confluence_loader_parameters:
69-
confluence_loader_parameters.pop("document_name", None)
70-
document_loader = ConfluenceLoader(**confluence_loader_parameters)
71-
documents = document_loader.load()
72-
return [self._mapper.map_document2informationpiece(x, extraction_parameters.document_name) for x in documents]
66+
assert query, "Query is empty: %s" % query # noqa S101
67+
config = ensure_config(config)
68+
tries_remaining = config.get("configurable", {}).get("tries_remaining", 3)
69+
logger.debug("Tries remaining %d" % tries_remaining)
70+
71+
if tries_remaining < 0:
72+
raise Exception("Summary creation failed.")
73+
document = Document(page_content=query)
74+
langchain_documents = self._chunker.split_documents([document])
75+
76+
outputs = []
77+
for langchain_document in langchain_documents:
78+
async with self._semaphore:
79+
try:
80+
result = await self._create_chain().ainvoke({"text": langchain_document.page_content}, config)
81+
# Extract content from AIMessage if it's not already a string
82+
content = result.content if hasattr(result, "content") else str(result)
83+
outputs.append(content)
84+
except Exception as e:
85+
logger.error("Error in summarizing langchain doc: %s %s", e, traceback.format_exc())
86+
config["tries_remaining"] = tries_remaining - 1
87+
result = await self._create_chain().ainvoke({"text": langchain_document.page_content}, config)
88+
# Extract content from AIMessage if it's not already a string
89+
content = result.content if hasattr(result, "content") else str(result)
90+
outputs.append(content)
91+
92+
if len(outputs) == 1:
93+
return outputs[0]
94+
summary = " ".join(outputs)
95+
logger.debug(
96+
"Reduced number of chars from %d to %d"
97+
% (len("".join([x.page_content for x in langchain_documents])), len(summary))
98+
)
99+
return await self.ainvoke(summary, config)
100+
101+
def _create_chain(self) -> Runnable:
102+
return self._langfuse_manager.get_base_prompt(self.__class__.__name__) | self._langfuse_manager.get_base_llm(
103+
self.__class__.__name__
104+
)

libs/rag-core-api/src/rag_core_api/impl/evaluator/langfuse_ragas_evaluator.py

Lines changed: 62 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
from asyncio import gather
88
from datetime import datetime
99
from json import JSONDecodeError
10+
from time import sleep
1011
from uuid import uuid4
1112

1213
import ragas
1314
from datasets import Dataset
1415
from langchain_core.runnables import RunnableConfig
1516
from langfuse import Langfuse
17+
from langfuse.api.core.api_error import ApiError
1618
from langfuse.api.resources.commons.errors.not_found_error import NotFoundError
1719
from langfuse._client.datasets import DatasetClient
1820
from ragas.llms import LangchainLLMWrapper
@@ -162,61 +164,67 @@ async def _aevaluate_question(self, item, experiment_name: str, generation_time:
162164
async with self._semaphore:
163165
chat_request = ChatRequest(message=item.input)
164166

165-
# Use item.run context manager for trace
166-
with item.run(
167-
run_name=experiment_name,
168-
run_metadata={"model": self._settings.model},
169-
run_description=f"Evaluation run for {experiment_name}",
170-
) as root_span:
171-
# Use langfuse.start_as_current_generation for generation
172-
try:
173-
response = await self._chat_endpoint.achat(config["metadata"]["session_id"], chat_request)
174-
except Exception as e:
175-
logger.info("Error while answering question %s: %s", item.input, e)
176-
response = None
177-
output = {
178-
"answer": response.answer if response else None,
179-
"documents": (
180-
[x.page_content for x in response.citations] if response and response.citations else None
181-
),
182-
}
183-
with self._langfuse.start_as_current_generation(
184-
name="rag-eval-llm-call",
185-
input={"question": item.input, "context": output["documents"]},
186-
metadata={"item_id": item.id, "run": experiment_name},
187-
model=self._settings.model,
188-
) as generation:
189-
190-
generation.update(output=output["answer"])
191-
generation.update_trace(
192-
input={"question": item.input, "context": output["documents"]},
193-
metadata={"item_id": item.id, "run": experiment_name},
194-
output=output["answer"],
195-
)
167+
try:
168+
response = await self._chat_endpoint.achat(config["metadata"]["session_id"], chat_request)
169+
except Exception as e:
170+
logger.info("Error while answering question %s: %s", item.input, e)
171+
response = None
196172

197-
# Ragas metrics
198-
if response and response.citations:
199-
eval_data = Dataset.from_dict(
200-
{
201-
"question": [item.input],
202-
"answer": [output["answer"]],
203-
"contexts": [output["documents"]],
204-
"ground_truth": [item.expected_output],
205-
}
206-
)
207-
result = ragas.evaluate(
208-
eval_data,
209-
metrics=self.METRICS,
210-
llm=self._llm_wrapped,
211-
embeddings=self._embedder,
173+
if response and response.citations:
174+
output = {"answer": response.answer, "documents": [x.page_content for x in response.citations]}
175+
else:
176+
output = {"answer": None, "documents": None}
177+
178+
langfuse_generation = self._langfuse.generation(
179+
name=self._settings.evaluation_dataset_name,
180+
input=item.input,
181+
output=output,
182+
start_time=generation_time,
183+
end_time=datetime.now(),
184+
)
185+
self._link_item2generation(item, langfuse_generation, experiment_name)
186+
187+
if not (response and response.citations):
188+
for metric in self.METRICS:
189+
langfuse_generation.score(
190+
name=metric.name,
191+
value=self.DEFAULT_SCORE_VALUE,
212192
)
213-
for metric, score in result.scores[0].items():
214-
if math.isnan(score):
215-
score = self.DEFAULT_SCORE_VALUE
216-
root_span.score_trace(name=metric, value=score)
217-
else:
218-
for metric in self.METRICS:
219-
root_span.score_trace(name=metric.name, value=self.DEFAULT_SCORE_VALUE)
193+
return
194+
195+
eval_data = Dataset.from_dict(
196+
{
197+
"question": [item.input],
198+
"answer": [output["answer"]],
199+
"contexts": [output["documents"]],
200+
"ground_truth": [item.expected_output],
201+
}
202+
)
203+
204+
result = ragas.evaluate(
205+
eval_data,
206+
metrics=self.METRICS,
207+
llm=self._llm_wrapped,
208+
embeddings=self._embedder,
209+
)
210+
for metric, score in result.scores[0].items():
211+
if math.isnan(score):
212+
score = self.DEFAULT_SCORE_VALUE
213+
langfuse_generation.score(
214+
name=metric,
215+
value=score,
216+
)
217+
218+
def _link_item2generation(self, item, generation, experiment_name, retries: int = 0):
219+
try:
220+
item.link(generation, experiment_name)
221+
except ApiError as e:
222+
logger.warning("Failed to link item to generation: %s", e)
223+
retries += 1
224+
if retries > self.MAX_RETRIES:
225+
raise e
226+
sleep(1)
227+
self._link_item2generation(item, generation, experiment_name, retries)
220228

221229
def _get_dataset(self, dataset_name: str) -> DatasetClient:
222230
dataset = None
@@ -232,7 +240,7 @@ def _get_dataset(self, dataset_name: str) -> DatasetClient:
232240
return dataset
233241

234242
def _create_dataset(self, dataset_name: str = None):
235-
self._langfuse.create_dataset(name=dataset_name)
243+
self._langfuse.create_dataset(dataset_name)
236244

237245
data = self._load_dataset_items()
238246
self._store_items_in_dataset(data, dataset_name)

libs/rag-core-lib/src/rag_core_lib/tracers/traced_runnable.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,9 @@ async def ainvoke(
6666
config = ensure_config(config)
6767
session_id = self._get_session_id(config)
6868
config_with_tracing = self._add_tracing_callback(config)
69-
with self.langfuse_client.start_as_current_span(name=self._inner_chain.__class__.__name__) as span:
70-
span.update_trace(session_id=session_id, input=chain_input)
71-
output = await self._inner_chain.ainvoke(chain_input, config=config_with_tracing)
72-
span.update_trace(output=output)
73-
return output
69+
with self.langfuse_client.start_as_current_span(name="traced_runnable") as span:
70+
span.update_trace(session_id=session_id)
71+
return await self._inner_chain.ainvoke(chain_input, config=config_with_tracing)
7472

7573
@abstractmethod
7674
def _add_tracing_callback(self, config: Optional[RunnableConfig]) -> RunnableConfig: ...

0 commit comments

Comments
 (0)