Skip to content

Commit 4953c85

Browse files
fix: resolve regression on thought process display for ReAct Agent LLM
Thought process events were not displaying because LangchainProfilerHandler was not being invoked during LLM calls. Passing the handler via ainvoke() config was bypassed by _runnable_config, which was built once at graph construction time with no callbacks. Replace the stored _runnable_config attribute with _make_runnable_config(), which instantiates callback classes fresh on each LLM/tool call. Pass LangchainProfilerHandler as a class reference to the graph constructor so each invocation gets an isolated handler, also fixing a memory leak and concurrency issues from the previously shared instance. Signed-off-by: Patrick Chin <8509935+thepatrickchin@users.noreply.github.com>
1 parent 27797af commit 4953c85

3 files changed

Lines changed: 17 additions & 15 deletions

File tree

packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/base.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
from abc import ABC
2020
from abc import abstractmethod
21+
from collections.abc import Callable
2122
from enum import Enum
2223
from typing import Any
2324

@@ -71,7 +72,7 @@ class BaseAgent(ABC):
7172
def __init__(self,
7273
llm: BaseChatModel,
7374
tools: list[BaseTool],
74-
callbacks: list[AsyncCallbackHandler] | None = None,
75+
callbacks: list[Callable[[], AsyncCallbackHandler]] | None = None,
7576
detailed_logs: bool = False,
7677
log_response_max_chars: int = 1000) -> None:
7778
logger.debug("Initializing Agent Graph")
@@ -81,8 +82,14 @@ def __init__(self,
8182
self.detailed_logs = detailed_logs
8283
self.log_response_max_chars = log_response_max_chars
8384
self.graph = None
84-
self._runnable_config = RunnableConfig(callbacks=self.callbacks,
85-
configurable={"__pregel_runtime": DEFAULT_RUNTIME})
85+
86+
@property
87+
def _runnable_config(self) -> RunnableConfig:
88+
return self._make_runnable_config()
89+
90+
def _make_runnable_config(self) -> RunnableConfig:
91+
return RunnableConfig(callbacks=[c() for c in self.callbacks],
92+
configurable={"__pregel_runtime": DEFAULT_RUNTIME})
8693

8794
async def _stream_llm(self, runnable: Any, inputs: dict[str, Any]) -> AIMessage:
8895
"""
@@ -102,7 +109,7 @@ async def _stream_llm(self, runnable: Any, inputs: dict[str, Any]) -> AIMessage:
102109
"""
103110
content_parts = []
104111
reasoning_parts = []
105-
async for event in runnable.astream(inputs, config=self._runnable_config):
112+
async for event in runnable.astream(inputs, config=self._make_runnable_config()):
106113
content_parts.append(event.content)
107114
extra = getattr(event, 'additional_kwargs', None)
108115
if isinstance(extra, dict):
@@ -132,7 +139,7 @@ async def _call_llm(self, llm: Runnable, inputs: dict[str, Any]) -> AIMessage:
132139
AIMessage
133140
The LLM response
134141
"""
135-
response = await llm.ainvoke(inputs, config=self._runnable_config)
142+
response = await llm.ainvoke(inputs, config=self._make_runnable_config())
136143
return AIMessage(content=str(response.content))
137144

138145
async def _call_tool(self, tool: BaseTool, tool_input: dict[str, Any] | str, max_retries: int = 3) -> ToolMessage:
@@ -157,7 +164,7 @@ async def _call_tool(self, tool: BaseTool, tool_input: dict[str, Any] | str, max
157164

158165
for attempt in range(1, max_retries + 1):
159166
try:
160-
response = await tool.ainvoke(tool_input, config=self._runnable_config)
167+
response = await tool.ainvoke(tool_input, config=self._make_runnable_config())
161168

162169
# Handle empty responses
163170
if response is None or (isinstance(response, str) and response == ""):

packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/react_agent/register.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
117117
llm=llm,
118118
prompt=prompt,
119119
tools=tools,
120-
callbacks=[],
120+
callbacks=[LangchainProfilerHandler],
121121
use_tool_schema=config.include_tool_input_schema_in_tool_description,
122122
detailed_logs=config.verbose,
123123
log_response_max_chars=config.log_response_max_chars,
@@ -154,12 +154,8 @@ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatRes
154154

155155
state = ReActGraphState(messages=messages)
156156

157-
# run the ReAct Agent Graph with a new callback handler instance per request
158-
state = await graph.ainvoke(state,
159-
config={
160-
'recursion_limit': (config.max_tool_calls + 1) * 2,
161-
'callbacks': [LangchainProfilerHandler()]
162-
})
157+
# run the ReAct Agent Graph
158+
state = await graph.ainvoke(state, config={'recursion_limit': (config.max_tool_calls + 1) * 2})
163159
# setting recursion_limit: 4 allows 1 tool call
164160
# - allows the ReAct Agent to perform 1 cycle / call 1 single tool,
165161
# - but stops the agent when it tries to call a tool a second time

packages/nvidia_nat_langchain/tests/agent/test_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from langchain_core.messages import AIMessage
2323
from langchain_core.messages import HumanMessage
2424
from langchain_core.messages import ToolMessage
25-
from langchain_core.runnables import RunnableConfig
2625
from langgraph.graph.state import CompiledStateGraph
2726

2827
from nat.plugins.langchain.agent.base import BaseAgent
@@ -40,7 +39,7 @@ def __init__(self, detailed_logs=True, log_response_max_chars=1000):
4039
self.callbacks = []
4140
self.detailed_logs = detailed_logs
4241
self.log_response_max_chars = log_response_max_chars
43-
self._runnable_config = RunnableConfig()
42+
self.graph = None
4443

4544
async def _build_graph(self, state_schema: type) -> CompiledStateGraph:
4645
"""Mock implementation."""

0 commit comments

Comments
 (0)