Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions src/openlayer/lib/integrations/langchain_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pylint: disable=unused-argument
import time
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING, Callable
from uuid import UUID

try:
Expand Down Expand Up @@ -52,6 +52,8 @@ def __init__(self, **kwargs: Any) -> None:
self.root_steps: set[UUID] = set() # Track which steps are root
# Extract inference_id from kwargs if provided
self._inference_id = kwargs.get("inference_id")
# Extract metadata_transformer from kwargs if provided
self._metadata_transformer = kwargs.get("metadata_transformer")

def _start_step(
self,
Expand Down Expand Up @@ -207,6 +209,25 @@ def _process_and_upload_trace(self, root_step: steps.Step) -> None:
# Reset trace context only for standalone traces
tracer._current_trace.set(None)

def _process_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Apply user-defined metadata transformation if provided."""
if not metadata:
return {}

# First convert LangChain objects to JSON-serializable format
converted_metadata = self._convert_langchain_objects(metadata)

# Then apply custom transformer if provided
if self._metadata_transformer:
try:
return self._metadata_transformer(converted_metadata)
except Exception as e:
# Log warning but continue with unconverted metadata
tracer.logger.warning(f"Metadata transformer failed: {e}")
return converted_metadata

return converted_metadata

def _convert_step_objects_recursively(self, step: steps.Step) -> None:
"""Convert all LangChain objects in a step and its nested steps."""
# Convert step attributes
Expand All @@ -217,7 +238,7 @@ def _convert_step_objects_recursively(self, step: steps.Step) -> None:
converted_output = self._convert_langchain_objects(step.output)
step.output = utils.json_serialize(converted_output)
if step.metadata is not None:
step.metadata = self._convert_langchain_objects(step.metadata)
step.metadata = self._process_metadata(step.metadata)

# Convert nested steps recursively
for nested_step in step.steps:
Expand Down Expand Up @@ -471,7 +492,16 @@ def _handle_llm_end(
return

output = self._extract_output(response)
token_info = self._extract_token_info(response)

# Only extract token info if it hasn't been set during streaming
step = self.steps[run_id]
token_info = {}
if not (
hasattr(step, "prompt_tokens")
and step.prompt_tokens is not None
and step.prompt_tokens > 0
):
token_info = self._extract_token_info(response)

self._end_step(
run_id=run_id,
Expand Down Expand Up @@ -544,6 +574,23 @@ def _handle_chain_end(
if run_id not in self.steps:
return

# Check if this is a ConversationalRetrievalChain with source documents
if isinstance(outputs, dict) and "source_documents" in outputs:
source_docs = outputs["source_documents"]
if source_docs:
# Extract content from source documents
context_list = []
for doc in source_docs:
if hasattr(doc, "page_content"):
context_list.append(doc.page_content)
else:
context_list.append(str(doc))

if context_list:
current_trace = tracer.get_current_trace()
if current_trace:
current_trace.update_metadata(context=context_list)

self._end_step(
run_id=run_id,
parent_run_id=parent_run_id,
Expand Down Expand Up @@ -724,6 +771,10 @@ def _handle_retriever_end(
else:
doc_contents.append(str(doc))

current_trace = tracer.get_current_trace()
if current_trace:
current_trace.update_metadata(context=doc_contents)

self._end_step(
run_id=run_id,
parent_run_id=parent_run_id,
Expand All @@ -742,6 +793,35 @@ def _handle_retriever_error(
"""Common logic for retriever error."""
self._end_step(run_id=run_id, parent_run_id=parent_run_id, error=str(error))

def _handle_llm_new_token(self, token: str, **kwargs: Any) -> Any:
"""Common logic for LLM new token."""
# Safely check for chunk and usage_metadata
chunk = kwargs.get("chunk")
if (
chunk
and hasattr(chunk, "message")
and hasattr(chunk.message, "usage_metadata")
):
usage = chunk.message.usage_metadata

# Only proceed if usage is not None
if usage:
# Extract run_id from kwargs (should be provided by LangChain)
run_id = kwargs.get("run_id")
if run_id and run_id in self.steps:
# Convert usage to the expected format like _extract_token_info does
token_info = {
"prompt_tokens": usage.get("input_tokens", 0),
"completion_tokens": usage.get("output_tokens", 0),
"tokens": usage.get("total_tokens", 0),
}

# Update the step with token usage information
step = self.steps[run_id]
if isinstance(step, steps.ChatCompletionStep):
step.log(**token_info)
return


class OpenlayerHandler(OpenlayerHandlerMixin, BaseCallbackHandlerClass): # type: ignore[misc]
"""LangChain callback handler that logs to Openlayer."""
Expand All @@ -754,11 +834,16 @@ def __init__(
ignore_retriever=False,
ignore_agent=False,
inference_id: Optional[Any] = None,
metadata_transformer: Optional[
Callable[[Dict[str, Any]], Dict[str, Any]]
] = None,
**kwargs: Any,
) -> None:
# Add inference_id to kwargs so it gets passed to mixin
# Add both inference_id and metadata_transformer to kwargs so they get passed to mixin
if inference_id is not None:
kwargs["inference_id"] = inference_id
if metadata_transformer is not None:
kwargs["metadata_transformer"] = metadata_transformer
super().__init__(**kwargs)
# Store the ignore flags as instance variables
self._ignore_llm = ignore_llm
Expand Down Expand Up @@ -822,7 +907,7 @@ def on_llm_error(

def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
pass
return self._handle_llm_new_token(token, **kwargs)

def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
Expand Down Expand Up @@ -900,11 +985,16 @@ def __init__(
ignore_retriever=False,
ignore_agent=False,
inference_id: Optional[Any] = None,
metadata_transformer: Optional[
Callable[[Dict[str, Any]], Dict[str, Any]]
] = None,
**kwargs: Any,
) -> None:
# Add inference_id to kwargs so it gets passed to mixin
# Add both inference_id and metadata_transformer to kwargs so they get passed to mixin
if inference_id is not None:
kwargs["inference_id"] = inference_id
if metadata_transformer is not None:
kwargs["metadata_transformer"] = metadata_transformer
super().__init__(**kwargs)
# Store the ignore flags as instance variables
self._ignore_llm = ignore_llm
Expand Down Expand Up @@ -1106,7 +1196,7 @@ async def on_llm_error(
return self._handle_llm_error(error, **kwargs)

async def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
pass
return self._handle_llm_new_token(token, **kwargs)

async def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
Expand Down