1919from types import TracebackType
2020from typing import (
2121 TYPE_CHECKING ,
22- AsyncIterator ,
22+ Any ,
2323 Callable ,
2424 Generator ,
2525 Generic ,
2626 Iterator ,
27- Protocol ,
2827 TypeVar ,
2928 cast ,
3029)
4544 _sdk_accumulate_event = None
4645
4746if TYPE_CHECKING :
47+ from anthropic ._streaming import AsyncStream , Stream
4848 from anthropic .lib .streaming ._messages import ( # pylint: disable=no-name-in-module
49+ AsyncMessageStream ,
4950 AsyncMessageStreamManager ,
51+ MessageStream ,
5052 MessageStreamManager ,
5153 )
5254 from anthropic .lib .streaming ._types import ( # pylint: disable=no-name-in-module
53- MessageStreamEvent ,
55+ ParsedMessageStreamEvent ,
5456 )
5557 from anthropic .types import (
5658 Message ,
5759 RawMessageStreamEvent ,
5860 )
61+ from anthropic .types .parsed_message import ParsedMessage
5962
6063
6164_logger = logging .getLogger (__name__ )
62- SyncResponseT = TypeVar ("SyncResponseT" , bound = "_SupportsClose" )
63- AsyncResponseT = TypeVar ("AsyncResponseT" , bound = "_SupportsAclose" )
64- StreamEventT = TypeVar (
65- "StreamEventT" , "RawMessageStreamEvent" , "MessageStreamEvent"
66- )
67- StreamEventT_co = TypeVar (
68- "StreamEventT_co" ,
69- "RawMessageStreamEvent" ,
70- "MessageStreamEvent" ,
71- covariant = True ,
72- )
65+ ResponseT = TypeVar ("ResponseT" )
66+ ResponseFormatT = TypeVar ("ResponseFormatT" )
7367accumulate_event = cast ("Callable[..., Message] | None" , _sdk_accumulate_event )
7468
7569
76- class _SupportsClose (Protocol ):
77- def close (self ) -> None : ...
78-
79-
80- class _SupportsAclose (_SupportsClose , Protocol ):
81- async def aclose (self ) -> None : ...
82-
83-
84- class _SyncStream (Protocol [StreamEventT_co ]):
85- @property
86- def response (self ) -> _SupportsClose : ...
87-
88- def __iter__ (self ) -> Iterator [StreamEventT_co ]: ...
89-
90- def __next__ (self ) -> StreamEventT_co : ...
91-
92- def close (self ) -> None : ...
93-
94-
95- class _AsyncStream (Protocol [StreamEventT_co ]):
96- @property
97- def response (self ) -> _SupportsAclose : ...
98-
99- def __aiter__ (self ) -> AsyncIterator [StreamEventT_co ]: ...
100-
101- async def __anext__ (self ) -> StreamEventT_co : ...
102-
103- async def close (self ) -> None : ...
104-
105-
10670def _set_response_attributes (
10771 invocation : LLMInvocation ,
10872 result : "Message | None" ,
@@ -111,9 +75,9 @@ def _set_response_attributes(
11175 set_invocation_response_attributes (invocation , result , capture_content )
11276
11377
114- class _ResponseProxy (Generic [SyncResponseT ]):
115- def __init__ (self , response : SyncResponseT , finalize : Callable [[], None ]):
116- self ._response = response
78+ class _ResponseProxy (Generic [ResponseT ]):
79+ def __init__ (self , response : ResponseT , finalize : Callable [[], None ]):
80+ self ._response : Any = response
11781 self ._finalize = finalize
11882
11983 def close (self ) -> None :
@@ -126,17 +90,11 @@ def __getattr__(self, name: str):
12690 return getattr (self ._response , name )
12791
12892
129- class _AsyncResponseProxy (Generic [AsyncResponseT ]):
130- def __init__ (self , response : AsyncResponseT , finalize : Callable [[], None ]):
131- self ._response = response
93+ class _AsyncResponseProxy (Generic [ResponseT ]):
94+ def __init__ (self , response : ResponseT , finalize : Callable [[], None ]):
95+ self ._response : Any = response
13296 self ._finalize = finalize
13397
134- def close (self ) -> None :
135- try :
136- self ._response .close ()
137- finally :
138- self ._finalize ()
139-
14098 async def aclose (self ) -> None :
14199 try :
142100 await self ._response .aclose ()
@@ -166,24 +124,29 @@ def message(self) -> Message:
166124 return self ._message
167125
168126
169- class MessagesStreamWrapper (Generic [StreamEventT ], Iterator [StreamEventT ]):
127+ class MessagesStreamWrapper (
128+ Generic [ResponseFormatT ],
129+ Iterator [
130+ "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]"
131+ ],
132+ ):
170133 """Wrapper for Anthropic Stream that handles telemetry."""
171134
172135 def __init__ (
173136 self ,
174- stream : _SyncStream [ StreamEventT ] ,
137+ stream : "Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT]" ,
175138 handler : TelemetryHandler ,
176139 invocation : LLMInvocation ,
177140 capture_content : bool ,
178141 ):
179142 self .stream = stream
180143 self .handler = handler
181144 self .invocation = invocation
182- self ._message : "Message | None" = None
145+ self ._message : "Message | ParsedMessage[ResponseFormatT] | None" = None
183146 self ._capture_content = capture_content
184147 self ._finalized = False
185148
186- def __enter__ (self ) -> "MessagesStreamWrapper[StreamEventT ]" :
149+ def __enter__ (self ) -> "MessagesStreamWrapper[ResponseFormatT ]" :
187150 return self
188151
189152 def __exit__ (
@@ -207,10 +170,12 @@ def close(self) -> None:
207170 finally :
208171 self ._stop ()
209172
210- def __iter__ (self ) -> "MessagesStreamWrapper[StreamEventT ]" :
173+ def __iter__ (self ) -> "MessagesStreamWrapper[ResponseFormatT ]" :
211174 return self
212175
213- def __next__ (self ) -> StreamEventT :
176+ def __next__ (
177+ self ,
178+ ) -> "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]" :
214179 try :
215180 chunk = next (self .stream )
216181 except StopIteration :
@@ -227,9 +192,7 @@ def __getattr__(self, name: str) -> object:
227192 return getattr (self .stream , name )
228193
229194 @property
230- def response (
231- self ,
232- ) -> "_ResponseProxy[_SupportsClose] | _AsyncResponseProxy[_SupportsAclose] | None" :
195+ def response (self ):
233196 return _ResponseProxy (self .stream .response , self ._stop )
234197
235198 def _stop (self ) -> None :
@@ -266,10 +229,13 @@ def _safe_instrumentation(
266229 exc_info = True ,
267230 )
268231
269- def _process_chunk (self , chunk : StreamEventT ) -> None :
232+ def _process_chunk (
233+ self ,
234+ chunk : "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]" ,
235+ ) -> None :
270236 """Accumulate a final message snapshot from a streaming chunk."""
271237 snapshot = cast (
272- "Message | None" ,
238+ "ParsedMessage[ResponseFormatT] | None" ,
273239 getattr (self .stream , "current_message_snapshot" , None ),
274240 )
275241 if snapshot is not None :
@@ -279,30 +245,32 @@ def _process_chunk(self, chunk: StreamEventT) -> None:
279245 return
280246 self ._message = accumulate_event (
281247 event = cast ("RawMessageStreamEvent" , chunk ),
282- current_snapshot = self ._message ,
248+ current_snapshot = cast (
249+ "ParsedMessage[ResponseFormatT] | None" , self ._message
250+ ),
283251 )
284252
285253
286- class AsyncMessagesStreamWrapper (MessagesStreamWrapper [StreamEventT ]):
254+ class AsyncMessagesStreamWrapper (MessagesStreamWrapper [ResponseFormatT ]):
287255 """Wrapper for async Anthropic Stream that handles telemetry."""
288256
289- stream : _AsyncStream [StreamEventT ]
290-
291257 def __init__ (
292258 self ,
293- stream : _AsyncStream [ StreamEventT ] ,
259+ stream : "AsyncStream[RawMessageStreamEvent] | AsyncMessageStream[ResponseFormatT]" ,
294260 handler : TelemetryHandler ,
295261 invocation : LLMInvocation ,
296262 capture_content : bool ,
297263 ):
298264 self .stream = stream
299265 self .handler = handler
300266 self .invocation = invocation
301- self ._message : "Message | None" = None
267+ self ._message : "Message | ParsedMessage[ResponseFormatT] | None" = None
302268 self ._capture_content = capture_content
303269 self ._finalized = False
304270
305- async def __aenter__ (self ) -> "AsyncMessagesStreamWrapper[StreamEventT]" :
271+ async def __aenter__ (
272+ self ,
273+ ) -> "AsyncMessagesStreamWrapper[ResponseFormatT]" :
306274 return self
307275
308276 async def __aexit__ (
@@ -326,16 +294,16 @@ async def close(self) -> None: # type: ignore[override]
326294 finally :
327295 self ._stop ()
328296
329- def __aiter__ (self ) -> "AsyncMessagesStreamWrapper[StreamEventT ]" :
297+ def __aiter__ (self ) -> "AsyncMessagesStreamWrapper[ResponseFormatT ]" :
330298 return self
331299
332300 @property
333- def response (
334- self ,
335- ) -> "_ResponseProxy[_SupportsClose] | _AsyncResponseProxy[_SupportsAclose] | None" :
301+ def response (self ) -> Any :
336302 return _AsyncResponseProxy (self .stream .response , self ._stop )
337303
338- async def __anext__ (self ) -> StreamEventT :
304+ async def __anext__ (
305+ self ,
306+ ) -> "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]" :
339307 try :
340308 chunk = await self .stream .__anext__ ()
341309 except StopAsyncIteration :
@@ -349,12 +317,12 @@ async def __anext__(self) -> StreamEventT:
349317 return chunk
350318
351319
352- class MessagesStreamManagerWrapper :
320+ class MessagesStreamManagerWrapper ( Generic [ ResponseFormatT ]) :
353321 """Wrapper for sync Anthropic stream managers."""
354322
355323 def __init__ (
356324 self ,
357- manager : "MessageStreamManager" ,
325+ manager : "MessageStreamManager[ResponseFormatT] " ,
358326 handler : TelemetryHandler ,
359327 invocation : LLMInvocation ,
360328 capture_content : bool ,
@@ -363,15 +331,12 @@ def __init__(
363331 self ._handler = handler
364332 self ._invocation = invocation
365333 self ._capture_content = capture_content
366- self ._stream_wrapper : (
367- MessagesStreamWrapper [MessageStreamEvent ] | None
368- ) = None
369-
370- def __enter__ (self ) -> MessagesStreamWrapper [MessageStreamEvent ]:
371- stream = cast (
372- "_SyncStream[MessageStreamEvent]" ,
373- self ._manager .__enter__ (),
334+ self ._stream_wrapper : MessagesStreamWrapper [ResponseFormatT ] | None = (
335+ None
374336 )
337+
338+ def __enter__ (self ) -> MessagesStreamWrapper [ResponseFormatT ]:
339+ stream = self ._manager .__enter__ ()
375340 self ._stream_wrapper = MessagesStreamWrapper (
376341 stream ,
377342 self ._handler ,
@@ -406,7 +371,7 @@ def __getattr__(self, name: str) -> object:
406371 return getattr (self ._manager , name )
407372
408373
409- class AsyncMessagesStreamManagerWrapper :
374+ class AsyncMessagesStreamManagerWrapper ( Generic [ ResponseFormatT ]) :
410375 """Wrapper for AsyncMessageStreamManager that handles telemetry.
411376
412377 Wraps AsyncMessageStreamManager from the Anthropic SDK:
@@ -415,7 +380,7 @@ class AsyncMessagesStreamManagerWrapper:
415380
416381 def __init__ (
417382 self ,
418- manager : "AsyncMessageStreamManager" ,
383+ manager : "AsyncMessageStreamManager[ResponseFormatT] " ,
419384 handler : TelemetryHandler ,
420385 invocation : LLMInvocation ,
421386 capture_content : bool ,
@@ -425,16 +390,13 @@ def __init__(
425390 self ._invocation = invocation
426391 self ._capture_content = capture_content
427392 self ._stream_wrapper : (
428- AsyncMessagesStreamWrapper [MessageStreamEvent ] | None
393+ AsyncMessagesStreamWrapper [ResponseFormatT ] | None
429394 ) = None
430395
431396 async def __aenter__ (
432397 self ,
433- ) -> AsyncMessagesStreamWrapper [MessageStreamEvent ]:
434- msg_stream = cast (
435- "_AsyncStream[MessageStreamEvent]" ,
436- await self ._manager .__aenter__ (),
437- )
398+ ) -> AsyncMessagesStreamWrapper [ResponseFormatT ]:
399+ msg_stream = await self ._manager .__aenter__ ()
438400 self ._stream_wrapper = AsyncMessagesStreamWrapper (
439401 msg_stream ,
440402 self ._handler ,
0 commit comments