Skip to content

Commit 6d33d36

Browse files
fix(anthropic): Patch AsyncStream.close() and AsyncMessageStream.close() to finish spans (#5675)
Finish AI Client Spans when `close()` is called and not merely when the streamed response iterator is consumed or the GC collects the `_iterator` instance variable. The `close()` method closes the HTTP connection.
1 parent 8203912 commit 6d33d36

File tree

2 files changed

+292
-50
lines changed

2 files changed

+292
-50
lines changed

sentry_sdk/integrations/anthropic.py

Lines changed: 76 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
MessageStreamManager,
4444
MessageStream,
4545
AsyncMessageStreamManager,
46+
AsyncMessageStream,
4647
)
4748

4849
from anthropic.types import (
@@ -60,7 +61,15 @@
6061
raise DidNotEnable("Anthropic not installed")
6162

6263
if TYPE_CHECKING:
63-
from typing import Any, AsyncIterator, Iterator, Optional, Union, Callable
64+
from typing import (
65+
Any,
66+
AsyncIterator,
67+
Iterator,
68+
Optional,
69+
Union,
70+
Callable,
71+
Awaitable,
72+
)
6473
from sentry_sdk.tracing import Span
6574
from sentry_sdk._types import TextPart
6675

@@ -71,7 +80,6 @@
7180
TextBlockParam,
7281
ToolUnionParam,
7382
)
74-
from anthropic.lib.streaming import AsyncMessageStream
7583

7684

7785
class _RecordedUsage:
@@ -95,6 +103,7 @@ def setup_once() -> None:
95103

96104
"""
97105
client.messages.create(stream=True) can return an instance of the Stream class, which implements the iterator protocol.
106+
Analogously, the function can return an AsyncStream, which implements the asynchronous iterator protocol.
98107
The private _iterator variable and the close() method are patched. During iteration over the _iterator generator,
99108
information from intercepted events is accumulated and used to populate output attributes on the AI Client Span.
100109
@@ -108,6 +117,7 @@ def setup_once() -> None:
108117
Stream.close = _wrap_close(Stream.close)
109118

110119
AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create)
120+
AsyncStream.close = _wrap_async_close(AsyncStream.close)
111121

112122
"""
113123
client.messages.stream() patches are analogous to the patches for client.messages.create(stream=True) described above.
@@ -129,6 +139,11 @@ def setup_once() -> None:
129139
)
130140
)
131141

142+
# Before https://github.com/anthropics/anthropic-sdk-python/commit/b1a1c0354a9aca450a7d512fdbdeb59c0ead688a
143+
# AsyncMessageStream inherits from AsyncStream, so patching Stream is sufficient on these versions.
144+
if not issubclass(AsyncMessageStream, AsyncStream):
145+
AsyncMessageStream.close = _wrap_async_close(AsyncMessageStream.close)
146+
132147

