Skip to content

Commit c396a32

Browse files
merge
2 parents fd84837 + 0e06f49 commit c396a32

File tree

1 file changed

+154
-125
lines changed

1 file changed

+154
-125
lines changed

sentry_sdk/integrations/anthropic.py

Lines changed: 154 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
except ImportError:
3838
Omit = None
3939

40+
from anthropic import Stream, AsyncStream
4041
from anthropic.resources import AsyncMessages, Messages
4142

4243
from anthropic.types import (
@@ -54,11 +55,10 @@
5455
raise DidNotEnable("Anthropic not installed")
5556

5657
if TYPE_CHECKING:
57-
from typing import Any, AsyncIterator, Iterator, List, Optional, Union
58+
from typing import Any, AsyncIterator, Iterator, List, Optional, Union, Callable
5859
from sentry_sdk.tracing import Span
5960
from sentry_sdk._types import TextPart
6061

61-
from anthropic import AsyncStream
6262
from anthropic.types import RawMessageStreamEvent
6363

6464

@@ -84,6 +84,155 @@ def setup_once() -> None:
8484
Messages.create = _wrap_message_create(Messages.create)
8585
AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create)
8686

87+
Stream.__iter__ = _wrap_stream_iter(Stream.__iter__)
88+
AsyncStream.__aiter__ = _wrap_async_stream_aiter(AsyncStream.__aiter__)
89+
90+
91+
def _wrap_stream_iter(
92+
f: "Callable[..., Iterator[RawMessageStreamEvent]]",
93+
) -> "Callable[..., Iterator[RawMessageStreamEvent]]":
94+
"""
95+
Sets information received while iterating the response stream on the AI Client Span.
96+
Responsible for closing the AI Client Span.
97+
"""
98+
99+
@wraps(f)
100+
def _patched_iter(self: "Stream") -> "Iterator[RawMessageStreamEvent]":
101+
if not hasattr(self, "_sentry_span"):
102+
for event in f(self):
103+
yield event
104+
return
105+
106+
model = None
107+
usage = _RecordedUsage()
108+
content_blocks: "list[str]" = []
109+
110+
for event in f(self):
111+
if not isinstance(
112+
event,
113+
(
114+
MessageStartEvent,
115+
MessageDeltaEvent,
116+
MessageStopEvent,
117+
ContentBlockStartEvent,
118+
ContentBlockDeltaEvent,
119+
ContentBlockStopEvent,
120+
),
121+
):
122+
yield event
123+
continue
124+
125+
(
126+
model,
127+
usage,
128+
content_blocks,
129+
) = _collect_ai_data(
130+
event,
131+
model,
132+
usage,
133+
content_blocks,
134+
)
135+
yield event
136+
137+
# Anthropic's input_tokens excludes cached/cache_write tokens.
138+
# Normalize to total input tokens for correct cost calculations.
139+
total_input = (
140+
usage.input_tokens
141+
+ (usage.cache_read_input_tokens or 0)
142+
+ (usage.cache_write_input_tokens or 0)
143+
)
144+
145+
span = self._sentry_span
146+
integration = self._integration
147+
148+
_set_output_data(
149+
span=span,
150+
integration=integration,
151+
model=model,
152+
input_tokens=total_input,
153+
output_tokens=usage.output_tokens,
154+
cache_read_input_tokens=usage.cache_read_input_tokens,
155+
cache_write_input_tokens=usage.cache_write_input_tokens,
156+
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
157+
finish_span=True,
158+
)
159+
160+
return _patched_iter
161+
162+
163+
def _wrap_async_stream_aiter(
164+
f: "Callable[..., AsyncIterator[RawMessageStreamEvent]]",
165+
) -> "Callable[..., AsyncIterator[RawMessageStreamEvent]]":
166+
"""
167+
Sets information received while iterating the response stream on the AI Client Span.
168+
Responsible for closing the AI Client Span.
169+
"""
170+
171+
@wraps(f)
172+
async def _patched_aiter(
173+
self: "AsyncStream",
174+
) -> "AsyncIterator[RawMessageStreamEvent]":
175+
if not hasattr(self, "_sentry_span"):
176+
async for event in f(self):
177+
yield event
178+
return
179+
180+
model = None
181+
usage = _RecordedUsage()
182+
content_blocks: "list[str]" = []
183+
184+
async for event in f(self):
185+
if not isinstance(
186+
event,
187+
(
188+
MessageStartEvent,
189+
MessageDeltaEvent,
190+
MessageStopEvent,
191+
ContentBlockStartEvent,
192+
ContentBlockDeltaEvent,
193+
ContentBlockStopEvent,
194+
),
195+
):
196+
yield event
197+
continue
198+
199+
(
200+
model,
201+
usage,
202+
content_blocks,
203+
) = _collect_ai_data(
204+
event,
205+
model,
206+
usage,
207+
content_blocks,
208+
)
209+
yield event
210+
211+
# Anthropic's input_tokens excludes cached/cache_write tokens.
212+
# Normalize to total input tokens for correct cost calculations.
213+
total_input = (
214+
usage.input_tokens
215+
+ (usage.cache_read_input_tokens or 0)
216+
+ (usage.cache_write_input_tokens or 0)
217+
)
218+
219+
span = self._sentry_span
220+
integration = self._integration
221+
222+
_set_output_data(
223+
span=span,
224+
integration=integration,
225+
model=model,
226+
input_tokens=total_input,
227+
output_tokens=usage.output_tokens,
228+
cache_read_input_tokens=usage.cache_read_input_tokens,
229+
cache_write_input_tokens=usage.cache_write_input_tokens,
230+
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
231+
finish_span=True,
232+
)
233+
234+
return _patched_aiter
235+
87236

