Skip to content

Commit 85ed502

Browse files
jsonbaileyclaude
andcommitted
fix: Replace fan-out static edges with LLM-driven handoff routing
Multi-child nodes previously used static add_edge calls, causing LangGraph to fan-out to all children in parallel. Replace with handoff tools (Command(goto=child_key) via @tool + InjectedToolCallId) so the LLM picks exactly one child per turn. Bind handoff nodes with parallel_tool_calls=False to prevent the model from selecting multiple destinations in a single response. Switch WorkflowState.messages to add_messages reducer (deduplicates by ID) and add recursion_limit=25 as a safety cap. Adds test asserting single-child routing in a 3-node orchestrator graph. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8cad7e5 commit 85ed502

2 files changed

Lines changed: 207 additions & 35 deletions

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py

Lines changed: 97 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""
22

3-
import operator
43
import time
54
from typing import Annotated, Any, Dict, List, Tuple
65

@@ -25,6 +24,40 @@ def _tool_call_id_from_entry(tc: Any) -> Any:
2524
return getattr(tc, 'id', None)
2625

2726

27+
def _make_handoff_tool(child_key: str, description: str) -> Any:
28+
"""
29+
Create a tool that transfers control to ``child_key``.
30+
31+
Uses the ``@tool`` decorator with ``InjectedState`` + ``InjectedToolCallId``
32+
so LangGraph's ToolNode handles the ``Command`` return value correctly.
33+
The tool explicitly creates a ToolMessage in ``Command.update`` to satisfy
34+
the LangChain/OpenAI message-chain contract.
35+
"""
36+
from typing import Annotated as _Annotated
37+
38+
from langchain_core.messages import ToolMessage
39+
from langchain_core.tools import tool
40+
from langchain_core.tools.base import InjectedToolCallId
41+
from langgraph.prebuilt import InjectedState
42+
from langgraph.types import Command
43+
44+
tool_name = f"transfer_to_{child_key.replace('-', '_')}"
45+
46+
@tool(tool_name, description=description)
47+
def handoff(
48+
state: _Annotated[Any, InjectedState], # noqa: ARG001
49+
tool_call_id: _Annotated[str, InjectedToolCallId],
50+
) -> Command:
51+
tool_message = ToolMessage(
52+
content=f'Transferred to {child_key}',
53+
name=tool_name,
54+
tool_call_id=tool_call_id,
55+
)
56+
return Command(goto=child_key, update={'messages': [tool_message]})
57+
58+
return handoff
59+
60+
2861
def _coalesce_tool_messages_for_openai(msgs: List[Any]) -> List[Any]:
2962
"""
3063
Rewind shared LangGraph message state into OpenAI's required shape.
@@ -168,11 +201,12 @@ def _build_graph(self) -> Tuple[Any, Dict[str, str]]:
168201
"""
169202
from langchain_core.messages import SystemMessage
170203
from langgraph.graph import END, START, StateGraph
204+
from langgraph.graph.message import add_messages
171205
from langgraph.prebuilt import ToolNode, tools_condition
172206
from typing_extensions import TypedDict
173207

174208
class WorkflowState(TypedDict):
175-
messages: Annotated[List[Any], operator.add]
209+
messages: Annotated[List[Any], add_messages]
176210

177211
agent_builder: StateGraph = StateGraph(WorkflowState)
178212
root_node = self._graph.root()
@@ -184,22 +218,16 @@ class WorkflowState(TypedDict):
184218
def handle_traversal(node: AgentGraphNode, ctx: dict) -> None:
185219
node_config = node.get_config()
186220
node_key = node.get_key()
221+
instructions = node_config.instructions if hasattr(node_config, 'instructions') else None
222+
outgoing_edges = node.get_edges()
187223

224+
lc_model = None
188225
tool_fns: list = []
189-
model = None
190-
instructions = node_config.instructions if hasattr(node_config, 'instructions') else None
191226
if node_config.model:
192227
# We send an empty tool registry to avoid binding tools to the model.
193228
lc_model = create_langchain_model(node_config, tool_registry=None)
194229

195-
# Retrieve tool definitions to build fn_name_to_config_key map
196-
config_dict = node_config.to_dict()
197-
model_dict = config_dict.get('model') or {}
198-
parameters = dict(model_dict.get('parameters') or {})
199-
tool_defs = parameters.get('tools', []) or []
200-
201230
tool_fns = build_structured_tools(node_config, tools_ref)
202-
model = lc_model.bind_tools(tool_fns) if tool_fns else lc_model
203231

204232
# Map tool name -> LD config key for callback attribution.
205233
# build_structured_tools returns StructuredTool instances with tool.name set
@@ -209,6 +237,33 @@ def handle_traversal(node: AgentGraphNode, ctx: dict) -> None:
209237
if tool_name:
210238
fn_name_to_config_key[tool_name] = tool_name
211239

