33import pydantic
44
55from langfuse ._client .get_client import get_client
6- from langfuse ._client .span import LangfuseGeneration , LangfuseSpan
6+ from langfuse ._client .span import (
7+ LangfuseGeneration ,
8+ LangfuseSpan ,
9+ LangfuseAgent ,
10+ LangfuseChain ,
11+ LangfuseTool ,
12+ LangfuseRetriever ,
13+ LangfuseObservationWrapper ,
14+ )
715from langfuse .logger import langfuse_logger
816
917try :
@@ -59,7 +67,17 @@ class LangchainCallbackHandler(LangchainBaseCallbackHandler):
5967 def __init__ (self , * , public_key : Optional [str ] = None ) -> None :
6068 self .client = get_client (public_key = public_key )
6169
62- self .runs : Dict [UUID , Union [LangfuseSpan , LangfuseGeneration ]] = {}
70+ self .runs : Dict [
71+ UUID ,
72+ Union [
73+ LangfuseSpan ,
74+ LangfuseGeneration ,
75+ LangfuseAgent ,
76+ LangfuseChain ,
77+ LangfuseTool ,
78+ LangfuseRetriever ,
79+ ],
80+ ] = {}
6381 self .prompt_to_parent_run_map : Dict [UUID , Any ] = {}
6482 self .updated_completion_start_time_memo : Set [UUID ] = set ()
6583
@@ -87,6 +105,49 @@ def on_llm_new_token(
87105
88106 self .updated_completion_start_time_memo .add (run_id )
89107
108+ def _get_observation_type_from_serialized (
109+ self , serialized : Optional [Dict [str , Any ]], callback_type : str , ** kwargs : Any
110+ ) -> Union [
111+ Literal ["tool" ],
112+ Literal ["retriever" ],
113+ Literal ["generation" ],
114+ Literal ["agent" ],
115+ Literal ["chain" ],
116+ Literal ["span" ],
117+ ]:
118+ """Determine Langfuse observation type from LangChain component.
119+
120+ Args:
121+ serialized: LangChain's serialized component dict
122+ callback_type: The type of callback (e.g., "chain", "tool", "retriever", "llm")
123+ **kwargs: Additional keyword arguments from the callback
124+
125+ Returns:
126+ The appropriate Langfuse observation type string
127+ """
128+ # Direct mappings based on callback type
129+ if callback_type == "tool" :
130+ return "tool"
131+ elif callback_type == "retriever" :
132+ return "retriever"
133+ elif callback_type == "llm" :
134+ return "generation"
135+ elif callback_type == "chain" :
136+ # Detect if it's an agent by examining class path or name
137+ if serialized and "id" in serialized :
138+ class_path = serialized ["id" ]
139+ if any ("agent" in part .lower () for part in class_path ):
140+ return "agent"
141+
142+ # Check name for agent-related keywords
143+ name = self .get_langchain_run_name (serialized , ** kwargs )
144+ if "agent" in name .lower ():
145+ return "agent"
146+
147+ return "chain"
148+
149+ return "span"
150+
90151 def get_langchain_run_name (
91152 self , serialized : Optional [Dict [str , Any ]], ** kwargs : Any
92153 ) -> str :
@@ -196,9 +257,14 @@ def on_chain_start(
196257 span_metadata = self .__join_tags_and_metadata (tags , metadata )
197258 span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
198259
260+ observation_type = self ._get_observation_type_from_serialized (
261+ serialized , "chain" , ** kwargs
262+ )
263+
199264 if parent_run_id is None :
200- span = self .client .start_span (
265+ span = self .client .start_observation (
201266 name = span_name ,
267+ as_type = observation_type ,
202268 metadata = span_metadata ,
203269 input = inputs ,
204270 level = cast (
@@ -212,9 +278,12 @@ def on_chain_start(
212278 self .runs [run_id ] = span
213279 else :
214280 self .runs [run_id ] = cast (
215- LangfuseSpan , self .runs [parent_run_id ]
216- ).start_span (
281+ # TODO: make this more precise (can be chain or agent here)
282+ LangfuseObservationWrapper ,
283+ self .runs [parent_run_id ],
284+ ).start_observation (
217285 name = span_name ,
286+ as_type = observation_type ,
218287 metadata = span_metadata ,
219288 input = inputs ,
220289 level = cast (
@@ -442,8 +511,6 @@ def on_tool_start(
442511 "on_tool_start" , run_id , parent_run_id , input_str = input_str
443512 )
444513
445- if parent_run_id is None or parent_run_id not in self .runs :
446- raise Exception ("parent run not found" )
447514 meta = self .__join_tags_and_metadata (tags , metadata )
448515
449516 if not meta :
@@ -453,13 +520,31 @@ def on_tool_start(
453520 {key : value for key , value in kwargs .items () if value is not None }
454521 )
455522
456- self .runs [run_id ] = cast (LangfuseSpan , self .runs [parent_run_id ]).start_span (
457- name = self .get_langchain_run_name (serialized , ** kwargs ),
458- input = input_str ,
459- metadata = meta ,
460- level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None ,
523+ observation_type = self ._get_observation_type_from_serialized (
524+ serialized , "tool" , ** kwargs
461525 )
462526
527+ if parent_run_id is None or parent_run_id not in self .runs :
528+ # Create root observation for direct tool calls
529+ self .runs [run_id ] = self .client .start_observation (
530+ name = self .get_langchain_run_name (serialized , ** kwargs ),
531+ as_type = observation_type ,
532+ input = input_str ,
533+ metadata = meta ,
534+ level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None ,
535+ )
536+ else :
537+ # Create child observation for tools within chains/agents
538+ self .runs [run_id ] = cast (
539+ LangfuseChain , self .runs [parent_run_id ]
540+ ).start_observation (
541+ name = self .get_langchain_run_name (serialized , ** kwargs ),
542+ as_type = observation_type ,
543+ input = input_str ,
544+ metadata = meta ,
545+ level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None ,
546+ )
547+
463548 except Exception as e :
464549 langfuse_logger .exception (e )
465550
@@ -482,9 +567,14 @@ def on_retriever_start(
482567 span_metadata = self .__join_tags_and_metadata (tags , metadata )
483568 span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
484569
570+ observation_type = self ._get_observation_type_from_serialized (
571+ serialized , "retriever" , ** kwargs
572+ )
573+
485574 if parent_run_id is None :
486- self .runs [run_id ] = self .client .start_span (
575+ self .runs [run_id ] = self .client .start_observation (
487576 name = span_name ,
577+ as_type = observation_type ,
488578 metadata = span_metadata ,
489579 input = query ,
490580 level = cast (
@@ -494,9 +584,10 @@ def on_retriever_start(
494584 )
495585 else :
496586 self .runs [run_id ] = cast (
497- LangfuseSpan , self .runs [parent_run_id ]
498- ).start_span (
587+ LangfuseRetriever , self .runs [parent_run_id ]
588+ ).start_observation (
499589 name = span_name ,
590+ as_type = observation_type ,
500591 input = query ,
501592 metadata = span_metadata ,
502593 level = cast (
@@ -625,10 +716,12 @@ def __on_llm_action(
625716
626717 if parent_run_id is not None and parent_run_id in self .runs :
627718 self .runs [run_id ] = cast (
628- LangfuseSpan , self .runs [parent_run_id ]
629- ).start_generation ( ** content ) # type: ignore
719+ LangfuseGeneration , self .runs [parent_run_id ]
720+ ).start_observation ( as_type = "generation" , ** content ) # type: ignore
630721 else :
631- self .runs [run_id ] = self .client .start_generation (** content ) # type: ignore
722+ self .runs [run_id ] = self .client .start_observation (
723+ as_type = "generation" , ** content
724+ ) # type: ignore
632725
633726 self .last_trace_id = self .runs [run_id ].trace_id
634727
0 commit comments