From 6dd8400b869d459c7316bc581f133c1799106130 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Wed, 1 Apr 2026 18:36:29 +0200 Subject: [PATCH 1/2] mark LangChain roots in metadata --- langfuse/langchain/CallbackHandler.py | 47 ++++++++++-- tests/test_langchain.py | 105 ++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 6 deletions(-) diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 5b5dfe691..8d2c8db90 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -303,6 +303,28 @@ def _parse_langfuse_trace_attributes( return attributes + def _get_langchain_observation_metadata( + self, + *, + parent_run_id: Optional[UUID], + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + keep_langfuse_trace_attributes: bool = False, + ) -> Optional[Dict[str, Any]]: + observation_metadata = self.__join_tags_and_metadata( + tags=tags, + metadata=metadata, + keep_langfuse_trace_attributes=keep_langfuse_trace_attributes, + ) + + if parent_run_id is not None: + return observation_metadata + + root_metadata = observation_metadata.copy() if observation_metadata else {} + root_metadata["is_langchain_root"] = True + + return root_metadata + def on_chain_start( self, serialized: Optional[Dict[str, Any]], @@ -325,7 +347,11 @@ def on_chain_start( ) span_name = self.get_langchain_run_name(serialized, **kwargs) - span_metadata = self.__join_tags_and_metadata(tags, metadata) + span_metadata = self._get_langchain_observation_metadata( + parent_run_id=parent_run_id, + tags=tags, + metadata=metadata, + ) span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None observation_type = self._get_observation_type_from_serialized( @@ -690,7 +716,11 @@ def on_tool_start( "on_tool_start", run_id, parent_run_id, input_str=input_str ) - meta = self.__join_tags_and_metadata(tags, metadata) + meta = self._get_langchain_observation_metadata( + parent_run_id=parent_run_id, + tags=tags, + metadata=metadata, + ) if not meta: meta = {} @@ -734,7 +764,11 @@ def on_retriever_start( "on_retriever_start", run_id, parent_run_id, query=query ) span_name = self.get_langchain_run_name(serialized, **kwargs) - span_metadata = self.__join_tags_and_metadata(tags, metadata) + span_metadata = self._get_langchain_observation_metadata( + parent_run_id=parent_run_id, + tags=tags, + metadata=metadata, + ) span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None observation_type = self._get_observation_type_from_serialized( @@ -865,9 +899,10 @@ def __on_llm_action( content = { "name": self.get_langchain_run_name(serialized, **kwargs), "input": prompts, - "metadata": self.__join_tags_and_metadata( - tags, - metadata, + "metadata": self._get_langchain_observation_metadata( + parent_run_id=parent_run_id, + tags=tags, + metadata=metadata, # If llm is run isolated and outside chain, keep trace attributes keep_langfuse_trace_attributes=True if parent_run_id is None diff --git a/tests/test_langchain.py b/tests/test_langchain.py index 6c9d3eb4d..a3d36845f 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -1,8 +1,10 @@ +import importlib import random import string import time from time import sleep from typing import Any, Dict, Literal +from uuid import uuid4 import pytest from langchain.messages import HumanMessage, SystemMessage @@ -14,6 +16,7 @@ from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.prebuilt import ToolNode +from opentelemetry import trace as otel_trace from pydantic import BaseModel, Field from langfuse._client.client import Langfuse @@ -21,6 +24,108 @@ from tests.utils import create_uuid, encode_file_to_base64, get_api +class _FakeLangchainObservation: + def __init__(self, recorder, **kwargs): + self._recorder = recorder + self._otel_span = otel_trace.NonRecordingSpan(otel_trace.INVALID_SPAN_CONTEXT) + self.trace_id = "test-trace-id" + self.metadata = kwargs.get("metadata") + self.name = kwargs.get("name") + self.input = kwargs.get("input") + self.as_type = kwargs.get("as_type") + self.updates = [] + self.ended = False + + def start_observation(self, **kwargs): + return self._recorder.start_observation(**kwargs) + + def update(self, **kwargs): + self.updates.append(kwargs) + return self + + def end(self): + self.ended = True + return self + + +class _FakeLangchainClient: + def __init__(self): + self.started_observations = [] + + def start_observation(self, **kwargs): + observation = _FakeLangchainObservation(self, **kwargs) + self.started_observations.append(observation) + return observation + + +def _patch_langchain_client(monkeypatch, fake_client): + callback_handler_module = importlib.import_module( + "langfuse.langchain.CallbackHandler" + ) + monkeypatch.setattr( + callback_handler_module, + "get_client", + lambda public_key=None: fake_client, + ) + + +def test_root_langchain_chain_sets_is_langchain_root_metadata(monkeypatch): + fake_client = _FakeLangchainClient() + _patch_langchain_client(monkeypatch, fake_client) + handler = CallbackHandler() + + root_run_id = uuid4() + child_run_id = uuid4() + + handler.on_chain_start( + serialized={"name": "RootChain"}, + inputs={"question": "hello"}, + run_id=root_run_id, + tags=["root-tag"], + metadata={"foo": "bar"}, + ) + handler.on_chain_start( + serialized={"name": "ChildChain"}, + inputs={"question": "child"}, + run_id=child_run_id, + parent_run_id=root_run_id, + metadata={"child": "metadata"}, + ) + handler.on_chain_end( + {"answer": "child"}, + run_id=child_run_id, + parent_run_id=root_run_id, + ) + handler.on_chain_end({"answer": "root"}, run_id=root_run_id) + + root_observation, child_observation = fake_client.started_observations + + assert root_observation.metadata == { + "foo": "bar", + "tags": ["root-tag"], + "is_langchain_root": True, + } + assert child_observation.metadata == {"child": "metadata"} + + +def test_root_langchain_llm_sets_is_langchain_root_metadata(monkeypatch): + fake_client = _FakeLangchainClient() + _patch_langchain_client(monkeypatch, fake_client) + handler = CallbackHandler() + + root_run_id = uuid4() + + handler.on_llm_start( + serialized={"name": "ChatOpenAI", "id": ["langchain", "ChatOpenAI"]}, + prompts=["hello"], + run_id=root_run_id, + invocation_params={"model_name": "gpt-4o-mini"}, + ) + handler._detach_observation(root_run_id) + + assert fake_client.started_observations[0].metadata == {"is_langchain_root": True} + + def test_callback_generated_from_trace_chat(): langfuse = Langfuse() From ccda75009f6688fc8742ac67d7c05afbec92496f Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Wed, 1 Apr 2026 18:45:00 +0200 Subject: [PATCH 2/2] simplify LangChain root metadata tests --- tests/test_langchain.py | 114 ++++------------------------------------ 1 file changed, 9 insertions(+), 105 deletions(-) diff --git a/tests/test_langchain.py b/tests/test_langchain.py index a3d36845f..fa2bcfddb 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -1,10 +1,8 @@ -import importlib import random import string import time from time import sleep from typing import Any, Dict, Literal -from uuid import uuid4 import pytest from langchain.messages import HumanMessage, SystemMessage @@ -16,7 +14,6 @@ from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.prebuilt import ToolNode -from opentelemetry import trace as otel_trace from pydantic import BaseModel, Field from langfuse._client.client import Langfuse @@ -24,108 +21,6 @@ from tests.utils import create_uuid, encode_file_to_base64, get_api -class _FakeLangchainObservation: - def __init__(self, recorder, **kwargs): - self._recorder = recorder - self._otel_span = otel_trace.NonRecordingSpan(otel_trace.INVALID_SPAN_CONTEXT) - self.trace_id = "test-trace-id" - self.metadata = kwargs.get("metadata") - self.name = kwargs.get("name") - self.input = kwargs.get("input") - self.as_type = kwargs.get("as_type") - self.updates = [] - self.ended = False - - def start_observation(self, **kwargs): - return self._recorder.start_observation(**kwargs) - - def update(self, **kwargs): - self.updates.append(kwargs) - return self - - def end(self): - self.ended = True - return self - - -class _FakeLangchainClient: - def __init__(self): - self.started_observations = [] - - def start_observation(self, **kwargs): - observation = _FakeLangchainObservation(self, **kwargs) - self.started_observations.append(observation) - return observation - - -def _patch_langchain_client(monkeypatch, fake_client): - callback_handler_module = importlib.import_module( - "langfuse.langchain.CallbackHandler" - ) - monkeypatch.setattr( - callback_handler_module, - "get_client", - lambda public_key=None: fake_client, - ) - - -def test_root_langchain_chain_sets_is_langchain_root_metadata(monkeypatch): - fake_client = _FakeLangchainClient() - _patch_langchain_client(monkeypatch, fake_client) - handler = CallbackHandler() - - root_run_id = uuid4() - child_run_id = uuid4() - - handler.on_chain_start( - serialized={"name": "RootChain"}, - inputs={"question": "hello"}, - run_id=root_run_id, - tags=["root-tag"], - metadata={"foo": "bar"}, - ) - handler.on_chain_start( - serialized={"name": "ChildChain"}, - inputs={"question": "child"}, - run_id=child_run_id, - parent_run_id=root_run_id, - metadata={"child": "metadata"}, - ) - handler.on_chain_end( - {"answer": "child"}, - run_id=child_run_id, - parent_run_id=root_run_id, - ) - handler.on_chain_end({"answer": "root"}, run_id=root_run_id) - - root_observation, child_observation = fake_client.started_observations - - assert root_observation.metadata == { - "foo": "bar", - "tags": ["root-tag"], - "is_langchain_root": True, - } - assert child_observation.metadata == {"child": "metadata"} - - -def test_root_langchain_llm_sets_is_langchain_root_metadata(monkeypatch): - fake_client = _FakeLangchainClient() - _patch_langchain_client(monkeypatch, fake_client) - handler = CallbackHandler() - - root_run_id = uuid4() - - handler.on_llm_start( - serialized={"name": "ChatOpenAI", "id": ["langchain", "ChatOpenAI"]}, - prompts=["hello"], - run_id=root_run_id, - invocation_params={"model_name": "gpt-4o-mini"}, - ) - handler._detach_observation(root_run_id) - - assert fake_client.started_observations[0].metadata == {"is_langchain_root": True} - - def test_callback_generated_from_trace_chat(): langfuse = Langfuse() @@ -172,6 +67,7 @@ def test_callback_generated_from_trace_chat(): assert langchain_generation_span.input != "" assert langchain_generation_span.output is not None assert langchain_generation_span.output != "" + assert langchain_generation_span.metadata["is_langchain_root"] is True def test_callback_generated_from_lcel_chain(): @@ -208,6 +104,11 @@ def test_callback_generated_from_lcel_chain(): trace.observations, ) )[0] + langchain_root_spans = [ + observation + for observation in trace.observations + if observation.metadata and observation.metadata.get("is_langchain_root") + ] assert langchain_generation_span.usage_details["input"] > 1 assert langchain_generation_span.usage_details["output"] > 0 @@ -216,6 +117,9 @@ def test_callback_generated_from_lcel_chain(): assert langchain_generation_span.input != "" assert langchain_generation_span.output is not None assert langchain_generation_span.output != "" + assert len(langchain_root_spans) == 1 + assert langchain_root_spans[0].type == "CHAIN" + assert langchain_root_spans[0].metadata["is_langchain_root"] is True @pytest.mark.skip(reason="Flaky")