Skip to content

Commit 88f5a06

Browse files
fix: resolve regression on thought process display for ReAct Agent LLM
Thought process events were not displaying because LangchainProfilerHandler was not being invoked during LLM calls. Passing the handler via ainvoke() config was bypassed by _runnable_config, which was built once at graph construction time with no callbacks. Replace the stored _runnable_config attribute with _make_runnable_config(), which instantiates callback classes fresh on each LLM/tool call. Pass LangchainProfilerHandler as a class reference to the graph constructor so each invocation gets an isolated handler, also fixing a memory leak and concurrency issues from the previously shared instance. Signed-off-by: Patrick Chin <8509935+thepatrickchin@users.noreply.github.com>
1 parent 27797af commit 88f5a06

4 files changed

Lines changed: 23 additions & 17 deletions

File tree

packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/base.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
from abc import ABC
2020
from abc import abstractmethod
21+
from collections.abc import Callable
2122
from enum import Enum
2223
from typing import Any
2324

@@ -71,7 +72,7 @@ class BaseAgent(ABC):
7172
def __init__(self,
7273
llm: BaseChatModel,
7374
tools: list[BaseTool],
74-
callbacks: list[AsyncCallbackHandler] | None = None,
75+
callbacks: list[Callable[[], AsyncCallbackHandler]] | None = None,
7576
detailed_logs: bool = False,
7677
log_response_max_chars: int = 1000) -> None:
7778
logger.debug("Initializing Agent Graph")
@@ -81,8 +82,17 @@ def __init__(self,
8182
self.detailed_logs = detailed_logs
8283
self.log_response_max_chars = log_response_max_chars
8384
self.graph = None
84-
self._runnable_config = RunnableConfig(callbacks=self.callbacks,
85-
configurable={"__pregel_runtime": DEFAULT_RUNTIME})
85+
86+
@property
87+
def _runnable_config(self) -> RunnableConfig:
88+
return self._make_runnable_config()
89+
90+
def _make_runnable_config(self) -> RunnableConfig:
91+
"""
92+
Create a fresh RunnableConfig with isolated callback instances per invocation.
93+
"""
94+
return RunnableConfig(callbacks=[c() for c in self.callbacks],
95+
configurable={"__pregel_runtime": DEFAULT_RUNTIME})
8696

8797
async def _stream_llm(self, runnable: Any, inputs: dict[str, Any]) -> AIMessage:
8898
"""

packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/dual_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
from abc import abstractmethod
18+
from collections.abc import Callable
1819

1920
from langchain_core.callbacks import AsyncCallbackHandler
2021
from langchain_core.language_models import BaseChatModel
@@ -34,7 +35,7 @@ class DualNodeAgent(BaseAgent):
3435
def __init__(self,
3536
llm: BaseChatModel,
3637
tools: list[BaseTool],
37-
callbacks: list[AsyncCallbackHandler] | None = None,
38+
callbacks: list[Callable[[], AsyncCallbackHandler]] | None = None,
3839
detailed_logs: bool = False,
3940
log_response_max_chars: int = 1000):
4041
super().__init__(llm=llm,

packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/react_agent/register.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
117117
llm=llm,
118118
prompt=prompt,
119119
tools=tools,
120-
callbacks=[],
120+
callbacks=[LangchainProfilerHandler],
121121
use_tool_schema=config.include_tool_input_schema_in_tool_description,
122122
detailed_logs=config.verbose,
123123
log_response_max_chars=config.log_response_max_chars,
@@ -154,12 +154,8 @@ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatRes
154154

155155
state = ReActGraphState(messages=messages)
156156

157-
# run the ReAct Agent Graph with a new callback handler instance per request
158-
state = await graph.ainvoke(state,
159-
config={
160-
'recursion_limit': (config.max_tool_calls + 1) * 2,
161-
'callbacks': [LangchainProfilerHandler()]
162-
})
157+
# run the ReAct Agent Graph
158+
state = await graph.ainvoke(state, config={'recursion_limit': (config.max_tool_calls + 1) * 2})
163159
# setting recursion_limit: 4 allows 1 tool call
164160
# - allows the ReAct Agent to perform 1 cycle / call 1 single tool,
165161
# - but stops the agent when it tries to call a tool a second time

packages/nvidia_nat_langchain/tests/agent/test_base.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from langchain_core.messages import AIMessage
2323
from langchain_core.messages import HumanMessage
2424
from langchain_core.messages import ToolMessage
25-
from langchain_core.runnables import RunnableConfig
2625
from langgraph.graph.state import CompiledStateGraph
2726

2827
from nat.plugins.langchain.agent.base import BaseAgent
@@ -40,21 +39,21 @@ def __init__(self, detailed_logs=True, log_response_max_chars=1000):
4039
self.callbacks = []
4140
self.detailed_logs = detailed_logs
4241
self.log_response_max_chars = log_response_max_chars
43-
self._runnable_config = RunnableConfig()
42+
self.graph = None
4443

4544
async def _build_graph(self, state_schema: type) -> CompiledStateGraph:
4645
"""Mock implementation."""
4746
return Mock(spec=CompiledStateGraph)
4847

4948

50-
@pytest.fixture
51-
def base_agent():
49+
@pytest.fixture(name="base_agent")
50+
def fixture_base_agent():
5251
"""Create a mock agent for testing with detailed logs enabled."""
5352
return MockBaseAgent(detailed_logs=True)
5453

5554

56-
@pytest.fixture
57-
def base_agent_no_logs():
55+
@pytest.fixture(name="base_agent_no_logs")
56+
def fixture_base_agent_no_logs():
5857
"""Create a mock agent for testing with detailed logs disabled."""
5958
return MockBaseAgent(detailed_logs=False)
6059

0 commit comments

Comments
 (0)