diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3a23133de..9cdf071fd 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -841,7 +841,7 @@ async def stream_async( self._end_agent_trace_span(response=result) - except Exception as e: + except BaseException as e: self._end_agent_trace_span(error=e) raise @@ -1044,7 +1044,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: def _end_agent_trace_span( self, response: AgentResult | None = None, - error: Exception | None = None, + error: BaseException | None = None, ) -> None: """Ends a trace span for the agent. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index b4af16058..835122b4e 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -138,6 +138,7 @@ async def event_loop_cycle( custom_trace_attributes=agent.trace_attributes, ) invocation_state["event_loop_cycle_span"] = cycle_span + model_events: AsyncGenerator[TypedEvent, None] | None = None with trace_api.use_span(cycle_span, end_on_exit=False): try: @@ -153,15 +154,21 @@ async def event_loop_cycle( model_events = _handle_model_execution( agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context ) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + try: + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + finally: + await model_events.aclose() stop_reason, message, *_ = model_event["stop"] yield ModelMessageEvent(message=message) except Exception as e: tracer.end_span_with_error(cycle_span, str(e), e) raise + except BaseException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise try: if stop_reason == "max_tokens": @@ -238,6 +245,9 @@ async def event_loop_cycle( yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e + except BaseException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise async def recurse_event_loop( @@ -323,6 +333,7 @@ async def _handle_model_execution( system_prompt=agent.system_prompt, system_prompt_content=agent._system_prompt_content, ) + streamed_events: AsyncGenerator[TypedEvent, None] | None = None with trace_api.use_span(model_invoke_span, end_on_exit=False): try: await agent.hooks.invoke_callbacks_async( @@ -338,7 +349,7 @@ async def _handle_model_execution( else: tool_specs = agent.tool_registry.get_all_tool_specs() - async for event in stream_messages( + streamed_events = stream_messages( agent.model, agent.system_prompt, agent.messages, @@ -348,8 +359,12 @@ async def _handle_model_execution( invocation_state=invocation_state, model_state=agent._model_state, cancel_signal=agent._cancel_signal, - ): - yield event + ) + try: + async for event in streamed_events: + yield event + finally: + await streamed_events.aclose() stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -410,6 +425,9 @@ async def _handle_model_execution( # No retry requested, raise the exception yield ForceStopEvent(reason=e) raise e + except BaseException as e: + tracer.end_span_with_error(model_invoke_span, str(e), e) + raise try: # Add message in trace and mark the end of the stream messages trace diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 19a163f5c..1ff968558 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -184,7 +184,7 @@ def _end_span( self, span: Span, attributes: dict[str, AttributeValue] | None = None, - error: Exception | None = None, + error: BaseException | None = None, error_message: str | None = None, ) -> None: """Generic helper method to end a span. @@ -224,7 +224,7 @@ def _end_span( except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: BaseException | None = None) -> None: """End a span with error status. Args: @@ -450,7 +450,9 @@ def start_tool_call_span( return span - def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None: + def end_tool_call_span( + self, span: Span, tool_result: ToolResult | None, error: BaseException | None = None + ) -> None: """End a tool call span with results. Args: @@ -650,7 +652,7 @@ def end_agent_span( self, span: Span, response: AgentResult | None = None, - error: Exception | None = None, + error: BaseException | None = None, ) -> None: """End an agent span with results and metrics. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5a3cce11c..337b269af 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1423,6 +1423,27 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_base_exception(mock_get_tracer, mock_model, alist): + """Test that stream_async ends the agent span when a BaseException occurs.""" + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + test_exception = KeyboardInterrupt("stop now") + mock_model.mock_stream.side_effect = test_exception + + agent = Agent(model=mock_model) + + with pytest.raises(KeyboardInterrupt, match="stop now"): + stream = agent.stream_async("test prompt") + await alist(stream) + + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) + + def test_agent_init_with_state_object(): agent = Agent(state=AgentState({"foo": "bar"})) assert agent.state.get("foo") == "bar" diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index f91f7c2af..5903651f4 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -685,6 +685,121 @@ async def test_event_loop_tracing_with_tool_execution( assert mock_tracer.end_model_invoke_span.call_count == 2 +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_stream_aclose( + mock_get_tracer, + agent, + model, + mock_tracer, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + async def interrupted_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + await asyncio.sleep(10) + yield {"contentBlockStop": {}} + + model.stream.return_value = interrupted_stream() + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await anext(stream) + await anext(stream) + await anext(stream) + await stream.aclose() + + assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span] + assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [ + "", + "", + ] + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_task_cancellation( + mock_get_tracer, + agent, + model, + mock_tracer, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + blocked_on_stream = asyncio.Event() + release_stream = asyncio.Event() + + async def interrupted_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + blocked_on_stream.set() + await release_stream.wait() + yield {"contentBlockStop": {}} + + model.stream.return_value = interrupted_stream() + + async def consume() -> None: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + async for _ in stream: + pass + + task = asyncio.create_task(consume()) + await blocked_on_stream.wait() + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span] + assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [ + "", + "", + ] + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_keyboard_interrupt( + mock_get_tracer, + agent, + model, + mock_tracer, + alist, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + test_exception = KeyboardInterrupt("stop now") + model.stream.side_effect = test_exception + + with pytest.raises(KeyboardInterrupt, match="stop now"): + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + assert mock_tracer.end_span_with_error.call_args_list == [ + call(model_span, "stop now", test_exception), + call(cycle_span, "stop now", test_exception), + ] + + @pytest.mark.asyncio async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle( agent, diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index bcd42b610..f1f26b835 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -140,6 +140,18 @@ def test_end_span_with_empty_exception_message_uses_exception_name(mock_span): mock_span.end.assert_called_once() +def test_end_span_with_empty_base_exception_message_uses_exception_name(mock_span): + """Test that empty BaseException messages fall back to the exception type name.""" + tracer = Tracer() + error = KeyboardInterrupt() + + tracer.end_span_with_error(mock_span, "", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "KeyboardInterrupt") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_end_span_with_error_prefers_explicit_message(mock_span): """Test that an explicit error message takes precedence over the exception text.""" tracer = Tracer() @@ -1162,6 +1174,30 @@ def test_force_flush_with_error(mock_span, mock_get_tracer_provider): mock_tracer_provider.force_flush.assert_called_once() +def test_end_agent_span_with_empty_error_message_uses_exception_name(mock_span): + """Test that agent spans fall back to the exception type name for empty errors.""" + tracer = Tracer() + error = Exception() + + tracer.end_agent_span(mock_span, error=error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_tool_call_span_with_empty_error_message_uses_exception_name(mock_span): + """Test that tool call spans fall back to the exception type name for empty errors.""" + tracer = Tracer() + error = Exception() + + tracer.end_tool_call_span(mock_span, None, error=error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_end_tool_call_span_with_none(mock_span): """Test ending a tool call span with None result.""" tracer = Tracer()