Skip to content

Commit 287c5b6

Browse files
mattdai01daimatth
andauthored
fix: propagate tool exceptions to spans so StatusCode.ERROR is set correctly (#2046)
Co-authored-by: Matthew Dai <daimatth@amazon.com>
1 parent a19e73d commit 287c5b6

3 files changed

Lines changed: 98 additions & 7 deletions

File tree

src/strands/tools/executors/_executor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,15 @@ async def _stream(
171171
}
172172

173173
after_event, _ = await ToolExecutor._invoke_after_tool_call_hook(
174-
agent, None, tool_use, invocation_state, cancel_result, cancel_message=cancel_message
174+
agent,
175+
None,
176+
tool_use,
177+
invocation_state,
178+
cancel_result,
179+
exception=Exception(cancel_message),
180+
cancel_message=cancel_message,
175181
)
176-
yield ToolResultEvent(after_event.result)
182+
yield ToolResultEvent(after_event.result, exception=after_event.exception)
177183
tool_results.append(after_event.result)
178184
return
179185

@@ -202,15 +208,16 @@ async def _stream(
202208
"content": [{"text": f"Unknown tool: {tool_name}"}],
203209
}
204210

211+
unknown_tool_error = Exception(f"Unknown tool: {tool_name}")
205212
after_event, _ = await ToolExecutor._invoke_after_tool_call_hook(
206-
agent, selected_tool, tool_use, invocation_state, result
213+
agent, selected_tool, tool_use, invocation_state, result, exception=unknown_tool_error
207214
)
208215
# Check if retry requested for unknown tool error
209216
# Use getattr because BidiAfterToolCallEvent doesn't have retry attribute
210217
if getattr(after_event, "retry", False):
211218
logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name)
212219
continue
213-
yield ToolResultEvent(after_event.result)
220+
yield ToolResultEvent(after_event.result, exception=after_event.exception)
214221
tool_results.append(after_event.result)
215222
return
216223
if structured_output_context.is_enabled:
@@ -258,7 +265,7 @@ async def _stream(
258265
logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name)
259266
continue
260267

261-
yield ToolResultEvent(after_event.result)
268+
yield ToolResultEvent(after_event.result, exception=after_event.exception)
262269
tool_results.append(after_event.result)
263270
return
264271

@@ -277,7 +284,7 @@ async def _stream(
277284
if getattr(after_event, "retry", False):
278285
logger.debug("tool_name=<%s> | retry requested after exception, retrying tool call", tool_name)
279286
continue
280-
yield ToolResultEvent(after_event.result)
287+
yield ToolResultEvent(after_event.result, exception=after_event.exception)
281288
tool_results.append(after_event.result)
282289
return
283290

@@ -338,7 +345,7 @@ async def _stream_with_trace(
338345
agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
339346
cycle_trace.add_child(tool_trace)
340347

341-
tracer.end_tool_call_span(tool_call_span, result)
348+
tracer.end_tool_call_span(tool_call_span, result, error=result_event.exception)
342349

343350
@abc.abstractmethod
344351
# pragma: no cover

tests/strands/telemetry/test_tracer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,20 @@ def test_end_tool_call_span_latest_conventions(mock_span, monkeypatch):
707707
mock_span.end.assert_called_once()
708708

709709

710+
def test_end_tool_call_span_with_error(mock_span):
711+
"""Test ending a tool call span with an explicit error sets StatusCode.ERROR."""
712+
tracer = Tracer()
713+
error = ValueError("tool exploded")
714+
tool_result = {"status": "error", "content": [{"text": "Error: tool exploded"}]}
715+
716+
tracer.end_tool_call_span(mock_span, tool_result, error=error)
717+
718+
mock_span.set_attributes.assert_called_once_with({"gen_ai.tool.status": "error"})
719+
mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "tool exploded")
720+
mock_span.record_exception.assert_called_once_with(error)
721+
mock_span.end.assert_called_once()
722+
723+
710724
def test_start_event_loop_cycle_span(mock_tracer):
711725
"""Test starting an event loop cycle span."""
712726
with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):

