Skip to content

Commit 8203912

Browse files
fix(anthropic): Patch Stream.close() and MessageStream.close() to finish spans (#5674)
Finish AI Client Spans when `close()` is called and not merely when the streamed response iterator is consumed or the GC collects the `_iterator` instance variable. The `close()` method closes the HTTP connection.
1 parent c4df76b commit 8203912

File tree

2 files changed

+369
-48
lines changed

2 files changed

+369
-48
lines changed

sentry_sdk/integrations/anthropic.py

Lines changed: 157 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939

4040
from anthropic import Stream, AsyncStream
4141
from anthropic.resources import AsyncMessages, Messages
42-
from anthropic.lib.streaming import MessageStreamManager, AsyncMessageStreamManager
42+
from anthropic.lib.streaming import (
43+
MessageStreamManager,
44+
MessageStream,
45+
AsyncMessageStreamManager,
46+
)
4347

4448
from anthropic.types import (
4549
MessageStartEvent,
@@ -56,7 +60,7 @@
5660
raise DidNotEnable("Anthropic not installed")
5761

5862
if TYPE_CHECKING:
59-
from typing import Any, AsyncIterator, Iterator, Optional, Union
63+
from typing import Any, AsyncIterator, Iterator, Optional, Union, Callable
6064
from sentry_sdk.tracing import Span
6165
from sentry_sdk._types import TextPart
6266

@@ -67,7 +71,7 @@
6771
TextBlockParam,
6872
ToolUnionParam,
6973
)
70-
from anthropic.lib.streaming import MessageStream, AsyncMessageStream
74+
from anthropic.lib.streaming import AsyncMessageStream
7175

7276

7377
class _RecordedUsage:
@@ -89,14 +93,35 @@ def setup_once() -> None:
8993
version = package_version("anthropic")
9094
_check_minimum_version(AnthropicIntegration, version)
9195

96+
"""
97+
client.messages.create(stream=True) can return an instance of the Stream class, which implements the iterator protocol.
98+
The private _iterator variable and the close() method are patched. During iteration over the _iterator generator,
99+
information from intercepted events is accumulated and used to populate output attributes on the AI Client Span.
100+
101+
The span can be finished in two places:
102+
- When the user exits the context manager or directly calls close(), the patched close() finishes the span.
103+
- When iteration ends, the finally block in the _iterator wrapper finishes the span.
104+
105+
Both paths may run. For example, the context manager exit can follow iterator exhaustion.
106+
"""
92107
Messages.create = _wrap_message_create(Messages.create)
108+
Stream.close = _wrap_close(Stream.close)
109+
93110
AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create)
94111

112+
"""
113+
client.messages.stream() patches are analogous to the patches for client.messages.create(stream=True) described above.
114+
"""
95115
Messages.stream = _wrap_message_stream(Messages.stream)
96116
MessageStreamManager.__enter__ = _wrap_message_stream_manager_enter(
97117
MessageStreamManager.__enter__
98118
)
99119

