|
39 | 39 |
|
40 | 40 | from anthropic import Stream, AsyncStream |
41 | 41 | from anthropic.resources import AsyncMessages, Messages |
42 | | - from anthropic.lib.streaming import MessageStreamManager |
| 42 | + from anthropic.lib.streaming import MessageStreamManager, AsyncMessageStreamManager |
43 | 43 |
|
44 | 44 | from anthropic.types import ( |
45 | 45 | MessageStartEvent, |
|
67 | 67 | TextBlockParam, |
68 | 68 | ToolUnionParam, |
69 | 69 | ) |
70 | | - from anthropic.lib.streaming import MessageStream |
| 70 | + from anthropic.lib.streaming import MessageStream, AsyncMessageStream |
71 | 71 |
|
72 | 72 |
|
73 | 73 | class _RecordedUsage: |
@@ -97,6 +97,13 @@ def setup_once() -> None: |
97 | 97 | MessageStreamManager.__enter__ |
98 | 98 | ) |
99 | 99 |
|
| 100 | + AsyncMessages.stream = _wrap_async_message_stream(AsyncMessages.stream) |
| 101 | + AsyncMessageStreamManager.__aenter__ = ( |
| 102 | + _wrap_async_message_stream_manager_aenter( |
| 103 | + AsyncMessageStreamManager.__aenter__ |
| 104 | + ) |
| 105 | + ) |
| 106 | + |
100 | 107 |
|
101 | 108 | def _capture_exception(exc: "Any") -> None: |
102 | 109 | set_span_errored() |
@@ -391,10 +398,10 @@ def _set_create_input_data( |
391 | 398 |
|
392 | 399 |
|
393 | 400 | def _wrap_synchronous_message_iterator( |
394 | | - iterator: "Iterator[RawMessageStreamEvent]", |
| 401 | + iterator: "Iterator[Union[RawMessageStreamEvent, MessageStreamEvent]]", |
395 | 402 | span: "Span", |
396 | 403 | integration: "AnthropicIntegration", |
397 | | -) -> "Iterator[RawMessageStreamEvent]": |
| 404 | +) -> "Iterator[Union[RawMessageStreamEvent, MessageStreamEvent]]": |
398 | 405 | """ |
399 | 406 | Sets information received while iterating the response stream on the AI Client Span. |
400 | 407 | Responsible for closing the AI Client Span. |
@@ -456,10 +463,10 @@ def _wrap_synchronous_message_iterator( |
456 | 463 |
|
457 | 464 |
|
458 | 465 | async def _wrap_asynchronous_message_iterator( |
459 | | - iterator: "AsyncIterator[RawMessageStreamEvent]", |
| 466 | + iterator: "AsyncIterator[Union[RawMessageStreamEvent, MessageStreamEvent]]", |
460 | 467 | span: "Span", |
461 | 468 | integration: "AnthropicIntegration", |
462 | | -) -> "AsyncIterator[RawMessageStreamEvent]": |
| 469 | +) -> "AsyncIterator[Union[RawMessageStreamEvent, MessageStreamEvent]]": |
463 | 470 | """ |
464 | 471 | Sets information received while iterating the response stream on the AI Client Span. |
465 | 472 | Responsible for closing the AI Client Span. |
@@ -809,6 +816,90 @@ def _sentry_patched_enter(self: "MessageStreamManager") -> "MessageStream": |
809 | 816 | return _sentry_patched_enter |
810 | 817 |
|
811 | 818 |
|
| 819 | +def _wrap_async_message_stream(f: "Any") -> "Any": |
| 820 | + """ |
| 821 | + Attaches user-provided arguments to the returned context manager. |
| 822 | + The attributes are set on AI Client Spans in the patch for the context manager. |
| 823 | + """ |
| 824 | + |
| 825 | + @wraps(f) |
| 826 | + def _sentry_patched_stream( |
| 827 | + *args: "Any", **kwargs: "Any" |
| 828 | + ) -> "AsyncMessageStreamManager": |
| 829 | + stream_manager = f(*args, **kwargs) |
| 830 | + |
| 831 | + stream_manager._max_tokens = kwargs.get("max_tokens") |
| 832 | + stream_manager._messages = kwargs.get("messages") |
| 833 | + stream_manager._model = kwargs.get("model") |
| 834 | + stream_manager._system = kwargs.get("system") |
| 835 | + stream_manager._temperature = kwargs.get("temperature") |
| 836 | + stream_manager._top_k = kwargs.get("top_k") |
| 837 | + stream_manager._top_p = kwargs.get("top_p") |
| 838 | + stream_manager._tools = kwargs.get("tools") |
| 839 | + |
| 840 | + return stream_manager |
| 841 | + |
| 842 | + return _sentry_patched_stream |
| 843 | + |
| 844 | + |
| 845 | +def _wrap_async_message_stream_manager_aenter(f: "Any") -> "Any": |
| 846 | + """ |
| 847 | + Creates and manages AI Client Spans. |
| 848 | + """ |
| 849 | + |
| 850 | + @wraps(f) |
| 851 | + async def _sentry_patched_aenter( |
| 852 | + self: "AsyncMessageStreamManager", |
| 853 | + ) -> "AsyncMessageStream": |
| 854 | + stream = await f(self) |
| 855 | + if not hasattr(self, "_max_tokens"): |
| 856 | + return stream |
| 857 | + |
| 858 | + integration = sentry_sdk.get_client().get_integration(AnthropicIntegration) |
| 859 | + |
| 860 | + if integration is None: |
| 861 | + return stream |
| 862 | + |
| 863 | + if self._messages is None: |
| 864 | + return stream |
| 865 | + |
| 866 | + try: |
| 867 | + iter(self._messages) |
| 868 | + except TypeError: |
| 869 | + return stream |
| 870 | + |
| 871 | + span = get_start_span_function()( |
| 872 | + op=OP.GEN_AI_CHAT, |
| 873 | + name="chat" if self._model is None else f"chat {self._model}".strip(), |
| 874 | + origin=AnthropicIntegration.origin, |
| 875 | + ) |
| 876 | + span.__enter__() |
| 877 | + |
| 878 | + span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True) |
| 879 | + _set_common_input_data( |
| 880 | + span=span, |
| 881 | + integration=integration, |
| 882 | + max_tokens=self._max_tokens, |
| 883 | + messages=self._messages, |
| 884 | + model=self._model, |
| 885 | + system=self._system, |
| 886 | + temperature=self._temperature, |
| 887 | + top_k=self._top_k, |
| 888 | + top_p=self._top_p, |
| 889 | + tools=self._tools, |
| 890 | + ) |
| 891 | + |
| 892 | + stream._iterator = _wrap_asynchronous_message_iterator( |
| 893 | + iterator=stream._iterator, |
| 894 | + span=span, |
| 895 | + integration=integration, |
| 896 | + ) |
| 897 | + |
| 898 | + return stream |
| 899 | + |
| 900 | + return _sentry_patched_aenter |
| 901 | + |
| 902 | + |
812 | 903 | def _is_given(obj: "Any") -> bool: |
813 | 904 | """ |
814 | 905 | Check for givenness safely across different anthropic versions. |
|
0 commit comments