Skip to content

Commit 6f9eaf2

Browse files
authored
feat(langchain): mark LangChain root observations in metadata (#1604)
1 parent b680136 commit 6f9eaf2

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

langfuse/langchain/CallbackHandler.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,28 @@ def _parse_langfuse_trace_attributes(
303303

304304
return attributes
305305

306+
def _get_langchain_observation_metadata(
307+
self,
308+
*,
309+
parent_run_id: Optional[UUID],
310+
tags: Optional[List[str]] = None,
311+
metadata: Optional[Dict[str, Any]] = None,
312+
keep_langfuse_trace_attributes: bool = False,
313+
) -> Optional[Dict[str, Any]]:
314+
observation_metadata = self.__join_tags_and_metadata(
315+
tags=tags,
316+
metadata=metadata,
317+
keep_langfuse_trace_attributes=keep_langfuse_trace_attributes,
318+
)
319+
320+
if parent_run_id is not None:
321+
return observation_metadata
322+
323+
root_metadata = observation_metadata.copy() if observation_metadata else {}
324+
root_metadata["is_langchain_root"] = True
325+
326+
return root_metadata
327+
306328
def on_chain_start(
307329
self,
308330
serialized: Optional[Dict[str, Any]],
@@ -325,7 +347,11 @@ def on_chain_start(
325347
)
326348

327349
span_name = self.get_langchain_run_name(serialized, **kwargs)
328-
span_metadata = self.__join_tags_and_metadata(tags, metadata)
350+
span_metadata = self._get_langchain_observation_metadata(
351+
parent_run_id=parent_run_id,
352+
tags=tags,
353+
metadata=metadata,
354+
)
329355
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
330356

331357
observation_type = self._get_observation_type_from_serialized(
@@ -690,7 +716,11 @@ def on_tool_start(
690716
"on_tool_start", run_id, parent_run_id, input_str=input_str
691717
)
692718

693-
meta = self.__join_tags_and_metadata(tags, metadata)
719+
meta = self._get_langchain_observation_metadata(
720+
parent_run_id=parent_run_id,
721+
tags=tags,
722+
metadata=metadata,
723+
)
694724

695725
if not meta:
696726
meta = {}
@@ -734,7 +764,11 @@ def on_retriever_start(
734764
"on_retriever_start", run_id, parent_run_id, query=query
735765
)
736766
span_name = self.get_langchain_run_name(serialized, **kwargs)
737-
span_metadata = self.__join_tags_and_metadata(tags, metadata)
767+
span_metadata = self._get_langchain_observation_metadata(
768+
parent_run_id=parent_run_id,
769+
tags=tags,
770+
metadata=metadata,
771+
)
738772
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
739773

740774
observation_type = self._get_observation_type_from_serialized(
@@ -865,9 +899,10 @@ def __on_llm_action(
865899
content = {
866900
"name": self.get_langchain_run_name(serialized, **kwargs),
867901
"input": prompts,
868-
"metadata": self.__join_tags_and_metadata(
869-
tags,
870-
metadata,
902+
"metadata": self._get_langchain_observation_metadata(
903+
parent_run_id=parent_run_id,
904+
tags=tags,
905+
metadata=metadata,
871906
# If llm is run isolated and outside chain, keep trace attributes
872907
keep_langfuse_trace_attributes=True
873908
if parent_run_id is None

tests/test_langchain.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def test_callback_generated_from_trace_chat():
6767
assert langchain_generation_span.input != ""
6868
assert langchain_generation_span.output is not None
6969
assert langchain_generation_span.output != ""
70+
assert langchain_generation_span.metadata["is_langchain_root"] is True
7071

7172

7273
def test_callback_generated_from_lcel_chain():
@@ -103,6 +104,11 @@ def test_callback_generated_from_lcel_chain():
103104
trace.observations,
104105
)
105106
)[0]
107+
langchain_root_spans = [
108+
observation
109+
for observation in trace.observations
110+
if observation.metadata and observation.metadata.get("is_langchain_root")
111+
]
106112

107113
assert langchain_generation_span.usage_details["input"] > 1
108114
assert langchain_generation_span.usage_details["output"] > 0
@@ -111,6 +117,9 @@ def test_callback_generated_from_lcel_chain():
111117
assert langchain_generation_span.input != ""
112118
assert langchain_generation_span.output is not None
113119
assert langchain_generation_span.output != ""
120+
assert len(langchain_root_spans) == 1
121+
assert langchain_root_spans[0].type == "CHAIN"
122+
assert langchain_root_spans[0].metadata["is_langchain_root"] is True
114123

115124

116125
@pytest.mark.skip(reason="Flaky")

0 commit comments

Comments
 (0)