88237
def _capture_exception(exc: "Any") -> None:
89238
set_span_errored()
@@ -401,126 +550,6 @@ def _set_output_data(
401550
span.__exit__(None, None, None)
402551

403552

404-
def _patch_streaming_response_iterator(
405-
result: "AsyncStream[RawMessageStreamEvent]",
406-
span: "sentry_sdk.tracing.Span",
407-
integration: "AnthropicIntegration",
408-
) -> None:
409-
"""
410-
Responsible for closing the `gen_ai.chat` span and setting attributes acquired during response consumption.
411-
"""
412-
old_iterator = result._iterator
413-
414-
def new_iterator() -> "Iterator[MessageStreamEvent]":
415-
model = None
416-
usage = _RecordedUsage()
417-
content_blocks: "list[str]" = []
418-
419-
for event in old_iterator:
420-
if not isinstance(
421-
event,
422-
(
423-
MessageStartEvent,
424-
MessageDeltaEvent,
425-
MessageStopEvent,
426-
ContentBlockStartEvent,
427-
ContentBlockDeltaEvent,
428-
ContentBlockStopEvent,
429-
),
430-
):
431-
yield event
432-
continue
433-
434-
(
435-
model,
436-
usage,
437-
content_blocks,
438-
) = _collect_ai_data(
439-
event,
440-
model,
441-
usage,
442-
content_blocks,
443-
)
444-
yield event
445-
446-
# Anthropic's input_tokens excludes cached/cache_write tokens.
447-
# Normalize to total input tokens for correct cost calculations.
448-
total_input = (
449-
usage.input_tokens
450-
+ (usage.cache_read_input_tokens or 0)
451-
+ (usage.cache_write_input_tokens or 0)
452-
)
453-
454-
_set_output_data(
455-
span=span,
456-
integration=integration,
457-
model=model,
458-
input_tokens=total_input,
459-
output_tokens=usage.output_tokens,
460-
cache_read_input_tokens=usage.cache_read_input_tokens,
461-
cache_write_input_tokens=usage.cache_write_input_tokens,
462-
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
463-
finish_span=True,
464-
)
465-
466-
async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
467-
model = None
468-
usage = _RecordedUsage()
469-
content_blocks: "list[str]" = []
470-
471-
async for event in old_iterator:
472-
if not isinstance(
473-
event,
474-
(
475-
MessageStartEvent,
476-
MessageDeltaEvent,
477-
MessageStopEvent,
478-
ContentBlockStartEvent,
479-
ContentBlockDeltaEvent,
480-
ContentBlockStopEvent,
481-
),
482-
):
483-
yield event
484-
continue
485-
486-
(
487-
model,
488-
usage,
489-
content_blocks,
490-
) = _collect_ai_data(
491-
event,
492-
model,
493-
usage,
494-
content_blocks,
495-
)
496-
yield event
497-
498-
# Anthropic's input_tokens excludes cached/cache_write tokens.
499-
# Normalize to total input tokens for correct cost calculations.
500-
total_input = (
501-
usage.input_tokens
502-
+ (usage.cache_read_input_tokens or 0)
503-
+ (usage.cache_write_input_tokens or 0)
504-
)
505-
506-
_set_output_data(
507-
span=span,
508-
integration=integration,
509-
model=model,
510-
input_tokens=total_input,
511-
output_tokens=usage.output_tokens,
512-
cache_read_input_tokens=usage.cache_read_input_tokens,
513-
cache_write_input_tokens=usage.cache_write_input_tokens,
514-
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
515-
finish_span=True,
516-
)
517-
518-
if str(type(result._iterator)) == "<class 'async_generator'>":
519-
result._iterator = new_iterator_async()
520-
else:
521-
result._iterator = new_iterator()
522-
523-
524553
def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
525554
integration = kwargs.pop("integration")
526555
if integration is None:
@@ -547,9 +576,9 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
547576

548577
result = yield f, args, kwargs
549578

550-
is_streaming_response = kwargs.get("stream", False)
551-
if is_streaming_response:
552-
_patch_streaming_response_iterator(result, span, integration)
579+
if isinstance(result, Stream) or isinstance(result, AsyncStream):
580+
result._sentry_span = span
581+
result._integration = integration
553582
return result
554583

555584
with capture_internal_exceptions():

0 commit comments

Comments
 (0)