1- import importlib
21import random
32import string
43import time
54from time import sleep
65from typing import Any , Dict , Literal
7- from uuid import uuid4
86
97import pytest
108from langchain .messages import HumanMessage , SystemMessage
1614from langgraph .checkpoint .memory import MemorySaver
1715from langgraph .graph import END , START , MessagesState , StateGraph
1816from langgraph .prebuilt import ToolNode
19- from opentelemetry import trace as otel_trace
2017from pydantic import BaseModel , Field
2118
2219from langfuse ._client .client import Langfuse
2320from langfuse .langchain import CallbackHandler
2421from tests .utils import create_uuid , encode_file_to_base64 , get_api
2522
2623
27- class _FakeLangchainObservation :
28- def __init__ (self , recorder , ** kwargs ):
29- self ._recorder = recorder
30- self ._otel_span = otel_trace .NonRecordingSpan (otel_trace .INVALID_SPAN_CONTEXT )
31- self .trace_id = "test-trace-id"
32- self .metadata = kwargs .get ("metadata" )
33- self .name = kwargs .get ("name" )
34- self .input = kwargs .get ("input" )
35- self .as_type = kwargs .get ("as_type" )
36- self .updates = []
37- self .ended = False
38-
39- def start_observation (self , ** kwargs ):
40- return self ._recorder .start_observation (** kwargs )
41-
42- def update (self , ** kwargs ):
43- self .updates .append (kwargs )
44- return self
45-
46- def end (self ):
47- self .ended = True
48- return self
49-
50-
51- class _FakeLangchainClient :
52- def __init__ (self ):
53- self .started_observations = []
54-
55- def start_observation (self , ** kwargs ):
56- observation = _FakeLangchainObservation (self , ** kwargs )
57- self .started_observations .append (observation )
58- return observation
59-
60-
61- def _patch_langchain_client (monkeypatch , fake_client ):
62- callback_handler_module = importlib .import_module (
63- "langfuse.langchain.CallbackHandler"
64- )
65- monkeypatch .setattr (
66- callback_handler_module ,
67- "get_client" ,
68- lambda public_key = None : fake_client ,
69- )
70-
71-
72- def test_root_langchain_chain_sets_is_langchain_root_metadata (monkeypatch ):
73- fake_client = _FakeLangchainClient ()
74- _patch_langchain_client (monkeypatch , fake_client )
75- handler = CallbackHandler ()
76-
77- root_run_id = uuid4 ()
78- child_run_id = uuid4 ()
79-
80- handler .on_chain_start (
81- serialized = {"name" : "RootChain" },
82- inputs = {"question" : "hello" },
83- run_id = root_run_id ,
84- tags = ["root-tag" ],
85- metadata = {"foo" : "bar" },
86- )
87- handler .on_chain_start (
88- serialized = {"name" : "ChildChain" },
89- inputs = {"question" : "child" },
90- run_id = child_run_id ,
91- parent_run_id = root_run_id ,
92- metadata = {"child" : "metadata" },
93- )
94- handler .on_chain_end (
95- {"answer" : "child" },
96- run_id = child_run_id ,
97- parent_run_id = root_run_id ,
98- )
99- handler .on_chain_end ({"answer" : "root" }, run_id = root_run_id )
100-
101- root_observation , child_observation = fake_client .started_observations
102-
103- assert root_observation .metadata == {
104- "foo" : "bar" ,
105- "tags" : ["root-tag" ],
106- "is_langchain_root" : True ,
107- }
108- assert child_observation .metadata == {"child" : "metadata" }
109-
110-
111- def test_root_langchain_llm_sets_is_langchain_root_metadata (monkeypatch ):
112- fake_client = _FakeLangchainClient ()
113- _patch_langchain_client (monkeypatch , fake_client )
114- handler = CallbackHandler ()
115-
116- root_run_id = uuid4 ()
117-
118- handler .on_llm_start (
119- serialized = {"name" : "ChatOpenAI" , "id" : ["langchain" , "ChatOpenAI" ]},
120- prompts = ["hello" ],
121- run_id = root_run_id ,
122- invocation_params = {"model_name" : "gpt-4o-mini" },
123- )
124- handler ._detach_observation (root_run_id )
125-
126- assert fake_client .started_observations [0 ].metadata == {"is_langchain_root" : True }
127-
128-
12924def test_callback_generated_from_trace_chat ():
13025 langfuse = Langfuse ()
13126
@@ -172,6 +67,7 @@ def test_callback_generated_from_trace_chat():
17267 assert langchain_generation_span .input != ""
17368 assert langchain_generation_span .output is not None
17469 assert langchain_generation_span .output != ""
70+ assert langchain_generation_span .metadata ["is_langchain_root" ] is True
17571
17672
17773def test_callback_generated_from_lcel_chain ():
@@ -208,6 +104,11 @@ def test_callback_generated_from_lcel_chain():
208104 trace .observations ,
209105 )
210106 )[0 ]
107+ langchain_root_spans = [
108+ observation
109+ for observation in trace .observations
110+ if observation .metadata and observation .metadata .get ("is_langchain_root" )
111+ ]
211112
212113 assert langchain_generation_span .usage_details ["input" ] > 1
213114 assert langchain_generation_span .usage_details ["output" ] > 0
@@ -216,6 +117,9 @@ def test_callback_generated_from_lcel_chain():
216117 assert langchain_generation_span .input != ""
217118 assert langchain_generation_span .output is not None
218119 assert langchain_generation_span .output != ""
120+ assert len (langchain_root_spans ) == 1
121+ assert langchain_root_spans [0 ].type == "CHAIN"
122+ assert langchain_root_spans [0 ].metadata ["is_langchain_root" ] is True
219123
220124
221125@pytest .mark .skip (reason = "Flaky" )
0 commit comments