Skip to content

Commit 3625d7d

Browse files
fix: throw exceptions from ConcurrentToolExecutor (#1797)
Co-authored-by: Patrick Gray <pgrayy@amazon.com>
1 parent 4cd7eeb commit 3625d7d

File tree

3 files changed

+78
-27
lines changed

3 files changed

+78
-27
lines changed

src/strands/tools/executors/concurrent.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,34 +48,43 @@ async def _execute(
4848
task_events = [asyncio.Event() for _ in tool_uses]
4949
stop_event = object()
5050

51-
tasks = [
52-
asyncio.create_task(
53-
self._task(
54-
agent,
55-
tool_use,
56-
tool_results,
57-
cycle_trace,
58-
cycle_span,
59-
invocation_state,
60-
task_id,
61-
task_queue,
62-
task_events[task_id],
63-
stop_event,
64-
structured_output_context,
51+
tasks = []
52+
try:
53+
for task_id, tool_use in enumerate(tool_uses):
54+
tasks.append(
55+
asyncio.create_task(
56+
self._task(
57+
agent,
58+
tool_use,
59+
tool_results,
60+
cycle_trace,
61+
cycle_span,
62+
invocation_state,
63+
task_id,
64+
task_queue,
65+
task_events[task_id],
66+
stop_event,
67+
structured_output_context,
68+
)
69+
)
6570
)
66-
)
67-
for task_id, tool_use in enumerate(tool_uses)
68-
]
6971

70-
task_count = len(tasks)
71-
while task_count:
72-
task_id, event = await task_queue.get()
73-
if event is stop_event:
74-
task_count -= 1
75-
continue
72+
task_count = len(tasks)
73+
while task_count:
74+
task_id, event = await task_queue.get()
75+
if event is stop_event:
76+
task_count -= 1
77+
continue
7678

77-
yield event
78-
task_events[task_id].set()
79+
if isinstance(event, Exception):
80+
raise event
81+
82+
yield event
83+
task_events[task_id].set()
84+
finally:
85+
for task in tasks:
86+
task.cancel()
87+
await asyncio.gather(*tasks, return_exceptions=True)
7988

8089
async def _task(
8190
self,
@@ -115,5 +124,8 @@ async def _task(
115124
await task_event.wait()
116125
task_event.clear()
117126

127+
except Exception as e:
128+
task_queue.put_nowait((task_id, e))
129+
118130
finally:
119131
task_queue.put_nowait((task_id, stop_event))

tests/strands/tools/executors/conftest.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import threading
23
import unittest.mock
34

@@ -90,13 +91,24 @@ def func(tool_context: ToolContext) -> str:
9091

9192

9293
@pytest.fixture
93-
def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, interrupt_tool):
94+
def slow_tool():
95+
@strands.tool(name="slow_tool")
96+
async def func():
97+
"""A tool that blocks until cancelled."""
98+
await asyncio.sleep(3)
99+
100+
return func
101+
102+
103+
@pytest.fixture
104+
def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, interrupt_tool, slow_tool):
94105
registry = ToolRegistry()
95106
registry.register_tool(weather_tool)
96107
registry.register_tool(temperature_tool)
97108
registry.register_tool(exception_tool)
98109
registry.register_tool(thread_tool)
99110
registry.register_tool(interrupt_tool)
111+
registry.register_tool(slow_tool)
100112
return registry
101113

102114

tests/strands/tools/executors/test_concurrent.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from strands.hooks import BeforeToolCallEvent
3+
from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent
44
from strands.interrupt import Interrupt
55
from strands.tools.executors import ConcurrentToolExecutor
66
from strands.tools.structured_output._structured_output_context import StructuredOutputContext
@@ -76,3 +76,30 @@ def interrupt_callback(event):
7676
tru_results = tool_results
7777
exp_results = [exp_events[1].tool_result]
7878
assert tru_results == exp_results
79+
80+
81+
@pytest.mark.asyncio
82+
async def test_concurrent_executor_reraises_exceptions(
83+
executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context, alist
84+
):
85+
"""Test that hook re-raised exceptions propagate and cancel remaining tasks."""
86+
87+
def reraise_callback(event):
88+
if event.exception is not None:
89+
raise event.exception
90+
91+
agent.hooks.add_callback(AfterToolCallEvent, reraise_callback)
92+
93+
tool_uses = [
94+
{"name": "exception_tool", "toolUseId": "1", "input": {}},
95+
{"name": "slow_tool", "toolUseId": "2", "input": {}},
96+
]
97+
98+
stream = executor._execute(
99+
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context
100+
)
101+
102+
with pytest.raises(RuntimeError, match="Tool error"):
103+
await alist(stream)
104+
105+
assert tool_results == []

0 commit comments

Comments
 (0)