Skip to content

Commit e537231

Browse files
committed
langchain
1 parent b6b1b01 commit e537231

3 files changed

Lines changed: 113 additions & 20 deletions

File tree

langfuse/langchain/CallbackHandler.py

Lines changed: 111 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
import pydantic
44

55
from langfuse._client.get_client import get_client
6-
from langfuse._client.span import LangfuseGeneration, LangfuseSpan
6+
from langfuse._client.span import (
7+
LangfuseGeneration,
8+
LangfuseSpan,
9+
LangfuseAgent,
10+
LangfuseChain,
11+
LangfuseTool,
12+
LangfuseRetriever,
13+
LangfuseObservationWrapper,
14+
)
715
from langfuse.logger import langfuse_logger
816

917
try:
@@ -59,7 +67,17 @@ class LangchainCallbackHandler(LangchainBaseCallbackHandler):
5967
def __init__(self, *, public_key: Optional[str] = None) -> None:
6068
self.client = get_client(public_key=public_key)
6169

62-
self.runs: Dict[UUID, Union[LangfuseSpan, LangfuseGeneration]] = {}
70+
self.runs: Dict[
71+
UUID,
72+
Union[
73+
LangfuseSpan,
74+
LangfuseGeneration,
75+
LangfuseAgent,
76+
LangfuseChain,
77+
LangfuseTool,
78+
LangfuseRetriever,
79+
],
80+
] = {}
6381
self.prompt_to_parent_run_map: Dict[UUID, Any] = {}
6482
self.updated_completion_start_time_memo: Set[UUID] = set()
6583