tests/strands/tools/executors/test_executor.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results
189189
tool_use=tool_use,
190190
invocation_state=invocation_state,
191191
result=exp_results[0],
192+
exception=unittest.mock.ANY,
192193
)
193194
assert tru_hook_after_event == exp_hook_after_event
194195

@@ -216,6 +217,7 @@ async def test_executor_stream_with_trace(
216217
tracer.end_tool_call_span.assert_called_once_with(
217218
tracer.start_tool_call_span.return_value,
218219
{"content": [{"text": "sunny"}], "status": "success", "toolUseId": "1"},
220+
error=None,
219221
)
220222

221223
cycle_trace.add_child.assert_called_once()
@@ -901,3 +903,71 @@ def retry_once_on_unknown(event):
901903
assert len(tru_events) == 1
902904
assert tru_events[0].tool_result["status"] == "error"
903905
assert "Unknown tool" in tru_events[0].tool_result["content"][0]["text"]
906+
907+
908+
@pytest.mark.asyncio
909+
async def test_executor_stream_with_trace_error(
910+
executor, tracer, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist
911+
):
912+
"""Test that _stream_with_trace passes the exception to end_tool_call_span when a tool fails."""
913+
tool_use: ToolUse = {"name": "exception_tool", "toolUseId": "1", "input": {}}
914+
stream = executor._stream_with_trace(agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state)
915+
916+
await alist(stream)
917+
918+
tracer.end_tool_call_span.assert_called_once()
919+
call_args = tracer.end_tool_call_span.call_args
920+
assert call_args[0][1]["status"] == "error"
921+
error_arg = call_args[1].get("error")
922+
assert error_arg is not None
923+
assert isinstance(error_arg, RuntimeError)
924+
assert "Tool error" in str(error_arg)
925+
926+
927+
@pytest.mark.asyncio
928+
async def test_executor_stream_error_preserves_exception(executor, agent, tool_results, invocation_state, alist):
929+
"""Test that _stream yields a ToolResultEvent with the exception preserved."""
930+
tool_use: ToolUse = {"name": "exception_tool", "toolUseId": "1", "input": {}}
931+
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
932+
933+
events = await alist(stream)
934+
result_event = events[-1]
935+
assert isinstance(result_event, ToolResultEvent)
936+
assert result_event.tool_result["status"] == "error"
937+
assert result_event.exception is not None
938+
assert isinstance(result_event.exception, RuntimeError)
939+
assert "Tool error" in str(result_event.exception)
940+
941+
942+
@pytest.mark.asyncio
943+
async def test_executor_stream_unknown_tool_has_exception(executor, agent, tool_results, invocation_state, alist):
944+
"""Test that _stream yields a ToolResultEvent with exception for unknown tools."""
945+
tool_use: ToolUse = {"name": "nonexistent_tool", "toolUseId": "1", "input": {}}
946+
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
947+
948+
events = await alist(stream)
949+
result_event = events[-1]
950+
assert isinstance(result_event, ToolResultEvent)
951+
assert result_event.tool_result["status"] == "error"
952+
assert result_event.exception is not None
953+
assert "Unknown tool" in str(result_event.exception)
954+
955+
956+
@pytest.mark.asyncio
957+
async def test_executor_stream_cancel_has_exception(executor, agent, tool_results, invocation_state, alist):
958+
"""Test that _stream yields a ToolResultEvent with exception for cancelled tools."""
959+
960+
def cancel_callback(event):
961+
event.cancel_tool = True
962+
return event
963+
964+
agent.hooks.add_callback(BeforeToolCallEvent, cancel_callback)
965+
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
966+
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
967+
968+
events = await alist(stream)
969+
result_event = events[-1]
970+
assert isinstance(result_event, ToolResultEvent)
971+
assert result_event.tool_result["status"] == "error"
972+
assert result_event.exception is not None
973+
assert "cancelled" in str(result_event.exception)

0 commit comments

Comments
 (0)