Skip to content

Commit 49985c9

Browse files
sasha-gitgcopybara-github
authored andcommitted
fix: cancel siblings in parallel function calling on failure
When a tool execution fails in a parallel batch, we now cancel all other pending tools in that batch and propagate the exception, instead of letting them run as orphaned tasks. Co-authored-by: Sasha Sobran <asobran@google.com> PiperOrigin-RevId: 905286925
1 parent 4073238 commit 49985c9

2 files changed

Lines changed: 102 additions & 2 deletions

File tree

src/google/adk/flows/llm_flows/functions.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,14 @@ async def handle_function_call_list_async(
401401
]
402402

403403
# Wait for all tasks to complete
404-
function_response_events = await asyncio.gather(*tasks)
404+
try:
405+
function_response_events = await asyncio.gather(*tasks)
406+
except Exception:
407+
for t in tasks:
408+
if not t.done():
409+
t.cancel()
410+
await asyncio.gather(*tasks, return_exceptions=True)
411+
raise
405412

406413
# Filter out None results
407414
function_response_events = [
@@ -624,7 +631,14 @@ async def handle_function_calls_live(
624631
]
625632

626633
# Wait for all tasks to complete
627-
function_response_events = await asyncio.gather(*tasks)
634+
try:
635+
function_response_events = await asyncio.gather(*tasks)
636+
except Exception:
637+
for t in tasks:
638+
if not t.done():
639+
t.cancel()
640+
await asyncio.gather(*tasks, return_exceptions=True)
641+
raise
628642

629643
# Filter out None results
630644
function_response_events = [
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
from typing import Any
17+
18+
from google.adk.agents.llm_agent import Agent
19+
from google.adk.flows.llm_flows import functions
20+
from google.adk.tools.tool_context import ToolContext
21+
from google.genai import types
22+
import pytest
23+
24+
from ... import testing_utils
25+
26+
27+
def function_call(function_call_id, name, args: dict[str, Any]) -> types.Part:
28+
part = types.Part.from_function_call(name=name, args=args)
29+
part.function_call.id = function_call_id
30+
return part
31+
32+
33+
@pytest.mark.asyncio
34+
async def test_parallel_function_call_error_fail_fast():
35+
id_1 = 'id_1'
36+
id_2 = 'id_2'
37+
responses = [
38+
[
39+
function_call(id_1, 'fail_tool', {}),
40+
function_call(id_2, 'sleep_tool', {}),
41+
],
42+
[
43+
types.Part.from_text(text='final response'),
44+
],
45+
]
46+
47+
mock_model = testing_utils.MockModel.create(responses=responses)
48+
49+
fail_called = False
50+
sleep_started = False
51+
sleep_completed = False
52+
sleep_cancelled = False
53+
54+
async def fail_tool(tool_context: ToolContext) -> str:
55+
nonlocal fail_called
56+
fail_called = True
57+
raise ValueError('Tool failed intentionally')
58+
59+
async def sleep_tool(tool_context: ToolContext) -> str:
60+
nonlocal sleep_started, sleep_completed, sleep_cancelled
61+
sleep_started = True
62+
try:
63+
await asyncio.sleep(10) # Sleep long enough to be cancelled
64+
sleep_completed = True
65+
return 'Tool succeeded'
66+
except asyncio.CancelledError:
67+
sleep_cancelled = True
68+
raise
69+
70+
agent = Agent(
71+
name='root_agent',
72+
model=mock_model,
73+
tools=[fail_tool, sleep_tool],
74+
)
75+
76+
runner = testing_utils.InMemoryRunner(agent)
77+
78+
with pytest.raises(ValueError, match='Tool failed intentionally'):
79+
await runner.run_async(
80+
new_message=types.Content(parts=[types.Part(text='test')]),
81+
)
82+
83+
assert fail_called
84+
assert sleep_started
85+
assert not sleep_completed
86+
assert sleep_cancelled

0 commit comments

Comments
 (0)