@@ -87,6 +105,49 @@ def on_llm_new_token(
87105

88106
self.updated_completion_start_time_memo.add(run_id)
89107

108+
def _get_observation_type_from_serialized(
109+
self, serialized: Optional[Dict[str, Any]], callback_type: str, **kwargs: Any
110+
) -> Union[
111+
Literal["tool"],
112+
Literal["retriever"],
113+
Literal["generation"],
114+
Literal["agent"],
115+
Literal["chain"],
116+
Literal["span"],
117+
]:
118+
"""Determine Langfuse observation type from LangChain component.
119+
120+
Args:
121+
serialized: LangChain's serialized component dict
122+
callback_type: The type of callback (e.g., "chain", "tool", "retriever", "llm")
123+
**kwargs: Additional keyword arguments from the callback
124+
125+
Returns:
126+
The appropriate Langfuse observation type string
127+
"""
128+
# Direct mappings based on callback type
129+
if callback_type == "tool":
130+
return "tool"
131+
elif callback_type == "retriever":
132+
return "retriever"
133+
elif callback_type == "llm":
134+
return "generation"
135+
elif callback_type == "chain":
136+
# Detect if it's an agent by examining class path or name
137+
if serialized and "id" in serialized:
138+
class_path = serialized["id"]
139+
if any("agent" in part.lower() for part in class_path):
140+
return "agent"
141+
142+
# Check name for agent-related keywords
143+
name = self.get_langchain_run_name(serialized, **kwargs)
144+
if "agent" in name.lower():
145+
return "agent"
146+
147+
return "chain"
148+
149+
return "span"
150+
90151
def get_langchain_run_name(
91152
self, serialized: Optional[Dict[str, Any]], **kwargs: Any
92153
) -> str:
@@ -196,9 +257,14 @@ def on_chain_start(
196257
span_metadata = self.__join_tags_and_metadata(tags, metadata)
197258
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
198259

260+
observation_type = self._get_observation_type_from_serialized(
261+
serialized, "chain", **kwargs
262+
)
263+
199264
if parent_run_id is None:
200-
span = self.client.start_span(
265+
span = self.client.start_observation(
201266
name=span_name,
267+
as_type=observation_type,
202268
metadata=span_metadata,
203269
input=inputs,
204270
level=cast(
@@ -212,9 +278,12 @@ def on_chain_start(
212278
self.runs[run_id] = span
213279
else:
214280
self.runs[run_id] = cast(
215-
LangfuseSpan, self.runs[parent_run_id]
216-
).start_span(
281+
# TODO: make this more precise (can be chain or agent here)
282+
LangfuseObservationWrapper,
283+
self.runs[parent_run_id],
284+
).start_observation(
217285
name=span_name,
286+
as_type=observation_type,
218287
metadata=span_metadata,
219288
input=inputs,
220289
level=cast(
@@ -442,8 +511,6 @@ def on_tool_start(
442511
"on_tool_start", run_id, parent_run_id, input_str=input_str
443512
)
444513

445-
if parent_run_id is None or parent_run_id not in self.runs:
446-
raise Exception("parent run not found")
447514
meta = self.__join_tags_and_metadata(tags, metadata)
448515

449516
if not meta:
@@ -453,13 +520,31 @@ def on_tool_start(
453520
{key: value for key, value in kwargs.items() if value is not None}
454521
)
455522

456-
self.runs[run_id] = cast(LangfuseSpan, self.runs[parent_run_id]).start_span(
457-
name=self.get_langchain_run_name(serialized, **kwargs),
458-
input=input_str,
459-
metadata=meta,
460-
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
523+
observation_type = self._get_observation_type_from_serialized(
524+
serialized, "tool", **kwargs
461525
)
462526

527+
if parent_run_id is None or parent_run_id not in self.runs:
528+
# Create root observation for direct tool calls
529+
self.runs[run_id] = self.client.start_observation(
530+
name=self.get_langchain_run_name(serialized, **kwargs),
531+
as_type=observation_type,
532+
input=input_str,
533+
metadata=meta,
534+
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
535+
)
536+
else:
537+
# Create child observation for tools within chains/agents
538+
self.runs[run_id] = cast(
539+
LangfuseChain, self.runs[parent_run_id]
540+
).start_observation(
541+
name=self.get_langchain_run_name(serialized, **kwargs),
542+
as_type=observation_type,
543+
input=input_str,
544+
metadata=meta,
545+
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
546+
)
547+
463548
except Exception as e:
464549
langfuse_logger.exception(e)
465550

@@ -482,9 +567,14 @@ def on_retriever_start(
482567
span_metadata = self.__join_tags_and_metadata(tags, metadata)
483568
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
484569

570+
observation_type = self._get_observation_type_from_serialized(
571+
serialized, "retriever", **kwargs
572+
)
573+
485574
if parent_run_id is None:
486-
self.runs[run_id] = self.client.start_span(
575+
self.runs[run_id] = self.client.start_observation(
487576
name=span_name,
577+
as_type=observation_type,
488578
metadata=span_metadata,
489579
input=query,
490580
level=cast(
@@ -494,9 +584,10 @@ def on_retriever_start(
494584
)
495585
else:
496586
self.runs[run_id] = cast(
497-
LangfuseSpan, self.runs[parent_run_id]
498-
).start_span(
587+
LangfuseRetriever, self.runs[parent_run_id]
588+
).start_observation(
499589
name=span_name,
590+
as_type=observation_type,
500591
input=query,
501592
metadata=span_metadata,
502593
level=cast(
@@ -625,10 +716,12 @@ def __on_llm_action(
625716

626717
if parent_run_id is not None and parent_run_id in self.runs:
627718
self.runs[run_id] = cast(
628-
LangfuseSpan, self.runs[parent_run_id]
629-
).start_generation(**content) # type: ignore
719+
LangfuseGeneration, self.runs[parent_run_id]
720+
).start_observation(as_type="generation", **content) # type: ignore
630721
else:
631-
self.runs[run_id] = self.client.start_generation(**content) # type: ignore
722+
self.runs[run_id] = self.client.start_observation(
723+
as_type="generation", **content
724+
) # type: ignore
632725

633726
self.last_trace_id = self.runs[run_id].trace_id
634727

tests/test_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def sorted_dependencies_from_trace(trace):
319319

320320
if len(sorted_observations) >= 2:
321321
assert sorted_observations[1].name == "RunnableSequence"
322-
assert sorted_observations[1].type == "SPAN"
322+
assert sorted_observations[1].type == "CHAIN"
323323
assert sorted_observations[1].input is not None
324324
assert sorted_observations[1].output is not None
325325
assert sorted_observations[1].input != ""

tests/test_langchain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_callback_generated_from_trace_chain():
6060

6161
langchain_span = list(
6262
filter(
63-
lambda o: o.type == "SPAN" and o.name == "LLMChain",
63+
lambda o: o.type == "CHAIN" and o.name == "LLMChain",
6464
trace.observations,
6565
)
6666
)[0]

0 commit comments

Comments
 (0)