|
55 | 55 | raise DidNotEnable("Anthropic not installed") |
56 | 56 |
|
57 | 57 | if TYPE_CHECKING: |
58 | | - from typing import Any, AsyncIterator, Iterator, List, Optional, Union, Callable |
| 58 | + from typing import Any, AsyncIterator, Iterator, List, Optional, Union |
59 | 59 | from sentry_sdk.tracing import Span |
60 | 60 | from sentry_sdk._types import TextPart |
61 | 61 |
|
@@ -84,155 +84,6 @@ def setup_once() -> None: |
84 | 84 | Messages.create = _wrap_message_create(Messages.create) |
85 | 85 | AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create) |
86 | 86 |
|
87 | | - Stream.__iter__ = _wrap_stream_iter(Stream.__iter__) |
88 | | - AsyncStream.__aiter__ = _wrap_async_stream_aiter(AsyncStream.__aiter__) |
89 | | - |
90 | | - |
91 | | -def _wrap_stream_iter( |
92 | | - f: "Callable[..., Iterator[RawMessageStreamEvent]]", |
93 | | -) -> "Callable[..., Iterator[RawMessageStreamEvent]]": |
94 | | - """ |
95 | | - Sets information received while iterating the response stream on the AI Client Span. |
96 | | - Responsible for closing the AI Client Span. |
97 | | - """ |
98 | | - |
99 | | - @wraps(f) |
100 | | - def _patched_iter(self: "Stream") -> "Iterator[RawMessageStreamEvent]": |
101 | | - if not hasattr(self, "_sentry_span"): |
102 | | - for event in f(self): |
103 | | - yield event |
104 | | - return |
105 | | - |
106 | | - model = None |
107 | | - usage = _RecordedUsage() |
108 | | - content_blocks: "list[str]" = [] |
109 | | - |
110 | | - for event in f(self): |
111 | | - if not isinstance( |
112 | | - event, |
113 | | - ( |
114 | | - MessageStartEvent, |
115 | | - MessageDeltaEvent, |
116 | | - MessageStopEvent, |
117 | | - ContentBlockStartEvent, |
118 | | - ContentBlockDeltaEvent, |
119 | | - ContentBlockStopEvent, |
120 | | - ), |
121 | | - ): |
122 | | - yield event |
123 | | - continue |
124 | | - |
125 | | - ( |
126 | | - model, |
127 | | - usage, |
128 | | - content_blocks, |
129 | | - ) = _collect_ai_data( |
130 | | - event, |
131 | | - model, |
132 | | - usage, |
133 | | - content_blocks, |
134 | | - ) |
135 | | - yield event |
136 | | - |
137 | | - # Anthropic's input_tokens excludes cached/cache_write tokens. |
138 | | - # Normalize to total input tokens for correct cost calculations. |
139 | | - total_input = ( |
140 | | - usage.input_tokens |
141 | | - + (usage.cache_read_input_tokens or 0) |
142 | | - + (usage.cache_write_input_tokens or 0) |
143 | | - ) |
144 | | - |
145 | | - span = self._sentry_span |
146 | | - integration = self._integration |
147 | | - |
148 | | - _set_output_data( |
149 | | - span=span, |
150 | | - integration=integration, |
151 | | - model=model, |
152 | | - input_tokens=total_input, |
153 | | - output_tokens=usage.output_tokens, |
154 | | - cache_read_input_tokens=usage.cache_read_input_tokens, |
155 | | - cache_write_input_tokens=usage.cache_write_input_tokens, |
156 | | - content_blocks=[{"text": "".join(content_blocks), "type": "text"}], |
157 | | - finish_span=True, |
158 | | - ) |
159 | | - |
160 | | - return _patched_iter |
161 | | - |
162 | | - |
163 | | -def _wrap_async_stream_aiter( |
164 | | - f: "Callable[..., AsyncIterator[RawMessageStreamEvent]]", |
165 | | -) -> "Callable[..., AsyncIterator[RawMessageStreamEvent]]": |
166 | | - """ |
167 | | - Sets information received while iterating the response stream on the AI Client Span. |
168 | | - Responsible for closing the AI Client Span. |
169 | | - """ |
170 | | - |
171 | | - @wraps(f) |
172 | | - async def _patched_aiter( |
173 | | - self: "AsyncStream", |
174 | | - ) -> "AsyncIterator[RawMessageStreamEvent]": |
175 | | - if not hasattr(self, "_sentry_span"): |
176 | | - async for event in f(self): |
177 | | - yield event |
178 | | - return |
179 | | - |
180 | | - model = None |
181 | | - usage = _RecordedUsage() |
182 | | - content_blocks: "list[str]" = [] |
183 | | - |
184 | | - async for event in f(self): |
185 | | - if not isinstance( |
186 | | - event, |
187 | | - ( |
188 | | - MessageStartEvent, |
189 | | - MessageDeltaEvent, |
190 | | - MessageStopEvent, |
191 | | - ContentBlockStartEvent, |
192 | | - ContentBlockDeltaEvent, |
193 | | - ContentBlockStopEvent, |
194 | | - ), |
195 | | - ): |
196 | | - yield event |
197 | | - continue |
198 | | - |
199 | | - ( |
200 | | - model, |
201 | | - usage, |
202 | | - content_blocks, |
203 | | - ) = _collect_ai_data( |
204 | | - event, |
205 | | - model, |
206 | | - usage, |
207 | | - content_blocks, |
208 | | - ) |
209 | | - yield event |
210 | | - |
211 | | - # Anthropic's input_tokens excludes cached/cache_write tokens. |
212 | | - # Normalize to total input tokens for correct cost calculations. |
213 | | - total_input = ( |
214 | | - usage.input_tokens |
215 | | - + (usage.cache_read_input_tokens or 0) |
216 | | - + (usage.cache_write_input_tokens or 0) |
217 | | - ) |
218 | | - |
219 | | - span = self._sentry_span |
220 | | - integration = self._integration |
221 | | - |
222 | | - _set_output_data( |
223 | | - span=span, |
224 | | - integration=integration, |
225 | | - model=model, |
226 | | - input_tokens=total_input, |
227 | | - output_tokens=usage.output_tokens, |
228 | | - cache_read_input_tokens=usage.cache_read_input_tokens, |
229 | | - cache_write_input_tokens=usage.cache_write_input_tokens, |
230 | | - content_blocks=[{"text": "".join(content_blocks), "type": "text"}], |
231 | | - finish_span=True, |
232 | | - ) |
233 | | - |
234 | | - return _patched_aiter |
235 | | - |
236 | 87 |
|
237 | 88 | def _capture_exception(exc: "Any") -> None: |
238 | 89 | set_span_errored() |
@@ -499,6 +350,129 @@ def _set_input_data( |
499 | 350 | span.set_data(SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)) |
500 | 351 |
|
501 | 352 |
|
| 353 | +def _wrap_synchronous_message_iterator( |
| 354 | + iterator: "Iterator[RawMessageStreamEvent]", |
| 355 | + span: "Span", |
| 356 | + integration: "AnthropicIntegration", |
| 357 | +) -> "Iterator[RawMessageStreamEvent]": |
| 358 | + """ |
| 359 | + Sets information received while iterating the response stream on the AI Client Span. |
| 360 | + Responsible for closing the AI Client Span. |
| 361 | + """ |
| 362 | + |
| 363 | + model = None |
| 364 | + usage = _RecordedUsage() |
| 365 | + content_blocks: "list[str]" = [] |
| 366 | + |
| 367 | + for event in iterator: |
| 368 | + if not isinstance( |
| 369 | + event, |
| 370 | + ( |
| 371 | + MessageStartEvent, |
| 372 | + MessageDeltaEvent, |
| 373 | + MessageStopEvent, |
| 374 | + ContentBlockStartEvent, |
| 375 | + ContentBlockDeltaEvent, |
| 376 | + ContentBlockStopEvent, |
| 377 | + ), |
| 378 | + ): |
| 379 | + yield event |
| 380 | + continue |
| 381 | + |
| 382 | + ( |
| 383 | + model, |
| 384 | + usage, |
| 385 | + content_blocks, |
| 386 | + ) = _collect_ai_data( |
| 387 | + event, |
| 388 | + model, |
| 389 | + usage, |
| 390 | + content_blocks, |
| 391 | + ) |
| 392 | + yield event |
| 393 | + |
| 394 | + # Anthropic's input_tokens excludes cached/cache_write tokens. |
| 395 | + # Normalize to total input tokens for correct cost calculations. |
| 396 | + total_input = ( |
| 397 | + usage.input_tokens |
| 398 | + + (usage.cache_read_input_tokens or 0) |
| 399 | + + (usage.cache_write_input_tokens or 0) |
| 400 | + ) |
| 401 | + |
| 402 | + _set_output_data( |
| 403 | + span=span, |
| 404 | + integration=integration, |
| 405 | + model=model, |
| 406 | + input_tokens=total_input, |
| 407 | + output_tokens=usage.output_tokens, |
| 408 | + cache_read_input_tokens=usage.cache_read_input_tokens, |
| 409 | + cache_write_input_tokens=usage.cache_write_input_tokens, |
| 410 | + content_blocks=[{"text": "".join(content_blocks), "type": "text"}], |
| 411 | + finish_span=True, |
| 412 | + ) |
| 413 | + |
| 414 | + |
| 415 | +async def _wrap_asynchronous_message_iterator( |
| 416 | + iterator: "Iterator[RawMessageStreamEvent]", |
| 417 | + span: "Span", |
| 418 | + integration: "AnthropicIntegration", |
| 419 | +) -> "Iterator[RawMessageStreamEvent]": |
| 420 | + """ |
| 421 | + Sets information received while iterating the response stream on the AI Client Span. |
| 422 | + Responsible for closing the AI Client Span. |
| 423 | + """ |
| 424 | + model = None |
| 425 | + usage = _RecordedUsage() |
| 426 | + content_blocks: "list[str]" = [] |
| 427 | + |
| 428 | + async for event in iterator: |
| 429 | + if not isinstance( |
| 430 | + event, |
| 431 | + ( |
| 432 | + MessageStartEvent, |
| 433 | + MessageDeltaEvent, |
| 434 | + MessageStopEvent, |
| 435 | + ContentBlockStartEvent, |
| 436 | + ContentBlockDeltaEvent, |
| 437 | + ContentBlockStopEvent, |
| 438 | + ), |
| 439 | + ): |
| 440 | + yield event |
| 441 | + continue |
| 442 | + |
| 443 | + ( |
| 444 | + model, |
| 445 | + usage, |
| 446 | + content_blocks, |
| 447 | + ) = _collect_ai_data( |
| 448 | + event, |
| 449 | + model, |
| 450 | + usage, |
| 451 | + content_blocks, |
| 452 | + ) |
| 453 | + yield event |
| 454 | + |
| 455 | + # Anthropic's input_tokens excludes cached/cache_write tokens. |
| 456 | + # Normalize to total input tokens for correct cost calculations. |
| 457 | + total_input = ( |
| 458 | + usage.input_tokens |
| 459 | + + (usage.cache_read_input_tokens or 0) |
| 460 | + + (usage.cache_write_input_tokens or 0) |
| 461 | + ) |
| 462 | + |
| 463 | + _set_output_data( |
| 464 | + span=span, |
| 465 | + integration=integration, |
| 466 | + model=model, |
| 467 | + input_tokens=total_input, |
| 468 | + output_tokens=usage.output_tokens, |
| 469 | + cache_read_input_tokens=usage.cache_read_input_tokens, |
| 470 | + cache_write_input_tokens=usage.cache_write_input_tokens, |
| 471 | + content_blocks=[{"text": "".join(content_blocks), "type": "text"}], |
| 472 | + finish_span=True, |
| 473 | + ) |
| 474 | + |
| 475 | + |
502 | 476 | def _set_output_data( |
503 | 477 | span: "Span", |
504 | 478 | integration: "AnthropicIntegration", |
@@ -576,9 +550,16 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A |
576 | 550 |
|
577 | 551 | result = yield f, args, kwargs |
578 | 552 |
|
579 | | - if isinstance(result, Stream) or isinstance(result, AsyncStream): |
580 | | - result._sentry_span = span |
581 | | - result._integration = integration |
| 553 | + if isinstance(result, Stream): |
| 554 | + result._iterator = _wrap_synchronous_message_iterator( |
| 555 | + result._iterator, span, integration |
| 556 | + ) |
| 557 | + return result |
| 558 | + |
| 559 | + if isinstance(result, AsyncStream): |
| 560 | + result._iterator = _wrap_asynchronous_message_iterator( |
| 561 | + result._iterator, span, integration |
| 562 | + ) |
582 | 563 | return result |
583 | 564 |
|
584 | 565 | with capture_internal_exceptions(): |
|
0 commit comments