Skip to content

Commit ccda750

Browse files
committed
simplify LangChain root metadata tests
1 parent 03e0633 commit ccda750

1 file changed

Lines changed: 9 additions & 105 deletions

File tree

tests/test_langchain.py

Lines changed: 9 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import importlib
21
import random
32
import string
43
import time
54
from time import sleep
65
from typing import Any, Dict, Literal
7-
from uuid import uuid4
86

97
import pytest
108
from langchain.messages import HumanMessage, SystemMessage
@@ -16,116 +14,13 @@
1614
from langgraph.checkpoint.memory import MemorySaver
1715
from langgraph.graph import END, START, MessagesState, StateGraph
1816
from langgraph.prebuilt import ToolNode
19-
from opentelemetry import trace as otel_trace
2017
from pydantic import BaseModel, Field
2118

2219
from langfuse._client.client import Langfuse
2320
from langfuse.langchain import CallbackHandler
2421
from tests.utils import create_uuid, encode_file_to_base64, get_api
2522

2623

27-
class _FakeLangchainObservation:
28-
def __init__(self, recorder, **kwargs):
29-
self._recorder = recorder
30-
self._otel_span = otel_trace.NonRecordingSpan(otel_trace.INVALID_SPAN_CONTEXT)
31-
self.trace_id = "test-trace-id"
32-
self.metadata = kwargs.get("metadata")
33-
self.name = kwargs.get("name")
34-
self.input = kwargs.get("input")
35-
self.as_type = kwargs.get("as_type")
36-
self.updates = []
37-
self.ended = False
38-
39-
def start_observation(self, **kwargs):
40-
return self._recorder.start_observation(**kwargs)
41-
42-
def update(self, **kwargs):
43-
self.updates.append(kwargs)
44-
return self
45-
46-
def end(self):
47-
self.ended = True
48-
return self
49-
50-
51-
class _FakeLangchainClient:
52-
def __init__(self):
53-
self.started_observations = []
54-
55-
def start_observation(self, **kwargs):
56-
observation = _FakeLangchainObservation(self, **kwargs)
57-
self.started_observations.append(observation)
58-
return observation
59-
60-
61-
def _patch_langchain_client(monkeypatch, fake_client):
62-
callback_handler_module = importlib.import_module(
63-
"langfuse.langchain.CallbackHandler"
64-
)
65-
monkeypatch.setattr(
66-
callback_handler_module,
67-
"get_client",
68-
lambda public_key=None: fake_client,
69-
)
70-
71-
72-
def test_root_langchain_chain_sets_is_langchain_root_metadata(monkeypatch):
73-
fake_client = _FakeLangchainClient()
74-
_patch_langchain_client(monkeypatch, fake_client)
75-
handler = CallbackHandler()
76-
77-
root_run_id = uuid4()
78-
child_run_id = uuid4()
79-
80-
handler.on_chain_start(
81-
serialized={"name": "RootChain"},
82-
inputs={"question": "hello"},
83-
run_id=root_run_id,
84-
tags=["root-tag"],
85-
metadata={"foo": "bar"},
86-
)
87-
handler.on_chain_start(
88-
serialized={"name": "ChildChain"},
89-
inputs={"question": "child"},
90-
run_id=child_run_id,
91-
parent_run_id=root_run_id,
92-
metadata={"child": "metadata"},
93-
)
94-
handler.on_chain_end(
95-
{"answer": "child"},
96-
run_id=child_run_id,
97-
parent_run_id=root_run_id,
98-
)
99-
handler.on_chain_end({"answer": "root"}, run_id=root_run_id)
100-
101-
root_observation, child_observation = fake_client.started_observations
102-
103-
assert root_observation.metadata == {
104-
"foo": "bar",
105-
"tags": ["root-tag"],
106-
"is_langchain_root": True,
107-
}
108-
assert child_observation.metadata == {"child": "metadata"}
109-
110-
111-
def test_root_langchain_llm_sets_is_langchain_root_metadata(monkeypatch):
112-
fake_client = _FakeLangchainClient()
113-
_patch_langchain_client(monkeypatch, fake_client)
114-
handler = CallbackHandler()
115-
116-
root_run_id = uuid4()
117-
118-
handler.on_llm_start(
119-
serialized={"name": "ChatOpenAI", "id": ["langchain", "ChatOpenAI"]},
120-
prompts=["hello"],
121-
run_id=root_run_id,
122-
invocation_params={"model_name": "gpt-4o-mini"},
123-
)
124-
handler._detach_observation(root_run_id)
125-
126-
assert fake_client.started_observations[0].metadata == {"is_langchain_root": True}
127-
128-
12924
def test_callback_generated_from_trace_chat():
13025
langfuse = Langfuse()
13126

@@ -172,6 +67,7 @@ def test_callback_generated_from_trace_chat():
17267
assert langchain_generation_span.input != ""
17368
assert langchain_generation_span.output is not None
17469
assert langchain_generation_span.output != ""
70+
assert langchain_generation_span.metadata["is_langchain_root"] is True
17571

17672

17773
def test_callback_generated_from_lcel_chain():
@@ -208,6 +104,11 @@ def test_callback_generated_from_lcel_chain():
208104
trace.observations,
209105
)
210106
)[0]
107+
langchain_root_spans = [
108+
observation
109+
for observation in trace.observations
110+
if observation.metadata and observation.metadata.get("is_langchain_root")
111+
]
211112

212113
assert langchain_generation_span.usage_details["input"] > 1
213114
assert langchain_generation_span.usage_details["output"] > 0
@@ -216,6 +117,9 @@ def test_callback_generated_from_lcel_chain():
216117
assert langchain_generation_span.input != ""
217118
assert langchain_generation_span.output is not None
218119
assert langchain_generation_span.output != ""
120+
assert len(langchain_root_spans) == 1
121+
assert langchain_root_spans[0].type == "CHAIN"
122+
assert langchain_root_spans[0].metadata["is_langchain_root"] is True
219123

220124

221125
@pytest.mark.skip(reason="Flaky")

0 commit comments

Comments
 (0)