Skip to content

Commit 9f4cafe

Browse files
committed
simplify observation creation
1 parent 2431192 commit 9f4cafe

1 file changed

Lines changed: 48 additions & 111 deletions

File tree

langfuse/langchain/CallbackHandler.py

Lines changed: 48 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -276,19 +276,19 @@ def on_chain_start(
276276
serialized, "chain", **kwargs
277277
)
278278

279-
if parent_run_id is None:
280-
span = self.client.start_observation(
281-
name=span_name,
282-
as_type=observation_type,
283-
metadata=span_metadata,
284-
input=inputs,
285-
level=cast(
286-
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
287-
span_level,
288-
),
289-
)
290-
self._attach_observation(run_id, span)
279+
span = self.client.start_observation(
280+
name=span_name,
281+
as_type=observation_type,
282+
metadata=span_metadata,
283+
input=inputs,
284+
level=cast(
285+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
286+
span_level,
287+
),
288+
)
289+
self._attach_observation(run_id, span)
291290

291+
if parent_run_id is None:
292292
span.update_trace(
293293
**(
294294
cast(
@@ -304,22 +304,6 @@ def on_chain_start(
304304
),
305305
**self._parse_langfuse_trace_attributes_from_metadata(metadata),
306306
)
307-
else:
308-
span = cast(
309-
LangfuseChain,
310-
self.runs[parent_run_id],
311-
).start_observation(
312-
name=span_name,
313-
as_type=observation_type,
314-
metadata=span_metadata,
315-
input=inputs,
316-
level=cast(
317-
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
318-
span_level,
319-
),
320-
)
321-
322-
self._attach_observation(run_id, span)
323307

324308
self.last_trace_id = self.runs[run_id].trace_id
325309

@@ -509,28 +493,22 @@ def on_chain_error(
509493
) -> None:
510494
try:
511495
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
512-
if run_id in self.runs:
513-
if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES):
514-
level = None
515-
else:
516-
level = "ERROR"
517-
518-
observation = self._detach_observation(run_id)
519-
520-
if observation is not None:
521-
observation.update(
522-
level=cast(
523-
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
524-
level,
525-
),
526-
status_message=str(error) if level else None,
527-
input=kwargs.get("inputs"),
528-
).end()
529-
496+
if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES):
497+
level = None
530498
else:
531-
langfuse_logger.warning(
532-
f"Run ID {run_id} already popped from run map. Could not update run with error message"
533-
)
499+
level = "ERROR"
500+
501+
observation = self._detach_observation(run_id)
502+
503+
if observation is not None:
504+
observation.update(
505+
level=cast(
506+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
507+
level,
508+
),
509+
status_message=str(error) if level else None,
510+
input=kwargs.get("inputs"),
511+
).end()
534512

535513
except Exception as e:
536514
langfuse_logger.exception(e)
@@ -623,29 +601,15 @@ def on_tool_start(
623601
serialized, "tool", **kwargs
624602
)
625603

626-
if parent_run_id is None or parent_run_id not in self.runs:
627-
# Create root observation for direct tool calls
628-
span = self.client.start_observation(
629-
name=self.get_langchain_run_name(serialized, **kwargs),
630-
as_type=observation_type,
631-
input=input_str,
632-
metadata=meta,
633-
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
634-
)
635-
636-
self._attach_observation(run_id, span)
637-
638-
else:
639-
# Create child observation for tools within chains/agents
640-
span = cast(LangfuseChain, self.runs[parent_run_id]).start_observation(
641-
name=self.get_langchain_run_name(serialized, **kwargs),
642-
as_type=observation_type,
643-
input=input_str,
644-
metadata=meta,
645-
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
646-
)
604+
span = self.client.start_observation(
605+
name=self.get_langchain_run_name(serialized, **kwargs),
606+
as_type=observation_type,
607+
input=input_str,
608+
metadata=meta,
609+
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
610+
)
647611

648-
self._attach_observation(run_id, span)
612+
self._attach_observation(run_id, span)
649613

650614
except Exception as e:
651615
langfuse_logger.exception(e)
@@ -673,35 +637,18 @@ def on_retriever_start(
673637
serialized, "retriever", **kwargs
674638
)
675639

676-
if parent_run_id is None:
677-
span = self.client.start_observation(
678-
name=span_name,
679-
as_type=observation_type,
680-
metadata=span_metadata,
681-
input=query,
682-
level=cast(
683-
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
684-
span_level,
685-
),
686-
)
687-
688-
self._attach_observation(run_id, span)
689-
690-
else:
691-
span = cast(
692-
LangfuseRetriever, self.runs[parent_run_id]
693-
).start_observation(
694-
name=span_name,
695-
as_type=observation_type,
696-
input=query,
697-
metadata=span_metadata,
698-
level=cast(
699-
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
700-
span_level,
701-
),
702-
)
640+
span = self.client.start_observation(
641+
name=span_name,
642+
as_type=observation_type,
643+
metadata=span_metadata,
644+
input=query,
645+
level=cast(
646+
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
647+
span_level,
648+
),
649+
)
703650

704-
self._attach_observation(run_id, span)
651+
self._attach_observation(run_id, span)
705652

706653
except Exception as e:
707654
langfuse_logger.exception(e)
@@ -816,18 +763,8 @@ def __on_llm_action(
816763
"prompt": registered_prompt,
817764
}
818765

819-
if parent_run_id is not None and parent_run_id in self.runs:
820-
generation = cast(
821-
LangfuseGeneration, self.runs[parent_run_id]
822-
).start_observation(as_type="generation", **content) # type: ignore
823-
824-
self._attach_observation(run_id, generation)
825-
else:
826-
generation = self.client.start_observation(
827-
as_type="generation", **content
828-
) # type: ignore
829-
830-
self._attach_observation(run_id, generation)
766+
generation = self.client.start_observation(as_type="generation", **content)
767+
self._attach_observation(run_id, generation)
831768

832769
self.last_trace_id = self.runs[run_id].trace_id
833770

0 commit comments

Comments
 (0)