133148
def _capture_exception(exc: "Any") -> None:
134149
set_span_errored()
@@ -467,20 +482,13 @@ def _wrap_synchronous_message_iterator(
467482

468483

469484
async def _wrap_asynchronous_message_iterator(
485+
stream: "Union[AsyncStream, AsyncMessageStream]",
470486
iterator: "AsyncIterator[Union[RawMessageStreamEvent, MessageStreamEvent]]",
471-
span: "Span",
472-
integration: "AnthropicIntegration",
473487
) -> "AsyncIterator[Union[RawMessageStreamEvent, MessageStreamEvent]]":
474488
"""
475489
Sets information received while iterating the response stream on the AI Client Span.
476-
Responsible for closing the AI Client Span.
490+
Responsible for closing the AI Client Span unless the span has already been closed in the close() patch.
477491
"""
478-
model = None
479-
usage = _RecordedUsage()
480-
content_blocks: "list[str]" = []
481-
response_id = None
482-
finish_reason = None
483-
484492
try:
485493
async for event in iterator:
486494
# Message and content types are aliases for corresponding Raw* types, introduced in
@@ -499,44 +507,21 @@ async def _wrap_asynchronous_message_iterator(
499507
yield event
500508
continue
501509

502-
(
503-
model,
504-
usage,
505-
content_blocks,
506-
response_id,
507-
finish_reason,
508-
) = _collect_ai_data(
509-
event,
510-
model,
511-
usage,
512-
content_blocks,
513-
response_id,
514-
finish_reason,
515-
)
510+
_accumulate_event_data(stream, event)
516511
yield event
517512
finally:
518513
with capture_internal_exceptions():
519-
# Anthropic's input_tokens excludes cached/cache_write tokens.
520-
# Normalize to total input tokens for correct cost calculations.
521-
total_input = (
522-
usage.input_tokens
523-
+ (usage.cache_read_input_tokens or 0)
524-
+ (usage.cache_write_input_tokens or 0)
525-
)
526-
527-
_set_output_data(
528-
span=span,
529-
integration=integration,
530-
model=model,
531-
input_tokens=total_input,
532-
output_tokens=usage.output_tokens,
533-
cache_read_input_tokens=usage.cache_read_input_tokens,
534-
cache_write_input_tokens=usage.cache_write_input_tokens,
535-
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
536-
finish_span=True,
537-
response_id=response_id,
538-
finish_reason=finish_reason,
539-
)
514+
if hasattr(stream, "_span"):
515+
_finish_streaming_span(
516+
span=stream._span,
517+
integration=stream._integration,
518+
model=stream._model,
519+
usage=stream._usage,
520+
content_blocks=stream._content_blocks,
521+
response_id=stream._response_id,
522+
finish_reason=stream._finish_reason,
523+
)
524+
del stream._span
540525

541526

542527
def _set_output_data(
@@ -635,9 +620,15 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
635620
return result
636621

637622
if isinstance(result, AsyncStream):
623+
result._span = span
624+
result._integration = integration
625+
626+
_initialize_data_accumulation_state(result)
638627
result._iterator = _wrap_asynchronous_message_iterator(
639-
result._iterator, span, integration
628+
result,
629+
result._iterator,
640630
)
631+
641632
return result
642633

643634
with capture_internal_exceptions():
@@ -856,6 +847,38 @@ async def _sentry_patched_create_async(*args: "Any", **kwargs: "Any") -> "Any":
856847
return _sentry_patched_create_async
857848

858849

850+
def _wrap_async_close(
851+
f: "Callable[..., Awaitable[None]]",
852+
) -> "Callable[..., Awaitable[None]]":
853+
"""
854+
Closes the AI Client Span unless the finally block in `_wrap_asynchronous_message_iterator()` runs first.
855+
"""
856+
857+
async def close(self: "AsyncStream") -> None:
858+
if not hasattr(self, "_span"):
859+
return await f(self)
860+
861+
if not hasattr(self, "_model"):
862+
self._span.__exit__(None, None, None)
863+
del self._span
864+
return await f(self)
865+
866+
_finish_streaming_span(
867+
span=self._span,
868+
integration=self._integration,
869+
model=self._model,
870+
usage=self._usage,
871+
content_blocks=self._content_blocks,
872+
response_id=self._response_id,
873+
finish_reason=self._finish_reason,
874+
)
875+
del self._span
876+
877+
return await f(self)
878+
879+
return close
880+
881+
859882
def _wrap_message_stream(f: "Any") -> "Any":
860883
"""
861884
Attaches user-provided arguments to the returned context manager.
@@ -1012,10 +1035,13 @@ async def _sentry_patched_aenter(
10121035
tools=self._tools,
10131036
)
10141037

1038+
stream._span = span
1039+
stream._integration = integration
1040+
1041+
_initialize_data_accumulation_state(stream)
10151042
stream._iterator = _wrap_asynchronous_message_iterator(
1016-
iterator=stream._iterator,
1017-
span=span,
1018-
integration=integration,
1043+
stream,
1044+
stream._iterator,
10191045
)
10201046

10211047
return stream

0 commit comments

Comments
 (0)