Skip to content

Commit f28fe5f

Browse files
jsonbaileyclaude
andcommitted
fix: Address code review findings in graph runner and callback handler
- Cache compiled graph in _ensure_compiled() so _build_graph is not called on every run() invocation - Collect node_keys during traversal instead of reaching into _graph._nodes - Fix unconditional break in make_after_tools_router that made reversed() a no-op; replace broken loop with a direct last-message check - Fix duration tracking key collision: key _node_start_ns by run_id instead of node_key so concurrent invocations of the same node don't clobber each other's start times - Warn when a functional-tool node has multiple outgoing edges since only the first edge is reachable after the tool loop exits Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 9dec085 commit f28fe5f

2 files changed

Lines changed: 38 additions & 19 deletions

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""
22

33
import time
4-
from typing import Annotated, Any, Dict, List, Tuple
4+
from typing import Annotated, Any, Dict, List, Optional, Set, Tuple
55

66
from ldai import log
77
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
@@ -76,13 +76,25 @@ def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry):
7676
"""
7777
self._graph = graph
7878
self._tools = tools
79-
80-
def _build_graph(self) -> Tuple[Any, Dict[str, str]]:
79+
self._compiled: Any = None
80+
self._fn_name_to_config_key: Dict[str, str] = {}
81+
self._node_keys: Set[str] = set()
82+
83+
def _ensure_compiled(self) -> None:
84+
"""Build and cache the compiled graph if not already done."""
85+
if self._compiled is None:
86+
compiled, fn_name_to_config_key, node_keys = self._build_graph()
87+
self._compiled = compiled
88+
self._fn_name_to_config_key = fn_name_to_config_key
89+
self._node_keys = node_keys
90+
91+
def _build_graph(self) -> Tuple[Any, Dict[str, str], Set[str]]:
8192
"""
8293
Build and compile the LangGraph StateGraph from the AgentGraphDefinition.
8394
84-
:return: Tuple of (compiled_graph, fn_name_to_config_key) where
85-
fn_name_to_config_key maps tool function __name__ to LD config key.
95+
:return: Tuple of (compiled_graph, fn_name_to_config_key, node_keys) where
96+
fn_name_to_config_key maps tool function __name__ to LD config key, and
97+
node_keys is the set of all agent node keys in the graph.
8698
"""
8799
from langchain_core.messages import SystemMessage
88100
from langgraph.graph import END, START, StateGraph
@@ -99,10 +111,12 @@ class WorkflowState(TypedDict):
99111
tools_ref = self._tools
100112
graph_structure: List[str] = []
101113
fn_name_to_config_key: Dict[str, str] = {}
114+
node_keys: Set[str] = set()
102115

103116
def handle_traversal(node: AgentGraphNode, ctx: dict) -> None:
104117
node_config = node.get_config()
105118
node_key = node.get_key()
119+
node_keys.add(node_key)
106120
instructions = node_config.instructions if hasattr(node_config, 'instructions') else None
107121
outgoing_edges = node.get_edges()
108122

@@ -190,6 +204,12 @@ async def invoke(state: WorkflowState) -> dict:
190204
if not handoff_fns:
191205
# No handoff tools: standard loop-back after tool execution.
192206
after_loop = outgoing_edges[0].target_config if outgoing_edges else END
207+
if len(outgoing_edges) > 1:
208+
log.warning(
209+
f"Node '{node_key}' has {len(outgoing_edges)} outgoing edges but no handoff "
210+
"tools; only the first edge will be used after the tool loop. "
211+
"Use handoff tools for multi-child routing."
212+
)
193213
agent_builder.add_edge(tools_node_key, node_key)
194214
agent_builder.add_conditional_edges(
195215
node_key,
@@ -212,10 +232,11 @@ async def invoke(state: WorkflowState) -> dict:
212232

213233
def make_after_tools_router(parent_key: str, ht_names: frozenset):
214234
def route(state: WorkflowState) -> str:
215-
for msg in reversed(state['messages']):
216-
if hasattr(msg, 'name') and msg.name:
217-
return END if msg.name in ht_names else parent_key
218-
break
235+
msgs = state['messages']
236+
if msgs:
237+
last = msgs[-1]
238+
if hasattr(last, 'name') and last.name in ht_names:
239+
return END
219240
return parent_key
220241
return route
221242

@@ -247,7 +268,7 @@ def route(state: WorkflowState) -> str:
247268
)
248269

249270
compiled = agent_builder.compile()
250-
return compiled, fn_name_to_config_key
271+
return compiled, fn_name_to_config_key, node_keys
251272

252273
async def run(self, input: Any) -> AgentGraphResult:
253274
"""
@@ -266,12 +287,10 @@ async def run(self, input: Any) -> AgentGraphResult:
266287
try:
267288
from langchain_core.messages import HumanMessage
268289

269-
compiled, fn_name_to_config_key = self._build_graph()
270-
271-
node_keys = {node.get_key() for node in self._graph._nodes.values()}
272-
handler = LDMetricsCallbackHandler(node_keys, fn_name_to_config_key)
290+
self._ensure_compiled()
291+
handler = LDMetricsCallbackHandler(self._node_keys, self._fn_name_to_config_key)
273292

274-
result = await compiled.ainvoke( # type: ignore[call-overload]
293+
result = await self._compiled.ainvoke( # type: ignore[call-overload]
275294
{'messages': [HumanMessage(content=str(input))]},
276295
config={'callbacks': [handler], 'recursion_limit': 25},
277296
)

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]):
4242
self._node_tokens: Dict[str, TokenUsage] = {}
4343
# tool config keys called per node
4444
self._node_tool_calls: Dict[str, List[str]] = {}
45-
# start time (ns) per node — only set while running
46-
self._node_start_ns: Dict[str, int] = {}
45+
# start time (ns) per active run_id — keyed by run_id to handle re-entrant nodes
46+
self._node_start_ns: Dict[UUID, int] = {}
4747
# accumulated duration (ms) per node
4848
self._node_duration_ms: Dict[str, int] = {}
4949
# execution path in order (deduplicated)
@@ -96,7 +96,7 @@ def on_chain_start(
9696

9797
if name in self._node_keys:
9898
self._run_to_node[run_id] = name
99-
self._node_start_ns[name] = time.perf_counter_ns()
99+
self._node_start_ns[run_id] = time.perf_counter_ns()
100100
if name not in self._path_set:
101101
self._path.append(name)
102102
self._path_set.add(name)
@@ -117,7 +117,7 @@ def on_chain_end(
117117
node_key = self._run_to_node.get(run_id)
118118
if node_key is None:
119119
return
120-
start_ns = self._node_start_ns.pop(node_key, None)
120+
start_ns = self._node_start_ns.pop(run_id, None)
121121
if start_ns is not None:
122122
elapsed_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
123123
self._node_duration_ms[node_key] = (

0 commit comments

Comments
 (0)