Skip to content

Commit d7aa7b5

Browse files
committed
fix(workflow): Prevent incorrect chat agent wiring in graphs
Add validation to raise a ValueError at graph compilation time when a chat-mode agent has an incoming edge from a non-START node. This prevents silent dropping of node inputs. Fixes #5868 Change-Id: I575a7594a8912b4ea6eb4c1ed539c6c7f16eadd1
1 parent bc45ee6 commit d7aa7b5

3 files changed

Lines changed: 66 additions & 0 deletions

File tree

src/google/adk/workflow/_graph.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,26 @@ def _validate_static_schemas(self) -> None:
517517
f" input schema {to_node.input_schema}."
518518
)
519519

520+
def _validate_chat_agent_wiring(self) -> None:
521+
"""Validates that chat-mode agents do not have incoming edges from non-START nodes."""
522+
from ..agents.llm_agent import LlmAgent
523+
524+
for edge in self.edges:
525+
to_node = edge.to_node
526+
if (
527+
isinstance(to_node, LlmAgent)
528+
and getattr(to_node, "mode", None) == "chat"
529+
):
530+
if edge.from_node.name != START.name:
531+
raise ValueError(
532+
f"The agent '{to_node.name}' has been added to the workflow with"
533+
f" mode='chat' following node '{edge.from_node.name}'. This is"
534+
" not supported because chat-mode agents rely on conversational"
535+
" history (session events) and cannot consume direct node inputs"
536+
" from preceding nodes. Please change the agent's mode to"
537+
" 'single_turn'"
538+
)
539+
520540
def _compute_terminal_nodes(self) -> None:
521541
"""Computes terminal nodes (no outgoing edges)."""
522542
from_names = {edge.from_node.name for edge in self.edges}
@@ -535,4 +555,5 @@ def validate_graph(self) -> None:
535555
self._validate_default_routes()
536556
self._detect_unconditional_cycles(node_names)
537557
self._validate_static_schemas()
558+
self._validate_chat_agent_wiring()
538559
self._compute_terminal_nodes()

tests/unittests/workflow/test_graph.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,3 +775,19 @@ def test_get_next_pending_nodes() -> None:
775775
'NodeA', routes_to_match=['route1', 'unknown_route']
776776
)
777777
assert set(next_nodes) == {'NodeB', 'NodeC'}
778+
779+
780+
def test_chat_agent_wiring_validation_only_runs_on_llm_agent() -> None:
781+
"""Tests that _validate_chat_agent_wiring checks non-LlmAgent nodes safely."""
782+
node_a = TestingNode(name='NodeA')
783+
node_b = TestingNode(name='NodeB')
784+
# Set mode='chat' on a non-LlmAgent node
785+
object.__setattr__(node_b, 'mode', 'chat')
786+
787+
graph = Graph(
788+
edges=[
789+
Edge(from_node=START, to_node=node_a),
790+
Edge(from_node=node_a, to_node=node_b),
791+
],
792+
)
793+
graph.validate_graph() # Should not raise because node_b is a TestingNode, not LlmAgent

tests/unittests/workflow/test_llm_agent_as_node.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,35 @@ async def mock_run_async(*args, **kwargs):
10931093
assert events[0].output is None
10941094

10951095

1096+
def test_chat_mode_agent_following_non_start_raises_validation_error():
1097+
"""Wiring a chat-mode agent following a non-START node raises ValueError."""
1098+
agent = _make_v1_agent(mode='chat')
1099+
predecessor = TestingNode(name='pred', output='some output')
1100+
1101+
with pytest.raises(ValueError) as exc_info:
1102+
Workflow(
1103+
name='wf',
1104+
edges=[('START', predecessor), (predecessor, agent)],
1105+
)
1106+
1107+
assert (
1108+
"The agent 'test_v1_agent' has been added to the workflow with"
1109+
" mode='chat' following node 'pred'."
1110+
in str(exc_info.value)
1111+
)
1112+
1113+
1114+
def test_chat_mode_agent_from_start_allowed():
1115+
"""Wiring a chat-mode agent directly from START is allowed and validated without error."""
1116+
agent = _make_v1_agent(mode='chat')
1117+
1118+
wf = Workflow(
1119+
name='wf',
1120+
edges=[('START', agent)],
1121+
)
1122+
assert wf.graph is not None
1123+
1124+
10961125
@pytest.mark.asyncio
10971126
async def test_three_layer_llm_agent_transfer_round_trip(
10981127
request: pytest.FixtureRequest,

0 commit comments

Comments
 (0)