diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 5b5dfe691..8d2c8db90 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -303,6 +303,28 @@ def _parse_langfuse_trace_attributes( return attributes + def _get_langchain_observation_metadata( + self, + *, + parent_run_id: Optional[UUID], + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + keep_langfuse_trace_attributes: bool = False, + ) -> Optional[Dict[str, Any]]: + observation_metadata = self.__join_tags_and_metadata( + tags=tags, + metadata=metadata, + keep_langfuse_trace_attributes=keep_langfuse_trace_attributes, + ) + + if parent_run_id is not None: + return observation_metadata + + root_metadata = observation_metadata.copy() if observation_metadata else {} + root_metadata["is_langchain_root"] = True + + return root_metadata + def on_chain_start( self, serialized: Optional[Dict[str, Any]], @@ -325,7 +347,11 @@ def on_chain_start( ) span_name = self.get_langchain_run_name(serialized, **kwargs) - span_metadata = self.__join_tags_and_metadata(tags, metadata) + span_metadata = self._get_langchain_observation_metadata( + parent_run_id=parent_run_id, + tags=tags, + metadata=metadata, + ) span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None observation_type = self._get_observation_type_from_serialized( @@ -690,7 +716,11 @@ def on_tool_start( "on_tool_start", run_id, parent_run_id, input_str=input_str ) - meta = self.__join_tags_and_metadata(tags, metadata) + meta = self._get_langchain_observation_metadata( + parent_run_id=parent_run_id, + tags=tags, + metadata=metadata, + ) if not meta: meta = {} @@ -734,7 +764,11 @@ def on_retriever_start( "on_retriever_start", run_id, parent_run_id, query=query ) span_name = self.get_langchain_run_name(serialized, **kwargs) - span_metadata = self.__join_tags_and_metadata(tags, metadata) + span_metadata = self._get_langchain_observation_metadata( + parent_run_id=parent_run_id, + tags=tags, + metadata=metadata, + ) span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None observation_type = self._get_observation_type_from_serialized( @@ -865,9 +899,10 @@ def __on_llm_action( content = { "name": self.get_langchain_run_name(serialized, **kwargs), "input": prompts, - "metadata": self.__join_tags_and_metadata( - tags, - metadata, + "metadata": self._get_langchain_observation_metadata( + parent_run_id=parent_run_id, + tags=tags, + metadata=metadata, # If llm is run isolated and outside chain, keep trace attributes keep_langfuse_trace_attributes=True if parent_run_id is None diff --git a/tests/test_langchain.py b/tests/test_langchain.py index 6c9d3eb4d..fa2bcfddb 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -67,6 +67,7 @@ def test_callback_generated_from_trace_chat(): assert langchain_generation_span.input != "" assert langchain_generation_span.output is not None assert langchain_generation_span.output != "" + assert langchain_generation_span.metadata["is_langchain_root"] is True def test_callback_generated_from_lcel_chain(): @@ -103,6 +104,11 @@ def test_callback_generated_from_lcel_chain(): trace.observations, ) )[0] + langchain_root_spans = [ + observation + for observation in trace.observations + if observation.metadata and observation.metadata.get("is_langchain_root") + ] assert langchain_generation_span.usage_details["input"] > 1 assert langchain_generation_span.usage_details["output"] > 0 @@ -111,6 +117,9 @@ def test_callback_generated_from_lcel_chain(): assert langchain_generation_span.input != "" assert langchain_generation_span.output is not None assert langchain_generation_span.output != "" + assert len(langchain_root_spans) == 1 + assert langchain_root_spans[0].type == "CHAIN" + assert langchain_root_spans[0].metadata["is_langchain_root"] is True @pytest.mark.skip(reason="Flaky")