Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions langfuse/langchain/CallbackHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,28 @@ def _parse_langfuse_trace_attributes_from_metadata(

return attributes

def _get_langchain_observation_metadata(
self,
*,
parent_run_id: Optional[UUID],
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
keep_langfuse_trace_attributes: bool = False,
) -> Optional[Dict[str, Any]]:
observation_metadata = self.__join_tags_and_metadata(
tags=tags,
metadata=metadata,
keep_langfuse_trace_attributes=keep_langfuse_trace_attributes,
)

if parent_run_id is not None:
return observation_metadata

root_metadata = observation_metadata.copy() if observation_metadata else {}
root_metadata["is_langchain_root"] = True

return root_metadata

def on_chain_start(
self,
serialized: Optional[Dict[str, Any]],
Expand All @@ -314,7 +336,11 @@ def on_chain_start(
)

span_name = self.get_langchain_run_name(serialized, **kwargs)
span_metadata = self.__join_tags_and_metadata(tags, metadata)
span_metadata = self._get_langchain_observation_metadata(
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
)
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None

observation_type = self._get_observation_type_from_serialized(
Expand Down Expand Up @@ -356,7 +382,9 @@ def on_chain_start(
{
"input": inputs,
"name": span_name,
"metadata": span_metadata,
"metadata": self.__join_tags_and_metadata(
tags, metadata
),
},
)
if self.update_trace
Expand Down Expand Up @@ -682,7 +710,11 @@ def on_tool_start(
"on_tool_start", run_id, parent_run_id, input_str=input_str
)

meta = self.__join_tags_and_metadata(tags, metadata)
meta = self._get_langchain_observation_metadata(
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
)

if not meta:
meta = {}
Expand Down Expand Up @@ -726,7 +758,11 @@ def on_retriever_start(
"on_retriever_start", run_id, parent_run_id, query=query
)
span_name = self.get_langchain_run_name(serialized, **kwargs)
span_metadata = self.__join_tags_and_metadata(tags, metadata)
span_metadata = self._get_langchain_observation_metadata(
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
)
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None

observation_type = self._get_observation_type_from_serialized(
Expand Down Expand Up @@ -857,9 +893,10 @@ def __on_llm_action(
content = {
"name": self.get_langchain_run_name(serialized, **kwargs),
"input": prompts,
"metadata": self.__join_tags_and_metadata(
tags,
metadata,
"metadata": self._get_langchain_observation_metadata(
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
# If llm is run isolated and outside chain, keep trace attributes
keep_langfuse_trace_attributes=True
if parent_run_id is None
Expand Down
9 changes: 9 additions & 0 deletions tests/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_callback_generated_from_trace_chat():
assert langchain_generation_span.input != ""
assert langchain_generation_span.output is not None
assert langchain_generation_span.output != ""
assert langchain_generation_span.metadata["is_langchain_root"] is True


def test_callback_generated_from_lcel_chain():
Expand Down Expand Up @@ -103,6 +104,11 @@ def test_callback_generated_from_lcel_chain():
trace.observations,
)
)[0]
langchain_root_spans = [
observation
for observation in trace.observations
if observation.metadata and observation.metadata.get("is_langchain_root")
]

assert langchain_generation_span.usage_details["input"] > 1
assert langchain_generation_span.usage_details["output"] > 0
Expand All @@ -111,6 +117,9 @@ def test_callback_generated_from_lcel_chain():
assert langchain_generation_span.input != ""
assert langchain_generation_span.output is not None
assert langchain_generation_span.output != ""
assert len(langchain_root_spans) == 1
assert langchain_root_spans[0].type == "CHAIN"
assert langchain_root_spans[0].metadata["is_langchain_root"] is True


@pytest.mark.skip(reason="Flaky")
Expand Down
Loading