Skip to content

Commit 41fadbf

Browse files
authored
feat(langchain): backport Langchain root metadata flag to v3-stable (#1605)
mark LangChain roots in metadata
1 parent cfd3ac1 commit 41fadbf

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

langfuse/langchain/CallbackHandler.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,28 @@ def _parse_langfuse_trace_attributes_from_metadata(
292292

293293
return attributes
294294

295+
def _get_langchain_observation_metadata(
296+
self,
297+
*,
298+
parent_run_id: Optional[UUID],
299+
tags: Optional[List[str]] = None,
300+
metadata: Optional[Dict[str, Any]] = None,
301+
keep_langfuse_trace_attributes: bool = False,
302+
) -> Optional[Dict[str, Any]]:
303+
observation_metadata = self.__join_tags_and_metadata(
304+
tags=tags,
305+
metadata=metadata,
306+
keep_langfuse_trace_attributes=keep_langfuse_trace_attributes,
307+
)
308+
309+
if parent_run_id is not None:
310+
return observation_metadata
311+
312+
root_metadata = observation_metadata.copy() if observation_metadata else {}
313+
root_metadata["is_langchain_root"] = True
314+
315+
return root_metadata
316+
295317
def on_chain_start(
296318
self,
297319
serialized: Optional[Dict[str, Any]],
@@ -314,7 +336,11 @@ def on_chain_start(
314336
)
315337

316338
span_name = self.get_langchain_run_name(serialized, **kwargs)
317-
span_metadata = self.__join_tags_and_metadata(tags, metadata)
339+
span_metadata = self._get_langchain_observation_metadata(
340+
parent_run_id=parent_run_id,
341+
tags=tags,
342+
metadata=metadata,
343+
)
318344
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
319345

320346
observation_type = self._get_observation_type_from_serialized(
@@ -356,7 +382,9 @@ def on_chain_start(
356382
{
357383
"input": inputs,
358384
"name": span_name,
359-
"metadata": span_metadata,
385+
"metadata": self.__join_tags_and_metadata(
386+
tags, metadata
387+
),
360388
},
361389
)
362390
if self.update_trace
@@ -682,7 +710,11 @@ def on_tool_start(
682710
"on_tool_start", run_id, parent_run_id, input_str=input_str
683711
)
684712

685-
meta = self.__join_tags_and_metadata(tags, metadata)
713+
meta = self._get_langchain_observation_metadata(
714+
parent_run_id=parent_run_id,
715+
tags=tags,
716+
metadata=metadata,
717+
)
686718

687719
if not meta:
688720
meta = {}
@@ -726,7 +758,11 @@ def on_retriever_start(
726758
"on_retriever_start", run_id, parent_run_id, query=query
727759
)
728760
span_name = self.get_langchain_run_name(serialized, **kwargs)
729-
span_metadata = self.__join_tags_and_metadata(tags, metadata)
761+
span_metadata = self._get_langchain_observation_metadata(
762+
parent_run_id=parent_run_id,
763+
tags=tags,
764+
metadata=metadata,
765+
)
730766
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
731767

732768
observation_type = self._get_observation_type_from_serialized(
@@ -857,9 +893,10 @@ def __on_llm_action(
857893
content = {
858894
"name": self.get_langchain_run_name(serialized, **kwargs),
859895
"input": prompts,
860-
"metadata": self.__join_tags_and_metadata(
861-
tags,
862-
metadata,
896+
"metadata": self._get_langchain_observation_metadata(
897+
parent_run_id=parent_run_id,
898+
tags=tags,
899+
metadata=metadata,
863900
# If llm is run isolated and outside chain, keep trace attributes
864901
keep_langfuse_trace_attributes=True
865902
if parent_run_id is None

tests/test_langchain.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def test_callback_generated_from_trace_chat():
6767
assert langchain_generation_span.input != ""
6868
assert langchain_generation_span.output is not None
6969
assert langchain_generation_span.output != ""
70+
assert langchain_generation_span.metadata["is_langchain_root"] is True
7071

7172

7273
def test_callback_generated_from_lcel_chain():
@@ -103,6 +104,11 @@ def test_callback_generated_from_lcel_chain():
103104
trace.observations,
104105
)
105106
)[0]
107+
langchain_root_spans = [
108+
observation
109+
for observation in trace.observations
110+
if observation.metadata and observation.metadata.get("is_langchain_root")
111+
]
106112

107113
assert langchain_generation_span.usage_details["input"] > 1
108114
assert langchain_generation_span.usage_details["output"] > 0
@@ -111,6 +117,9 @@ def test_callback_generated_from_lcel_chain():
111117
assert langchain_generation_span.input != ""
112118
assert langchain_generation_span.output is not None
113119
assert langchain_generation_span.output != ""
120+
assert len(langchain_root_spans) == 1
121+
assert langchain_root_spans[0].type == "CHAIN"
122+
assert langchain_root_spans[0].metadata["is_langchain_root"] is True
114123

115124

116125
@pytest.mark.skip(reason="Flaky")

0 commit comments

Comments
 (0)