Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ async def stream_async(

self._end_agent_trace_span(response=result)

except Exception as e:
except BaseException as e:
self._end_agent_trace_span(error=e)
raise

Expand Down Expand Up @@ -1044,7 +1044,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
def _end_agent_trace_span(
self,
response: AgentResult | None = None,
error: Exception | None = None,
error: BaseException | None = None,
) -> None:
"""Ends a trace span for the agent.

Expand Down
30 changes: 24 additions & 6 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ async def event_loop_cycle(
custom_trace_attributes=agent.trace_attributes,
)
invocation_state["event_loop_cycle_span"] = cycle_span
model_events: AsyncGenerator[TypedEvent, None] | None = None

with trace_api.use_span(cycle_span, end_on_exit=False):
try:
Expand All @@ -153,15 +154,21 @@ async def event_loop_cycle(
model_events = _handle_model_execution(
agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context
)
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event
try:
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event
finally:
await model_events.aclose()

stop_reason, message, *_ = model_event["stop"]
yield ModelMessageEvent(message=message)
except Exception as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise
except BaseException as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise

try:
if stop_reason == "max_tokens":
Expand Down Expand Up @@ -238,6 +245,9 @@ async def event_loop_cycle(
yield ForceStopEvent(reason=e)
logger.exception("cycle failed")
raise EventLoopException(e, invocation_state["request_state"]) from e
except BaseException as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise


async def recurse_event_loop(
Expand Down Expand Up @@ -323,6 +333,7 @@ async def _handle_model_execution(
system_prompt=agent.system_prompt,
system_prompt_content=agent._system_prompt_content,
)
streamed_events: AsyncGenerator[TypedEvent, None] | None = None
with trace_api.use_span(model_invoke_span, end_on_exit=False):
try:
await agent.hooks.invoke_callbacks_async(
Expand All @@ -338,7 +349,7 @@ async def _handle_model_execution(
else:
tool_specs = agent.tool_registry.get_all_tool_specs()

async for event in stream_messages(
streamed_events = stream_messages(
agent.model,
agent.system_prompt,
agent.messages,
Expand All @@ -348,8 +359,12 @@ async def _handle_model_execution(
invocation_state=invocation_state,
model_state=agent._model_state,
cancel_signal=agent._cancel_signal,
):
yield event
)
try:
async for event in streamed_events:
yield event
finally:
await streamed_events.aclose()

stop_reason, message, usage, metrics = event["stop"]
invocation_state.setdefault("request_state", {})
Expand Down Expand Up @@ -410,6 +425,9 @@ async def _handle_model_execution(
# No retry requested, raise the exception
yield ForceStopEvent(reason=e)
raise e
except BaseException as e:
tracer.end_span_with_error(model_invoke_span, str(e), e)
raise

try:
# Add message in trace and mark the end of the stream messages trace
Expand Down
10 changes: 6 additions & 4 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _end_span(
self,
span: Span,
attributes: dict[str, AttributeValue] | None = None,
error: Exception | None = None,
error: BaseException | None = None,
error_message: str | None = None,
) -> None:
"""Generic helper method to end a span.
Expand Down Expand Up @@ -224,7 +224,7 @@ def _end_span(
except Exception as e:
logger.warning("error=<%s> | failed to force flush tracer provider", e)

def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None:
def end_span_with_error(self, span: Span, error_message: str, exception: BaseException | None = None) -> None:
"""End a span with error status.

Args:
Expand Down Expand Up @@ -450,7 +450,9 @@ def start_tool_call_span(

return span

def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None:
def end_tool_call_span(
self, span: Span, tool_result: ToolResult | None, error: BaseException | None = None
) -> None:
"""End a tool call span with results.

Args:
Expand Down Expand Up @@ -650,7 +652,7 @@ def end_agent_span(
self,
span: Span,
response: AgentResult | None = None,
error: Exception | None = None,
error: BaseException | None = None,
) -> None:
"""End an agent span with results and metrics.

Expand Down
21 changes: 21 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,27 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)


@pytest.mark.asyncio
@unittest.mock.patch("strands.agent.agent.get_tracer")
async def test_agent_stream_async_creates_and_ends_span_on_base_exception(mock_get_tracer, mock_model, alist):
"""Test that stream_async ends the agent span when a BaseException occurs."""
mock_tracer = unittest.mock.MagicMock()
mock_span = unittest.mock.MagicMock()
mock_tracer.start_agent_span.return_value = mock_span
mock_get_tracer.return_value = mock_tracer

test_exception = KeyboardInterrupt("stop now")
mock_model.mock_stream.side_effect = test_exception

agent = Agent(model=mock_model)

with pytest.raises(KeyboardInterrupt, match="stop now"):
stream = agent.stream_async("test prompt")
await alist(stream)

mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)


def test_agent_init_with_state_object():
agent = Agent(state=AgentState({"foo": "bar"}))
assert agent.state.get("foo") == "bar"
Expand Down
115 changes: 115 additions & 0 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,121 @@ async def test_event_loop_tracing_with_tool_execution(
assert mock_tracer.end_model_invoke_span.call_count == 2


@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_cycle_closes_spans_on_stream_aclose(
mock_get_tracer,
agent,
model,
mock_tracer,
):
mock_get_tracer.return_value = mock_tracer
cycle_span = MagicMock()
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
model_span = MagicMock()
mock_tracer.start_model_invoke_span.return_value = model_span

async def interrupted_stream():
yield {"contentBlockDelta": {"delta": {"text": "test text"}}}
await asyncio.sleep(10)
yield {"contentBlockStop": {}}

model.stream.return_value = interrupted_stream()

stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
await anext(stream)
await anext(stream)
await anext(stream)
await stream.aclose()

assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span]
assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [
"",
"",
]


@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_cycle_closes_spans_on_task_cancellation(
mock_get_tracer,
agent,
model,
mock_tracer,
):
mock_get_tracer.return_value = mock_tracer
cycle_span = MagicMock()
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
model_span = MagicMock()
mock_tracer.start_model_invoke_span.return_value = model_span

blocked_on_stream = asyncio.Event()
release_stream = asyncio.Event()

async def interrupted_stream():
yield {"contentBlockDelta": {"delta": {"text": "test text"}}}
blocked_on_stream.set()
await release_stream.wait()
yield {"contentBlockStop": {}}

model.stream.return_value = interrupted_stream()

async def consume() -> None:
stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
async for _ in stream:
pass

task = asyncio.create_task(consume())
await blocked_on_stream.wait()
task.cancel()

with pytest.raises(asyncio.CancelledError):
await task

assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span]
assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [
"",
"",
]


@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_cycle_closes_spans_on_keyboard_interrupt(
mock_get_tracer,
agent,
model,
mock_tracer,
alist,
):
mock_get_tracer.return_value = mock_tracer
cycle_span = MagicMock()
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
model_span = MagicMock()
mock_tracer.start_model_invoke_span.return_value = model_span

test_exception = KeyboardInterrupt("stop now")
model.stream.side_effect = test_exception

with pytest.raises(KeyboardInterrupt, match="stop now"):
stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
await alist(stream)

assert mock_tracer.end_span_with_error.call_args_list == [
call(model_span, "stop now", test_exception),
call(cycle_span, "stop now", test_exception),
]


@pytest.mark.asyncio
async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle(
agent,
Expand Down
36 changes: 36 additions & 0 deletions tests/strands/telemetry/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ def test_end_span_with_empty_exception_message_uses_exception_name(mock_span):
mock_span.end.assert_called_once()


def test_end_span_with_empty_base_exception_message_uses_exception_name(mock_span):
"""Test that empty BaseException messages fall back to the exception type name."""
tracer = Tracer()
error = KeyboardInterrupt()

tracer.end_span_with_error(mock_span, "", error)

mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "KeyboardInterrupt")
mock_span.record_exception.assert_called_once_with(error)
mock_span.end.assert_called_once()


def test_end_span_with_error_prefers_explicit_message(mock_span):
"""Test that an explicit error message takes precedence over the exception text."""
tracer = Tracer()
Expand Down Expand Up @@ -1162,6 +1174,30 @@ def test_force_flush_with_error(mock_span, mock_get_tracer_provider):
mock_tracer_provider.force_flush.assert_called_once()


def test_end_agent_span_with_empty_error_message_uses_exception_name(mock_span):
"""Test that agent spans fall back to the exception type name for empty errors."""
tracer = Tracer()
error = Exception()

tracer.end_agent_span(mock_span, error=error)

mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception")
mock_span.record_exception.assert_called_once_with(error)
mock_span.end.assert_called_once()


def test_end_tool_call_span_with_empty_error_message_uses_exception_name(mock_span):
"""Test that tool call spans fall back to the exception type name for empty errors."""
tracer = Tracer()
error = Exception()

tracer.end_tool_call_span(mock_span, None, error=error)

mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception")
mock_span.record_exception.assert_called_once_with(error)
mock_span.end.assert_called_once()


def test_end_tool_call_span_with_none(mock_span):
"""Test ending a tool call span with None result."""
tracer = Tracer()
Expand Down