4343 MessageStreamManager ,
4444 MessageStream ,
4545 AsyncMessageStreamManager ,
46+ AsyncMessageStream ,
4647 )
4748
4849 from anthropic .types import (
6061 raise DidNotEnable ("Anthropic not installed" )
6162
6263if TYPE_CHECKING :
63- from typing import Any , AsyncIterator , Iterator , Optional , Union , Callable
64+ from typing import (
65+ Any ,
66+ AsyncIterator ,
67+ Iterator ,
68+ Optional ,
69+ Union ,
70+ Callable ,
71+ Awaitable ,
72+ )
6473 from sentry_sdk .tracing import Span
6574 from sentry_sdk ._types import TextPart
6675
7180 TextBlockParam ,
7281 ToolUnionParam ,
7382 )
74- from anthropic .lib .streaming import AsyncMessageStream
7583
7684
7785class _RecordedUsage :
@@ -95,6 +103,7 @@ def setup_once() -> None:
95103
96104 """
97105 client.messages.create(stream=True) can return an instance of the Stream class, which implements the iterator protocol.
106+ Analogously, the function can return an AsyncStream, which implements the asynchronous iterator protocol.
98107 The private _iterator variable and the close() method are patched. During iteration over the _iterator generator,
99108 information from intercepted events is accumulated and used to populate output attributes on the AI Client Span.
100109
@@ -108,6 +117,7 @@ def setup_once() -> None:
108117 Stream .close = _wrap_close (Stream .close )
109118
110119 AsyncMessages .create = _wrap_message_create_async (AsyncMessages .create )
120+ AsyncStream .close = _wrap_async_close (AsyncStream .close )
111121
112122 """
113123 client.messages.stream() patches are analogous to the patches for client.messages.create(stream=True) described above.
@@ -129,6 +139,11 @@ def setup_once() -> None:
129139 )
130140 )
131141
142+ # Before https://github.com/anthropics/anthropic-sdk-python/commit/b1a1c0354a9aca450a7d512fdbdeb59c0ead688a
143+ # AsyncMessageStream inherits from AsyncStream, so patching Stream is sufficient on these versions.
144+ if not issubclass (AsyncMessageStream , AsyncStream ):
145+ AsyncMessageStream .close = _wrap_async_close (AsyncMessageStream .close )
146+
132147
133148def _capture_exception (exc : "Any" ) -> None :
134149 set_span_errored ()
@@ -467,20 +482,13 @@ def _wrap_synchronous_message_iterator(
467482
468483
469484async def _wrap_asynchronous_message_iterator (
485+ stream : "Union[AsyncStream, AsyncMessageStream]" ,
470486 iterator : "AsyncIterator[Union[RawMessageStreamEvent, MessageStreamEvent]]" ,
471- span : "Span" ,
472- integration : "AnthropicIntegration" ,
473487) -> "AsyncIterator[Union[RawMessageStreamEvent, MessageStreamEvent]]" :
474488 """
475489 Sets information received while iterating the response stream on the AI Client Span.
476- Responsible for closing the AI Client Span.
490+ Responsible for closing the AI Client Span unless the span has already been closed in the close() patch .
477491 """
478- model = None
479- usage = _RecordedUsage ()
480- content_blocks : "list[str]" = []
481- response_id = None
482- finish_reason = None
483-
484492 try :
485493 async for event in iterator :
486494 # Message and content types are aliases for corresponding Raw* types, introduced in
@@ -499,44 +507,21 @@ async def _wrap_asynchronous_message_iterator(
499507 yield event
500508 continue
501509
502- (
503- model ,
504- usage ,
505- content_blocks ,
506- response_id ,
507- finish_reason ,
508- ) = _collect_ai_data (
509- event ,
510- model ,
511- usage ,
512- content_blocks ,
513- response_id ,
514- finish_reason ,
515- )
510+ _accumulate_event_data (stream , event )
516511 yield event
517512 finally :
518513 with capture_internal_exceptions ():
519- # Anthropic's input_tokens excludes cached/cache_write tokens.
520- # Normalize to total input tokens for correct cost calculations.
521- total_input = (
522- usage .input_tokens
523- + (usage .cache_read_input_tokens or 0 )
524- + (usage .cache_write_input_tokens or 0 )
525- )
526-
527- _set_output_data (
528- span = span ,
529- integration = integration ,
530- model = model ,
531- input_tokens = total_input ,
532- output_tokens = usage .output_tokens ,
533- cache_read_input_tokens = usage .cache_read_input_tokens ,
534- cache_write_input_tokens = usage .cache_write_input_tokens ,
535- content_blocks = [{"text" : "" .join (content_blocks ), "type" : "text" }],
536- finish_span = True ,
537- response_id = response_id ,
538- finish_reason = finish_reason ,
539- )
514+ if hasattr (stream , "_span" ):
515+ _finish_streaming_span (
516+ span = stream ._span ,
517+ integration = stream ._integration ,
518+ model = stream ._model ,
519+ usage = stream ._usage ,
520+ content_blocks = stream ._content_blocks ,
521+ response_id = stream ._response_id ,
522+ finish_reason = stream ._finish_reason ,
523+ )
524+ del stream ._span
540525
541526
542527def _set_output_data (
@@ -635,9 +620,15 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
635620 return result
636621
637622 if isinstance (result , AsyncStream ):
623+ result ._span = span
624+ result ._integration = integration
625+
626+ _initialize_data_accumulation_state (result )
638627 result ._iterator = _wrap_asynchronous_message_iterator (
639- result ._iterator , span , integration
628+ result ,
629+ result ._iterator ,
640630 )
631+
641632 return result
642633
643634 with capture_internal_exceptions ():
@@ -856,6 +847,38 @@ async def _sentry_patched_create_async(*args: "Any", **kwargs: "Any") -> "Any":
856847 return _sentry_patched_create_async
857848
858849
850+ def _wrap_async_close (
851+ f : "Callable[..., Awaitable[None]]" ,
852+ ) -> "Callable[..., Awaitable[None]]" :
853+ """
854+ Closes the AI Client Span unless the finally block in `_wrap_asynchronous_message_iterator()` runs first.
855+ """
856+
857+ async def close (self : "AsyncStream" ) -> None :
858+ if not hasattr (self , "_span" ):
859+ return await f (self )
860+
861+ if not hasattr (self , "_model" ):
862+ self ._span .__exit__ (None , None , None )
863+ del self ._span
864+ return await f (self )
865+
866+ _finish_streaming_span (
867+ span = self ._span ,
868+ integration = self ._integration ,
869+ model = self ._model ,
870+ usage = self ._usage ,
871+ content_blocks = self ._content_blocks ,
872+ response_id = self ._response_id ,
873+ finish_reason = self ._finish_reason ,
874+ )
875+ del self ._span
876+
877+ return await f (self )
878+
879+ return close
880+
881+
859882def _wrap_message_stream (f : "Any" ) -> "Any" :
860883 """
861884 Attaches user-provided arguments to the returned context manager.
@@ -1012,10 +1035,13 @@ async def _sentry_patched_aenter(
10121035 tools = self ._tools ,
10131036 )
10141037
1038+ stream ._span = span
1039+ stream ._integration = integration
1040+
1041+ _initialize_data_accumulation_state (stream )
10151042 stream ._iterator = _wrap_asynchronous_message_iterator (
1016- iterator = stream ._iterator ,
1017- span = span ,
1018- integration = integration ,
1043+ stream ,
1044+ stream ._iterator ,
10191045 )
10201046
10211047 return stream
0 commit comments