240+
# For nodes with multiple children, create a handoff tool per child so the
241+
# LLM decides which agent to route to. Uses Command(goto=child_key) so
242+
# LangGraph routes to the target without looping back here.
243+
handoff_fns: list = []
244+
if lc_model and len(outgoing_edges) > 1:
245+
for edge in outgoing_edges:
246+
child_node = self._graph.get_node(edge.target_config)
247+
description = (
248+
(edge.handoff or {}).get('description')
249+
or (
250+
child_node.get_config().instructions[:120]
251+
if child_node and child_node.get_config().instructions
252+
else None
253+
)
254+
or f"Transfer control to {edge.target_config}"
255+
)
256+
handoff_fns.append(_make_handoff_tool(edge.target_config, description))
257+
258+
all_tools = tool_fns + handoff_fns
259+
if lc_model and all_tools:
260+
# When handoff tools are present, disable parallel tool calls so the LLM
261+
# picks exactly one destination rather than routing to multiple children.
262+
bind_kwargs = {'parallel_tool_calls': False} if handoff_fns else {}
263+
model = lc_model.bind_tools(all_tools, **bind_kwargs)
264+
else:
265+
model = lc_model
266+
212267
def make_node_fn(bound_model: Any, node_instructions: Any, nk: str):
213268
async def invoke(state: WorkflowState) -> dict:
214269
if not bound_model:
@@ -234,30 +289,45 @@ async def invoke(state: WorkflowState) -> dict:
234289
if node_key == root_key:
235290
agent_builder.add_edge(START, node_key)
236291

237-
outgoing_edges = node.get_edges()
238-
239292
# Collect node info for graph structure log
240293
tool_names = [str(getattr(t, 'name', None) or getattr(t, '__name__', t)) for t in tool_fns]
241294
edge_targets = [e.target_config for e in outgoing_edges]
242295
node_desc = node_key
243296
if tool_names:
244297
node_desc += f"[tools:{','.join(tool_names)}]"
245-
node_desc += f"→{','.join(edge_targets)}" if edge_targets else "(terminal)"
298+
if handoff_fns:
299+
node_desc += f"[handoff:{','.join(edge_targets)}]"
300+
elif edge_targets:
301+
node_desc += f"→{','.join(edge_targets)}"
302+
else:
303+
node_desc += "(terminal)"
246304
graph_structure.append(node_desc)
247305

248-
if tool_fns:
249-
# Pair this node with a ToolNode and loop it back (standard LangGraph pattern).
250-
# tools_condition routes to "tools" when the response has tool calls,
251-
# and to END otherwise; the path_map redirects those to our named nodes.
306+
if all_tools:
307+
# ToolNode handles Command returns from handoff tools, routing to the target
308+
# node. For functional tools it returns normal ToolMessages and we loop back.
309+
# tools_condition exits to END when no tool is called.
252310
tools_node_key = f"{node_key}__tools"
253-
after_loop = outgoing_edges[0].target_config if outgoing_edges else END
254-
agent_builder.add_node(tools_node_key, ToolNode(tool_fns))
255-
agent_builder.add_edge(tools_node_key, node_key)
256-
agent_builder.add_conditional_edges(
257-
node_key,
258-
tools_condition,
259-
{"tools": tools_node_key, END: after_loop},
260-
)
311+
agent_builder.add_node(tools_node_key, ToolNode(all_tools))
312+
313+
if not handoff_fns:
314+
# No handoff tools: standard loop-back after tool execution.
315+
after_loop = outgoing_edges[0].target_config if outgoing_edges else END
316+
agent_builder.add_edge(tools_node_key, node_key)
317+
agent_builder.add_conditional_edges(
318+
node_key,
319+
tools_condition,
320+
{"tools": tools_node_key, END: after_loop},
321+
)
322+
else:
323+
# Handoff tools use Command(goto=child_key) — LangGraph routes to the
324+
# target directly without any extra edge. The ToolNode does NOT loop
325+
# back here. tools_condition exits to END when no tool is called.
326+
agent_builder.add_conditional_edges(
327+
node_key,
328+
tools_condition,
329+
{"tools": tools_node_key, END: END},
330+
)
261331
else:
262332
if node.is_terminal():
263333
agent_builder.add_edge(node_key, END)
@@ -276,14 +346,6 @@ async def invoke(state: WorkflowState) -> dict:
276346
)
277347

278348
compiled = agent_builder.compile()
279-
# try:
280-
# image_data = compiled.get_graph().draw_mermaid_png()
281-
# out_path = f"{graph_key_str}_langgraph.png"
282-
# with open(out_path, mode='wb') as f:
283-
# f.write(image_data)
284-
# except Exception as exc:
285-
# log.debug('LangGraphAgentGraphRunner: could not write graph PNG (%s)', exc)
286-
287349
return compiled, fn_name_to_config_key
288350

