Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions langfuse/langchain/CallbackHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,28 @@ def _parse_langfuse_trace_attributes(

return attributes

def _get_langchain_observation_metadata(
self,
*,
parent_run_id: Optional[UUID],
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
keep_langfuse_trace_attributes: bool = False,
) -> Optional[Dict[str, Any]]:
observation_metadata = self.__join_tags_and_metadata(
tags=tags,
metadata=metadata,
keep_langfuse_trace_attributes=keep_langfuse_trace_attributes,
)

if parent_run_id is not None:
return observation_metadata

root_metadata = observation_metadata.copy() if observation_metadata else {}
root_metadata["is_langchain_root"] = True

return root_metadata

def on_chain_start(
self,
serialized: Optional[Dict[str, Any]],
Expand All @@ -325,7 +347,11 @@ def on_chain_start(
)

span_name = self.get_langchain_run_name(serialized, **kwargs)
span_metadata = self.__join_tags_and_metadata(tags, metadata)
span_metadata = self._get_langchain_observation_metadata(
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
)
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None

observation_type = self._get_observation_type_from_serialized(
Expand Down Expand Up @@ -690,7 +716,11 @@ def on_tool_start(
"on_tool_start", run_id, parent_run_id, input_str=input_str
)

meta = self.__join_tags_and_metadata(tags, metadata)
meta = self._get_langchain_observation_metadata(
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
)

if not meta:
meta = {}
Expand Down Expand Up @@ -734,7 +764,11 @@ def on_retriever_start(
"on_retriever_start", run_id, parent_run_id, query=query
)
span_name = self.get_langchain_run_name(serialized, **kwargs)
span_metadata = self.__join_tags_and_metadata(tags, metadata)
span_metadata = self._get_langchain_observation_metadata(
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
)
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None

observation_type = self._get_observation_type_from_serialized(
Expand Down Expand Up @@ -865,9 +899,10 @@ def __on_llm_action(
content = {
"name": self.get_langchain_run_name(serialized, **kwargs),
"input": prompts,
"metadata": self.__join_tags_and_metadata(
tags,
metadata,
"metadata": self._get_langchain_observation_metadata(
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
# If llm is run isolated and outside chain, keep trace attributes
keep_langfuse_trace_attributes=True
if parent_run_id is None
Expand Down
9 changes: 9 additions & 0 deletions tests/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_callback_generated_from_trace_chat():
assert langchain_generation_span.input != ""
assert langchain_generation_span.output is not None
assert langchain_generation_span.output != ""
assert langchain_generation_span.metadata["is_langchain_root"] is True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 The test test_callback_generated_from_trace_chat incorrectly asserts is_langchain_root on the GENERATION span rather than the CHAIN wrapper that is the actual LangChain root. When ChatOpenAI.invoke() is called, LangChain fires on_chain_start with parent_run_id=None first (creating a CHAIN wrapper), then on_chat_model_start with a non-None parent_run_id, so the flag lands on the CHAIN, not the GENERATION. The test should filter for the root CHAIN span (as test_callback_generated_from_lcel_chain correctly does) rather than asserting is_langchain_root on the GENERATION span.

Extended reasoning...

What the bug is and how it manifests

Line 70 asserts langchain_generation_span.metadata["is_langchain_root"] is True on the ChatOpenAI GENERATION span. The _get_langchain_observation_metadata helper only sets is_langchain_root=True when parent_run_id is None. When ChatOpenAI.invoke() is called, LangChain fires two callbacks: on_chain_start (with parent_run_id=None, creating a CHAIN wrapper as the LangChain root) and then on_chat_model_start (with parent_run_id=<chain_run_id>, a non-None value). So the CHAIN wrapper gets is_langchain_root=True, while the GENERATION does not.

The specific code path that triggers it

In __on_llm_action, the metadata is set via self._get_langchain_observation_metadata(parent_run_id=parent_run_id, ...). When on_chat_model_start fires for a ChatOpenAI.invoke() call, parent_run_id equals the run ID of the previously created CHAIN wrapper (not None). The helper function reaches if parent_run_id is not None: return observation_metadata without setting the is_langchain_root key.

Why existing evidence confirms this, not refutes it

The test itself already asserts len(trace.observations) == 3 (line 54), which means there are exactly 3 observations: (1) the Langfuse parent from start_as_current_observation, (2) a CHAIN wrapper created by on_chain_start, and (3) a GENERATION created by on_chat_model_start. If LangChain only fired on_chat_model_start with parent_run_id=None (as refuters claim), there would be only 2 observations total, contradicting line 54. This count of 3 is also independently confirmed by test_multimodal, which similarly calls model.invoke() directly and asserts len(trace.observations) == 3. Furthermore, git history shows commit 840cf2a explicitly changed the observation count in this test from 2 to 3, documenting the LangChain behavioral change where on_chain_start now fires before on_chat_model_start even for direct model invocations.

Step-by-step proof

  1. chat.invoke(messages, config={"callbacks": [handler]}) is called.
  2. LangChain fires on_chain_start(run_id=UUID_A, parent_run_id=None)_get_langchain_observation_metadata sees parent_run_id is None, sets is_langchain_root=True on the CHAIN observation.
  3. LangChain fires on_chat_model_start(run_id=UUID_B, parent_run_id=UUID_A)_get_langchain_observation_metadata sees parent_run_id is not None, returns metadata without is_langchain_root.
  4. The test filters for o.type == "GENERATION" and o.name == "ChatOpenAI", finding observation UUID_B.
  5. langchain_generation_span.metadata["is_langchain_root"] is either missing or None, causing the assertion on line 70 to fail (either KeyError or AssertionError).

How to fix it

Mirror the pattern used in test_callback_generated_from_lcel_chain: filter all observations for those with observation.metadata.get("is_langchain_root") and assert the single result has type == "CHAIN". Remove the incorrect line 70 assertion from test_callback_generated_from_trace_chat and replace it with a check that the CHAIN wrapper (not the GENERATION) carries the flag.



def test_callback_generated_from_lcel_chain():
Expand Down Expand Up @@ -103,6 +104,11 @@ def test_callback_generated_from_lcel_chain():
trace.observations,
)
)[0]
langchain_root_spans = [
observation
for observation in trace.observations
if observation.metadata and observation.metadata.get("is_langchain_root")
]

assert langchain_generation_span.usage_details["input"] > 1
assert langchain_generation_span.usage_details["output"] > 0
Expand All @@ -111,6 +117,9 @@ def test_callback_generated_from_lcel_chain():
assert langchain_generation_span.input != ""
assert langchain_generation_span.output is not None
assert langchain_generation_span.output != ""
assert len(langchain_root_spans) == 1
assert langchain_root_spans[0].type == "CHAIN"
assert langchain_root_spans[0].metadata["is_langchain_root"] is True


@pytest.mark.skip(reason="Flaky")
Expand Down
Loading