Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ async def stream_async(

self._end_agent_trace_span(response=result)

except Exception as e:
except BaseException as e:
self._end_agent_trace_span(error=e)
raise

Expand Down Expand Up @@ -1044,7 +1044,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
def _end_agent_trace_span(
self,
response: AgentResult | None = None,
error: Exception | None = None,
error: BaseException | None = None,
) -> None:
"""Ends a trace span for the agent.

Expand Down
48 changes: 33 additions & 15 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ async def event_loop_cycle(
custom_trace_attributes=agent.trace_attributes,
)
invocation_state["event_loop_cycle_span"] = cycle_span
model_events: AsyncGenerator[TypedEvent, None] | None = None

with trace_api.use_span(cycle_span, end_on_exit=False):
try:
Expand All @@ -153,15 +154,21 @@ async def event_loop_cycle(
model_events = _handle_model_execution(
agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context
)
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event
try:
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event
finally:
await model_events.aclose()

stop_reason, message, *_ = model_event["stop"]
yield ModelMessageEvent(message=message)
except Exception as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise
except BaseException as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: The except Exception and except BaseException handlers here have identical bodies — both call tracer.end_span_with_error and re-raise. Since Exception is a subclass of BaseException, the except BaseException handler alone would catch both.

Suggestion: Consolidate into a single handler:

except BaseException as e:
    tracer.end_span_with_error(cycle_span, str(e), e)
    raise

Note: The other two locations (lines 233-250 and 398-430) are correctly separated since the except Exception handlers do additional work (wrapping in EventLoopException, retry logic, etc.) that should not apply to BaseException subclasses.


try:
if stop_reason == "max_tokens":
Expand Down Expand Up @@ -238,6 +245,9 @@ async def event_loop_cycle(
yield ForceStopEvent(reason=e)
logger.exception("cycle failed")
raise EventLoopException(e, invocation_state["request_state"]) from e
except BaseException as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise


async def recurse_event_loop(
Expand Down Expand Up @@ -323,6 +333,7 @@ async def _handle_model_execution(
system_prompt=agent.system_prompt,
system_prompt_content=agent._system_prompt_content,
)
streamed_events: AsyncGenerator[TypedEvent, None] | None = None
with trace_api.use_span(model_invoke_span, end_on_exit=False):
try:
await agent.hooks.invoke_callbacks_async(
Expand All @@ -338,18 +349,22 @@ async def _handle_model_execution(
else:
tool_specs = agent.tool_registry.get_all_tool_specs()

async for event in stream_messages(
agent.model,
agent.system_prompt,
agent.messages,
tool_specs,
system_prompt_content=agent._system_prompt_content,
tool_choice=structured_output_context.tool_choice,
invocation_state=invocation_state,
model_state=agent._model_state,
cancel_signal=agent._cancel_signal,
):
yield event
try:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would it be better to move try on top of streamed_events = stream_messages(...)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it.

streamed_events = stream_messages(
agent.model,
agent.system_prompt,
agent.messages,
tool_specs,
system_prompt_content=agent._system_prompt_content,
tool_choice=structured_output_context.tool_choice,
invocation_state=invocation_state,
model_state=agent._model_state,
cancel_signal=agent._cancel_signal,
)
async for event in streamed_events:
yield event
finally:
await streamed_events.aclose()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: streamed_events is initialized to None (line 336), but the finally block unconditionally calls await streamed_events.aclose(). If stream_messages() raises before returning the generator object (e.g., due to a future signature change or dynamic error), this would raise AttributeError: 'NoneType' object has no attribute 'aclose' — masking the original exception.

Currently this is safe because stream_messages is an async generator function (calling it just creates the object without executing any code), but the code is fragile.

Suggestion: Add a guard:

finally:
    if streamed_events is not None:
        await streamed_events.aclose()

The same pattern at line 162 for model_events has the same concern, though it's slightly safer because model_events is assigned before the try block.


stop_reason, message, usage, metrics = event["stop"]
invocation_state.setdefault("request_state", {})
Expand Down Expand Up @@ -410,6 +425,9 @@ async def _handle_model_execution(
# No retry requested, raise the exception
yield ForceStopEvent(reason=e)
raise e
except BaseException as e:
tracer.end_span_with_error(model_invoke_span, str(e), e)
raise

try:
# Add message in trace and mark the end of the stream messages trace
Expand Down
10 changes: 6 additions & 4 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _end_span(
self,
span: Span,
attributes: dict[str, AttributeValue] | None = None,
error: Exception | None = None,
error: BaseException | None = None,
error_message: str | None = None,
) -> None:
"""Generic helper method to end a span.
Expand Down Expand Up @@ -224,7 +224,7 @@ def _end_span(
except Exception as e:
logger.warning("error=<%s> | failed to force flush tracer provider", e)

def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None:
def end_span_with_error(self, span: Span, error_message: str, exception: BaseException | None = None) -> None:
"""End a span with error status.

Args:
Expand Down Expand Up @@ -450,7 +450,9 @@ def start_tool_call_span(

return span

def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None:
def end_tool_call_span(
self, span: Span, tool_result: ToolResult | None, error: BaseException | None = None
) -> None:
"""End a tool call span with results.

Args:
Expand Down Expand Up @@ -650,7 +652,7 @@ def end_agent_span(
self,
span: Span,
response: AgentResult | None = None,
error: Exception | None = None,
error: BaseException | None = None,
) -> None:
"""End an agent span with results and metrics.

Expand Down
21 changes: 21 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,27 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)


@pytest.mark.asyncio
@unittest.mock.patch("strands.agent.agent.get_tracer")
async def test_agent_stream_async_creates_and_ends_span_on_base_exception(mock_get_tracer, mock_model, alist):
"""Test that stream_async ends the agent span when a BaseException occurs."""
mock_tracer = unittest.mock.MagicMock()
mock_span = unittest.mock.MagicMock()
mock_tracer.start_agent_span.return_value = mock_span
mock_get_tracer.return_value = mock_tracer

test_exception = KeyboardInterrupt("stop now")
mock_model.mock_stream.side_effect = test_exception

agent = Agent(model=mock_model)

with pytest.raises(KeyboardInterrupt, match="stop now"):
stream = agent.stream_async("test prompt")
await alist(stream)

mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)


def test_agent_init_with_state_object():
agent = Agent(state=AgentState({"foo": "bar"}))
assert agent.state.get("foo") == "bar"
Expand Down
115 changes: 115 additions & 0 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,121 @@ async def test_event_loop_tracing_with_tool_execution(
assert mock_tracer.end_model_invoke_span.call_count == 2


@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_cycle_closes_spans_on_stream_aclose(
mock_get_tracer,
agent,
model,
mock_tracer,
):
mock_get_tracer.return_value = mock_tracer
cycle_span = MagicMock()
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
model_span = MagicMock()
mock_tracer.start_model_invoke_span.return_value = model_span

async def interrupted_stream():
yield {"contentBlockDelta": {"delta": {"text": "test text"}}}
await asyncio.sleep(10)
yield {"contentBlockStop": {}}

model.stream.return_value = interrupted_stream()

stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
await anext(stream)
await anext(stream)
await anext(stream)
await stream.aclose()

assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span]
assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [
"",
"",
]


@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_cycle_closes_spans_on_task_cancellation(
mock_get_tracer,
agent,
model,
mock_tracer,
):
mock_get_tracer.return_value = mock_tracer
cycle_span = MagicMock()
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
model_span = MagicMock()
mock_tracer.start_model_invoke_span.return_value = model_span

blocked_on_stream = asyncio.Event()
release_stream = asyncio.Event()

async def interrupted_stream():
yield {"contentBlockDelta": {"delta": {"text": "test text"}}}
blocked_on_stream.set()
await release_stream.wait()
yield {"contentBlockStop": {}}

model.stream.return_value = interrupted_stream()

async def consume() -> None:
stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
async for _ in stream:
pass

task = asyncio.create_task(consume())
await blocked_on_stream.wait()
task.cancel()

with pytest.raises(asyncio.CancelledError):
await task

assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span]
assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [
"",
"",
]


@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_cycle_closes_spans_on_keyboard_interrupt(
mock_get_tracer,
agent,
model,
mock_tracer,
alist,
):
mock_get_tracer.return_value = mock_tracer
cycle_span = MagicMock()
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
model_span = MagicMock()
mock_tracer.start_model_invoke_span.return_value = model_span

test_exception = KeyboardInterrupt("stop now")
model.stream.side_effect = test_exception

with pytest.raises(KeyboardInterrupt, match="stop now"):
stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
await alist(stream)

assert mock_tracer.end_span_with_error.call_args_list == [
call(model_span, "stop now", test_exception),
call(cycle_span, "stop now", test_exception),
]


@pytest.mark.asyncio
async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle(
agent,
Expand Down
36 changes: 36 additions & 0 deletions tests/strands/telemetry/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ def test_end_span_with_empty_exception_message_uses_exception_name(mock_span):
mock_span.end.assert_called_once()


def test_end_span_with_empty_base_exception_message_uses_exception_name(mock_span):
"""Test that empty BaseException messages fall back to the exception type name."""
tracer = Tracer()
error = KeyboardInterrupt()

tracer.end_span_with_error(mock_span, "", error)

mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "KeyboardInterrupt")
mock_span.record_exception.assert_called_once_with(error)
mock_span.end.assert_called_once()


def test_end_span_with_error_prefers_explicit_message(mock_span):
"""Test that an explicit error message takes precedence over the exception text."""
tracer = Tracer()
Expand Down Expand Up @@ -1162,6 +1174,30 @@ def test_force_flush_with_error(mock_span, mock_get_tracer_provider):
mock_tracer_provider.force_flush.assert_called_once()


def test_end_agent_span_with_empty_error_message_uses_exception_name(mock_span):
"""Test that agent spans fall back to the exception type name for empty errors."""
tracer = Tracer()
error = Exception()

tracer.end_agent_span(mock_span, error=error)

mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception")
mock_span.record_exception.assert_called_once_with(error)
mock_span.end.assert_called_once()


def test_end_tool_call_span_with_empty_error_message_uses_exception_name(mock_span):
"""Test that tool call spans fall back to the exception type name for empty errors."""
tracer = Tracer()
error = Exception()

tracer.end_tool_call_span(mock_span, None, error=error)

mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception")
mock_span.record_exception.assert_called_once_with(error)
mock_span.end.assert_called_once()


def test_end_tool_call_span_with_none(mock_span):
"""Test ending a tool call span with None result."""
tracer = Tracer()
Expand Down