289351
async def run(self, input: Any) -> AgentGraphResult:
@@ -310,7 +372,7 @@ async def run(self, input: Any) -> AgentGraphResult:
310372

311373
result = await compiled.ainvoke( # type: ignore[call-overload]
312374
{'messages': [HumanMessage(content=str(input))]},
313-
config={'callbacks': [handler]},
375+
config={'callbacks': [handler], 'recursion_limit': 25},
314376
)
315377

316378
duration = (time.perf_counter_ns() - start_ns) // 1_000_000

packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,3 +490,113 @@ def model_factory(node_config, **kwargs):
490490
path_data = ev['$ld:ai:graph:path'][0][0]
491491
assert 'root-agent' in path_data['path']
492492
assert 'child-agent' in path_data['path']
493+
494+
495+
def _make_multi_child_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition':
496+
"""Build a 3-node graph: orchestrator → agent-a, orchestrator → agent-b."""
497+
context = MagicMock()
498+
499+
def _node_tracker(key: str) -> LDAIConfigTracker:
500+
return LDAIConfigTracker(
501+
ld_client=mock_ld_client,
502+
variation_key='test-variation',
503+
config_key=key,
504+
version=1,
505+
model_name='gpt-4',
506+
provider_name='openai',
507+
context=context,
508+
)
509+
510+
graph_tracker = AIGraphTracker(
511+
ld_client=mock_ld_client,
512+
variation_key='test-variation',
513+
graph_key='multi-child-graph',
514+
version=1,
515+
context=context,
516+
)
517+
518+
configs = {
519+
'orchestrator': AIAgentConfig(
520+
key='orchestrator',
521+
enabled=True,
522+
model=ModelConfig(name='gpt-4', parameters={}),
523+
provider=ProviderConfig(name='openai'),
524+
instructions='Route to the appropriate specialist agent.',
525+
tracker=_node_tracker('orchestrator'),
526+
),
527+
'agent-a': AIAgentConfig(
528+
key='agent-a',
529+
enabled=True,
530+
model=ModelConfig(name='gpt-4', parameters={}),
531+
provider=ProviderConfig(name='openai'),
532+
instructions='You handle topic A.',
533+
tracker=_node_tracker('agent-a'),
534+
),
535+
'agent-b': AIAgentConfig(
536+
key='agent-b',
537+
enabled=True,
538+
model=ModelConfig(name='gpt-4', parameters={}),
539+
provider=ProviderConfig(name='openai'),
540+
instructions='You handle topic B.',
541+
tracker=_node_tracker('agent-b'),
542+
),
543+
}
544+
545+
edges = [
546+
Edge(key='orch-to-a', source_config='orchestrator', target_config='agent-a'),
547+
Edge(key='orch-to-b', source_config='orchestrator', target_config='agent-b'),
548+
]
549+
graph_config = AIAgentGraphConfig(
550+
key='multi-child-graph',
551+
root_config_key='orchestrator',
552+
edges=edges,
553+
enabled=True,
554+
)
555+
nodes = AgentGraphDefinition.build_nodes(graph_config, configs)
556+
return AgentGraphDefinition(
557+
agent_graph=graph_config,
558+
nodes=nodes,
559+
context=context,
560+
enabled=True,
561+
tracker=graph_tracker,
562+
)
563+
564+
565+
@pytest.mark.asyncio
566+
async def test_multi_child_routes_via_handoff_not_fan_out():
567+
"""Orchestrator with two children routes to exactly one child via handoff tool,
568+
not a fan-out that invokes both children."""
569+
from langchain_core.messages import AIMessage
570+
571+
mock_ld_client = MagicMock()
572+
graph = _make_multi_child_graph(mock_ld_client)
573+
574+
# Orchestrator calls transfer_to_agent_a (handoff tool name derived from child key)
575+
orchestrator_response = AIMessage(
576+
content='',
577+
tool_calls=[{
578+
'name': 'transfer_to_agent_a',
579+
'args': {},
580+
'id': 'call_handoff_1',
581+
'type': 'tool_call',
582+
}],
583+
)
584+
agent_a_response = _make_fake_response('Agent A handled it.')
585+
agent_b_model = _mock_model(_make_fake_response('Agent B handled it.'))
586+
587+
def model_factory(node_config, **kwargs):
588+
if node_config.key == 'orchestrator':
589+
return _mock_model(orchestrator_response)
590+
if node_config.key == 'agent-a':
591+
return _mock_model(agent_a_response)
592+
return agent_b_model
593+
594+
with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model',
595+
side_effect=model_factory):
596+
runner = LangGraphAgentGraphRunner(graph, {})
597+
result = await runner.run('hello')
598+
599+
assert result.metrics.success is True
600+
assert 'Agent A' in result.output
601+
# Agent B's model must never have been invoked — no fan-out
602+
agent_b_model.ainvoke.assert_not_called()

0 commit comments

Comments
 (0)