120+
# Before https://github.com/anthropics/anthropic-sdk-python/commit/b1a1c0354a9aca450a7d512fdbdeb59c0ead688a
121+
# MessageStream inherits from Stream, so patching Stream is sufficient on these versions.
122+
if not issubclass(MessageStream, Stream):
123+
MessageStream.close = _wrap_close(MessageStream.close)
124+
100125
AsyncMessages.stream = _wrap_async_message_stream(AsyncMessages.stream)
101126
AsyncMessageStreamManager.__aenter__ = (
102127
_wrap_async_message_stream_manager_aenter(
@@ -399,21 +424,13 @@ def _set_create_input_data(
399424

400425

401426
def _wrap_synchronous_message_iterator(
427+
stream: "Union[Stream, MessageStream]",
402428
iterator: "Iterator[Union[RawMessageStreamEvent, MessageStreamEvent]]",
403-
span: "Span",
404-
integration: "AnthropicIntegration",
405429
) -> "Iterator[Union[RawMessageStreamEvent, MessageStreamEvent]]":
406430
"""
407431
Sets information received while iterating the response stream on the AI Client Span.
408-
Responsible for closing the AI Client Span.
432+
Responsible for closing the AI Client Span unless the span has already been closed in the close() patch.
409433
"""
410-
411-
model = None
412-
usage = _RecordedUsage()
413-
content_blocks: "list[str]" = []
414-
response_id = None
415-
finish_reason = None
416-
417434
try:
418435
for event in iterator:
419436
# Message and content types are aliases for corresponding Raw* types, introduced in
@@ -432,40 +449,21 @@ def _wrap_synchronous_message_iterator(
432449
yield event
433450
continue
434451

435-
(model, usage, content_blocks, response_id, finish_reason) = (
436-
_collect_ai_data(
437-
event,
438-
model,
439-
usage,
440-
content_blocks,
441-
response_id,
442-
finish_reason,
443-
)
444-
)
452+
_accumulate_event_data(stream, event)
445453
yield event
446454
finally:
447455
with capture_internal_exceptions():
448-
# Anthropic's input_tokens excludes cached/cache_write tokens.
449-
# Normalize to total input tokens for correct cost calculations.
450-
total_input = (
451-
usage.input_tokens
452-
+ (usage.cache_read_input_tokens or 0)
453-
+ (usage.cache_write_input_tokens or 0)
454-
)
455-
456-
_set_output_data(
457-
span=span,
458-
integration=integration,
459-
model=model,
460-
input_tokens=total_input,
461-
output_tokens=usage.output_tokens,
462-
cache_read_input_tokens=usage.cache_read_input_tokens,
463-
cache_write_input_tokens=usage.cache_write_input_tokens,
464-
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
465-
finish_span=True,
466-
response_id=response_id,
467-
finish_reason=finish_reason,
468-
)
456+
if hasattr(stream, "_span"):
457+
_finish_streaming_span(
458+
span=stream._span,
459+
integration=stream._integration,
460+
model=stream._model,
461+
usage=stream._usage,
462+
content_blocks=stream._content_blocks,
463+
response_id=stream._response_id,
464+
finish_reason=stream._finish_reason,
465+
)
466+
del stream._span
469467

470468

471469
async def _wrap_asynchronous_message_iterator(
@@ -625,9 +623,15 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
625623
result = yield f, args, kwargs
626624

627625
if isinstance(result, Stream):
626+
result._span = span
627+
result._integration = integration
628+
629+
_initialize_data_accumulation_state(result)
628630
result._iterator = _wrap_synchronous_message_iterator(
629-
result._iterator, span, integration
631+
result,
632+
result._iterator,
630633
)
634+
631635
return result
632636

633637
if isinstance(result, AsyncStream):
@@ -712,6 +716,108 @@ def _sentry_patched_create_sync(*args: "Any", **kwargs: "Any") -> "Any":
712716
return _sentry_patched_create_sync
713717

714718

719+
def _initialize_data_accumulation_state(stream: "Union[Stream, MessageStream]") -> None:
720+
"""
721+
Initialize fields for accumulating output on the Stream instance.
722+
"""
723+
if not hasattr(stream, "_model"):
724+
stream._model = None
725+
stream._usage = _RecordedUsage()
726+
stream._content_blocks = []
727+
stream._response_id = None
728+
stream._finish_reason = None
729+
730+
731+
def _accumulate_event_data(
732+
stream: "Union[Stream, MessageStream]",
733+
event: "Union[RawMessageStreamEvent, MessageStreamEvent]",
734+
) -> None:
735+
"""
736+
Update accumulated output from a single stream event.
737+
"""
738+
(model, usage, content_blocks, response_id, finish_reason) = _collect_ai_data(
739+
event,
740+
stream._model,
741+
stream._usage,
742+
stream._content_blocks,
743+
stream._response_id,
744+
stream._finish_reason,
745+
)
746+
747+
stream._model = model
748+
stream._usage = usage
749+
stream._content_blocks = content_blocks
750+
stream._response_id = response_id
751+
stream._finish_reason = finish_reason
752+
753+
754+
def _finish_streaming_span(
755+
span: "Span",
756+
integration: "AnthropicIntegration",
757+
model: "Optional[str]",
758+
usage: "_RecordedUsage",
759+
content_blocks: "list[str]",
760+
response_id: "Optional[str]",
761+
finish_reason: "Optional[str]",
762+
) -> None:
763+
"""
764+
Set output attributes on the AI Client Span and end the span.
765+
"""
766+
# Anthropic's input_tokens excludes cached/cache_write tokens.
767+
# Normalize to total input tokens for correct cost calculations.
768+
total_input = (
769+
usage.input_tokens
770+
+ (usage.cache_read_input_tokens or 0)
771+
+ (usage.cache_write_input_tokens or 0)
772+
)
773+
774+
_set_output_data(
775+
span=span,
776+
integration=integration,
777+
model=model,
778+
input_tokens=total_input,
779+
output_tokens=usage.output_tokens,
780+
cache_read_input_tokens=usage.cache_read_input_tokens,
781+
cache_write_input_tokens=usage.cache_write_input_tokens,
782+
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
783+
finish_span=True,
784+
response_id=response_id,
785+
finish_reason=finish_reason,
786+
)
787+
788+
789+
def _wrap_close(
790+
f: "Callable[..., None]",
791+
) -> "Callable[..., None]":
792+
"""
793+
Closes the AI Client Span unless the finally block in `_wrap_synchronous_message_iterator()` runs first.
794+
"""
795+
796+
def close(self: "Union[Stream, MessageStream]") -> None:
797+
if not hasattr(self, "_span"):
798+
return f(self)
799+
800+
if not hasattr(self, "_model"):
801+
self._span.__exit__(None, None, None)
802+
del self._span
803+
return f(self)
804+
805+
_finish_streaming_span(
806+
span=self._span,
807+
integration=self._integration,
808+
model=self._model,
809+
usage=self._usage,
810+
content_blocks=self._content_blocks,
811+
response_id=self._response_id,
812+
finish_reason=self._finish_reason,
813+
)
814+
del self._span
815+
816+
return f(self)
817+
818+
return close
819+
820+
715821
def _wrap_message_create_async(f: "Any") -> "Any":
716822
async def _execute_async(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
717823
gen = _sentry_patched_create_common(f, *args, **kwargs)
@@ -819,10 +925,13 @@ def _sentry_patched_enter(self: "MessageStreamManager") -> "MessageStream":
819925
tools=self._tools,
820926
)
821927

928+
stream._span = span
929+
stream._integration = integration
930+
931+
_initialize_data_accumulation_state(stream)
822932
stream._iterator = _wrap_synchronous_message_iterator(
823-
iterator=stream._iterator,
824-
span=span,
825-
integration=integration,
933+
stream,
934+
stream._iterator,
826935
)
827936

828937
return stream

0 commit comments

Comments
 (0)