Skip to content

Commit c5342cc

Browse files
committed
fix: handle BaseException in trace spans to prevent span leaks on KeyboardInterrupt
Trace spans were not properly closed when BaseException (e.g. KeyboardInterrupt, asyncio.CancelledError) was raised. Add explicit BaseException handlers to close spans and aclose() calls to ensure async generators are cleaned up.
1 parent b8bd925 commit c5342cc

File tree

6 files changed

+204
-12
lines changed

6 files changed

+204
-12
lines changed

src/strands/agent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ async def stream_async(
785785

786786
self._end_agent_trace_span(response=result)
787787

788-
except Exception as e:
788+
except BaseException as e:
789789
self._end_agent_trace_span(error=e)
790790
raise
791791

@@ -988,7 +988,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
988988
def _end_agent_trace_span(
989989
self,
990990
response: AgentResult | None = None,
991-
error: Exception | None = None,
991+
error: BaseException | None = None,
992992
) -> None:
993993
"""Ends a trace span for the agent.
994994

src/strands/event_loop/event_loop.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ async def event_loop_cycle(
138138
custom_trace_attributes=agent.trace_attributes,
139139
)
140140
invocation_state["event_loop_cycle_span"] = cycle_span
141+
model_events: AsyncGenerator[TypedEvent, None] | None = None
141142

142143
with trace_api.use_span(cycle_span, end_on_exit=False):
143144
try:
@@ -153,15 +154,21 @@ async def event_loop_cycle(
153154
model_events = _handle_model_execution(
154155
agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context
155156
)
156-
async for model_event in model_events:
157-
if not isinstance(model_event, ModelStopReason):
158-
yield model_event
157+
try:
158+
async for model_event in model_events:
159+
if not isinstance(model_event, ModelStopReason):
160+
yield model_event
161+
finally:
162+
await model_events.aclose()
159163

160164
stop_reason, message, *_ = model_event["stop"]
161165
yield ModelMessageEvent(message=message)
162166
except Exception as e:
163167
tracer.end_span_with_error(cycle_span, str(e), e)
164168
raise
169+
except BaseException as e:
170+
tracer.end_span_with_error(cycle_span, str(e), e)
171+
raise
165172

166173
try:
167174
if stop_reason == "max_tokens":
@@ -241,6 +248,9 @@ async def event_loop_cycle(
241248
yield ForceStopEvent(reason=e)
242249
logger.exception("cycle failed")
243250
raise EventLoopException(e, invocation_state["request_state"]) from e
251+
except BaseException as e:
252+
tracer.end_span_with_error(cycle_span, str(e), e)
253+
raise
244254

245255

246256
async def recurse_event_loop(
@@ -324,6 +334,7 @@ async def _handle_model_execution(
324334
model_id=model_id,
325335
custom_trace_attributes=agent.trace_attributes,
326336
)
337+
streamed_events: AsyncGenerator[TypedEvent, None] | None = None
327338
with trace_api.use_span(model_invoke_span, end_on_exit=False):
328339
try:
329340
await agent.hooks.invoke_callbacks_async(
@@ -339,7 +350,7 @@ async def _handle_model_execution(
339350
else:
340351
tool_specs = agent.tool_registry.get_all_tool_specs()
341352

342-
async for event in stream_messages(
353+
streamed_events = stream_messages(
343354
agent.model,
344355
agent.system_prompt,
345356
agent.messages,
@@ -348,8 +359,12 @@ async def _handle_model_execution(
348359
tool_choice=structured_output_context.tool_choice,
349360
invocation_state=invocation_state,
350361
cancel_signal=agent._cancel_signal,
351-
):
352-
yield event
362+
)
363+
try:
364+
async for event in streamed_events:
365+
yield event
366+
finally:
367+
await streamed_events.aclose()
353368

354369
stop_reason, message, usage, metrics = event["stop"]
355370
invocation_state.setdefault("request_state", {})
@@ -410,6 +425,9 @@ async def _handle_model_execution(
410425
# No retry requested, raise the exception
411426
yield ForceStopEvent(reason=e)
412427
raise e
428+
except BaseException as e:
429+
tracer.end_span_with_error(model_invoke_span, str(e), e)
430+
raise
413431

414432
try:
415433
# Add message in trace and mark the end of the stream messages trace

src/strands/telemetry/tracer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _end_span(
184184
self,
185185
span: Span,
186186
attributes: dict[str, AttributeValue] | None = None,
187-
error: Exception | None = None,
187+
error: BaseException | None = None,
188188
error_message: str | None = None,
189189
) -> None:
190190
"""Generic helper method to end a span.
@@ -224,7 +224,7 @@ def _end_span(
224224
except Exception as e:
225225
logger.warning("error=<%s> | failed to force flush tracer provider", e)
226226

227-
def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None:
227+
def end_span_with_error(self, span: Span, error_message: str, exception: BaseException | None = None) -> None:
228228
"""End a span with error status.
229229
230230
Args:
@@ -445,7 +445,9 @@ def start_tool_call_span(
445445

446446
return span
447447

448-
def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None:
448+
def end_tool_call_span(
449+
self, span: Span, tool_result: ToolResult | None, error: BaseException | None = None
450+
) -> None:
449451
"""End a tool call span with results.
450452
451453
Args:
@@ -645,7 +647,7 @@ def end_agent_span(
645647
self,
646648
span: Span,
647649
response: AgentResult | None = None,
648-
error: Exception | None = None,
650+
error: BaseException | None = None,
649651
) -> None:
650652
"""End an agent span with results and metrics.
651653

tests/strands/agent/test_agent.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,27 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr
14151415
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)
14161416

14171417

1418+
@pytest.mark.asyncio
1419+
@unittest.mock.patch("strands.agent.agent.get_tracer")
1420+
async def test_agent_stream_async_creates_and_ends_span_on_base_exception(mock_get_tracer, mock_model, alist):
1421+
"""Test that stream_async ends the agent span when a BaseException occurs."""
1422+
mock_tracer = unittest.mock.MagicMock()
1423+
mock_span = unittest.mock.MagicMock()
1424+
mock_tracer.start_agent_span.return_value = mock_span
1425+
mock_get_tracer.return_value = mock_tracer
1426+
1427+
test_exception = KeyboardInterrupt("stop now")
1428+
mock_model.mock_stream.side_effect = test_exception
1429+
1430+
agent = Agent(model=mock_model)
1431+
1432+
with pytest.raises(KeyboardInterrupt, match="stop now"):
1433+
stream = agent.stream_async("test prompt")
1434+
await alist(stream)
1435+
1436+
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)
1437+
1438+
14181439
def test_agent_init_with_state_object():
14191440
agent = Agent(state=AgentState({"foo": "bar"}))
14201441
assert agent.state.get("foo") == "bar"

tests/strands/event_loop/test_event_loop.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,121 @@ async def test_event_loop_tracing_with_tool_execution(
680680
assert mock_tracer.end_model_invoke_span.call_count == 2
681681

682682

683+
@patch("strands.event_loop.event_loop.get_tracer")
684+
@pytest.mark.asyncio
685+
async def test_event_loop_cycle_closes_spans_on_stream_aclose(
686+
mock_get_tracer,
687+
agent,
688+
model,
689+
mock_tracer,
690+
):
691+
mock_get_tracer.return_value = mock_tracer
692+
cycle_span = MagicMock()
693+
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
694+
model_span = MagicMock()
695+
mock_tracer.start_model_invoke_span.return_value = model_span
696+
697+
async def interrupted_stream():
698+
yield {"contentBlockDelta": {"delta": {"text": "test text"}}}
699+
await asyncio.sleep(10)
700+
yield {"contentBlockStop": {}}
701+
702+
model.stream.return_value = interrupted_stream()
703+
704+
stream = strands.event_loop.event_loop.event_loop_cycle(
705+
agent=agent,
706+
invocation_state={},
707+
)
708+
await anext(stream)
709+
await anext(stream)
710+
await anext(stream)
711+
await stream.aclose()
712+
713+
assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span]
714+
assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [
715+
"",
716+
"",
717+
]
718+
719+
720+
@patch("strands.event_loop.event_loop.get_tracer")
721+
@pytest.mark.asyncio
722+
async def test_event_loop_cycle_closes_spans_on_task_cancellation(
723+
mock_get_tracer,
724+
agent,
725+
model,
726+
mock_tracer,
727+
):
728+
mock_get_tracer.return_value = mock_tracer
729+
cycle_span = MagicMock()
730+
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
731+
model_span = MagicMock()
732+
mock_tracer.start_model_invoke_span.return_value = model_span
733+
734+
blocked_on_stream = asyncio.Event()
735+
release_stream = asyncio.Event()
736+
737+
async def interrupted_stream():
738+
yield {"contentBlockDelta": {"delta": {"text": "test text"}}}
739+
blocked_on_stream.set()
740+
await release_stream.wait()
741+
yield {"contentBlockStop": {}}
742+
743+
model.stream.return_value = interrupted_stream()
744+
745+
async def consume() -> None:
746+
stream = strands.event_loop.event_loop.event_loop_cycle(
747+
agent=agent,
748+
invocation_state={},
749+
)
750+
async for _ in stream:
751+
pass
752+
753+
task = asyncio.create_task(consume())
754+
await blocked_on_stream.wait()
755+
task.cancel()
756+
757+
with pytest.raises(asyncio.CancelledError):
758+
await task
759+
760+
assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span]
761+
assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [
762+
"",
763+
"",
764+
]
765+
766+
767+
@patch("strands.event_loop.event_loop.get_tracer")
768+
@pytest.mark.asyncio
769+
async def test_event_loop_cycle_closes_spans_on_keyboard_interrupt(
770+
mock_get_tracer,
771+
agent,
772+
model,
773+
mock_tracer,
774+
alist,
775+
):
776+
mock_get_tracer.return_value = mock_tracer
777+
cycle_span = MagicMock()
778+
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
779+
model_span = MagicMock()
780+
mock_tracer.start_model_invoke_span.return_value = model_span
781+
782+
test_exception = KeyboardInterrupt("stop now")
783+
model.stream.side_effect = test_exception
784+
785+
with pytest.raises(KeyboardInterrupt, match="stop now"):
786+
stream = strands.event_loop.event_loop.event_loop_cycle(
787+
agent=agent,
788+
invocation_state={},
789+
)
790+
await alist(stream)
791+
792+
assert mock_tracer.end_span_with_error.call_args_list == [
793+
call(model_span, "stop now", test_exception),
794+
call(cycle_span, "stop now", test_exception),
795+
]
796+
797+
683798
@pytest.mark.asyncio
684799
async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle(
685800
agent,

tests/strands/telemetry/test_tracer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,18 @@ def test_end_span_with_empty_exception_message_uses_exception_name(mock_span):
140140
mock_span.end.assert_called_once()
141141

142142

143+
def test_end_span_with_empty_base_exception_message_uses_exception_name(mock_span):
144+
"""Test that empty BaseException messages fall back to the exception type name."""
145+
tracer = Tracer()
146+
error = KeyboardInterrupt()
147+
148+
tracer.end_span_with_error(mock_span, "", error)
149+
150+
mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "KeyboardInterrupt")
151+
mock_span.record_exception.assert_called_once_with(error)
152+
mock_span.end.assert_called_once()
153+
154+
143155
def test_end_span_with_error_prefers_explicit_message(mock_span):
144156
"""Test that an explicit error message takes precedence over the exception text."""
145157
tracer = Tracer()
@@ -1092,6 +1104,30 @@ def test_force_flush_with_error(mock_span, mock_get_tracer_provider):
10921104
mock_tracer_provider.force_flush.assert_called_once()
10931105

10941106

1107+
def test_end_agent_span_with_empty_error_message_uses_exception_name(mock_span):
1108+
"""Test that agent spans fall back to the exception type name for empty errors."""
1109+
tracer = Tracer()
1110+
error = Exception()
1111+
1112+
tracer.end_agent_span(mock_span, error=error)
1113+
1114+
mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception")
1115+
mock_span.record_exception.assert_called_once_with(error)
1116+
mock_span.end.assert_called_once()
1117+
1118+
1119+
def test_end_tool_call_span_with_empty_error_message_uses_exception_name(mock_span):
1120+
"""Test that tool call spans fall back to the exception type name for empty errors."""
1121+
tracer = Tracer()
1122+
error = Exception()
1123+
1124+
tracer.end_tool_call_span(mock_span, None, error=error)
1125+
1126+
mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception")
1127+
mock_span.record_exception.assert_called_once_with(error)
1128+
mock_span.end.assert_called_once()
1129+
1130+
10951131
def test_end_tool_call_span_with_none(mock_span):
10961132
"""Test ending a tool call span with None result."""
10971133
tracer = Tracer()

0 commit comments

Comments
 (0)