33from __future__ import annotations
44
55import logging
6- from contextlib import ExitStack , contextmanager
6+ from contextlib import AsyncExitStack , ExitStack , contextmanager
77from types import TracebackType
88from typing import TYPE_CHECKING , Callable , Generator , Generic , TypeVar
99
5353 ResponseStreamEvent ,
5454 )
5555 from openai .lib .streaming .responses ._responses import (
56+ AsyncResponseStream ,
57+ AsyncResponseStreamManager ,
5658 ResponseStream ,
5759 ResponseStreamManager ,
5860 ) # pylint: disable=no-name-in-module
@@ -76,6 +78,16 @@ def _set_response_attributes(
7678 _set_invocation_response_attributes (invocation , result , capture_content )
7779
7880
81+ def _get_stream_response (stream ):
82+ try :
83+ return stream ._response
84+ except AttributeError :
85+ try :
86+ return stream .response
87+ except AttributeError :
88+ return None
89+
90+
7991class _ResponseProxy (Generic [ResponseT ]):
8092 def __init__ (self , response : ResponseT , finalize : Callable [[], None ]):
8193 self ._response = response
@@ -91,6 +103,21 @@ def __getattr__(self, name: str):
91103 return getattr (self ._response , name )
92104
93105
106+ class _AsyncResponseProxy (Generic [ResponseT ]):
107+ def __init__ (self , response : ResponseT , finalize : Callable [[], None ]):
108+ self ._response = response
109+ self ._finalize = finalize
110+
111+ async def aclose (self ) -> None :
112+ try :
113+ await self ._response .aclose ()
114+ finally :
115+ self ._finalize ()
116+
117+ def __getattr__ (self , name : str ):
118+ return getattr (self ._response , name )
119+
120+
94121class ResponseStreamWrapper (Generic [TextFormatT ]):
95122 """Wrapper for OpenAI Responses API stream objects.
96123
@@ -172,7 +199,7 @@ def __getattr__(self, name: str):
172199
173200 @property
174201 def response (self ):
175- response = self .stream . response
202+ response = _get_stream_response ( self .stream )
176203 if response is None :
177204 return None
178205 return _ResponseProxy (response , lambda : self ._stop (None ))
@@ -303,3 +330,135 @@ def parse(self) -> "ResponseStreamManagerWrapper[TextFormatT]":
303330 # cleanup once wrapt 2 typing support is available (wrapt PR #3903).
304331 def __getattr__ (self , name : str ):
305332 return getattr (self ._manager , name )
333+
334+
335+ class AsyncResponseStreamWrapper (ResponseStreamWrapper [TextFormatT ]):
336+ """Wrapper for async OpenAI Responses API stream objects."""
337+
338+ stream : "AsyncResponseStream[TextFormatT]"
339+
340+ async def __aenter__ (self ) -> "AsyncResponseStreamWrapper[TextFormatT]" :
341+ return self
342+
343+ async def __aexit__ (
344+ self ,
345+ exc_type : type [BaseException ] | None ,
346+ exc_val : BaseException | None ,
347+ exc_tb : TracebackType | None ,
348+ ) -> bool :
349+ try :
350+ if exc_type is not None :
351+ self ._fail (
352+ str (exc_val ), type (exc_val ) if exc_val else Exception
353+ )
354+ finally :
355+ await self .close ()
356+ return False
357+
358+ async def close (self ) -> None :
359+ try :
360+ await self .stream .close ()
361+ finally :
362+ self ._stop (None )
363+
364+ def __aiter__ (self ) -> "AsyncResponseStreamWrapper[TextFormatT]" :
365+ return self
366+
367+ async def __anext__ (self ) -> "ResponseStreamEvent[TextFormatT]" :
368+ try :
369+ event = await self .stream .__anext__ ()
370+ except StopAsyncIteration :
371+ self ._stop (None )
372+ raise
373+ except Exception as error :
374+ self ._fail (str (error ), type (error ))
375+ raise
376+ with self ._safe_instrumentation ("event processing" ):
377+ self .process_event (event )
378+ return event
379+
380+ async def get_final_response (self ) -> "ParsedResponse[TextFormatT]" :
381+ await self .until_done ()
382+ return await self .stream .get_final_response ()
383+
384+ async def until_done (self ) -> "AsyncResponseStreamWrapper[TextFormatT]" :
385+ async for _ in self :
386+ pass
387+ return self
388+
389+ def parse (self ) -> "AsyncResponseStreamWrapper[TextFormatT]" :
390+ raise NotImplementedError (
391+ "AsyncResponseStreamWrapper.parse() is not implemented"
392+ )
393+
394+ @property
395+ def response (self ):
396+ response = _get_stream_response (self .stream )
397+ if response is None :
398+ return None
399+ return _AsyncResponseProxy (response , lambda : self ._stop (None ))
400+
401+
402+ class AsyncResponseStreamManagerWrapper (Generic [TextFormatT ]):
403+ """Wrapper for async OpenAI Responses API stream managers."""
404+
405+ def __init__ (
406+ self ,
407+ manager : "AsyncResponseStreamManager[TextFormatT]" ,
408+ handler : TelemetryHandler ,
409+ invocation : "LLMInvocation" ,
410+ capture_content : bool ,
411+ ):
412+ self ._manager = manager
413+ self ._handler = handler
414+ self ._invocation = invocation
415+ self ._capture_content = capture_content
416+ self ._stream_wrapper : (
417+ AsyncResponseStreamWrapper [TextFormatT ] | None
418+ ) = None
419+
420+ async def __aenter__ (self ) -> AsyncResponseStreamWrapper [TextFormatT ]:
421+ stream = await self ._manager .__aenter__ ()
422+ self ._stream_wrapper = AsyncResponseStreamWrapper (
423+ stream ,
424+ self ._handler ,
425+ self ._invocation ,
426+ self ._capture_content ,
427+ )
428+ return self ._stream_wrapper
429+
430+ async def __aexit__ (
431+ self ,
432+ exc_type : type [BaseException ] | None ,
433+ exc_val : BaseException | None ,
434+ exc_tb : TracebackType | None ,
435+ ) -> bool :
436+ suppressed = False
437+ stream_wrapper = self ._stream_wrapper
438+ self ._stream_wrapper = None
439+ async with AsyncExitStack () as cleanup :
440+ if stream_wrapper is not None :
441+
442+ async def finalize_stream_wrapper () -> None :
443+ if suppressed :
444+ await stream_wrapper .__aexit__ (None , None , None )
445+ else :
446+ await stream_wrapper .__aexit__ (
447+ exc_type , exc_val , exc_tb
448+ )
449+
450+ cleanup .push_async_callback (finalize_stream_wrapper )
451+ suppressed = await self ._manager .__aexit__ (
452+ exc_type , exc_val , exc_tb
453+ )
454+ return suppressed
455+
456+ def parse (self ) -> "AsyncResponseStreamManagerWrapper[TextFormatT]" :
457+ raise NotImplementedError (
458+ "AsyncResponseStreamManagerWrapper.parse() is not implemented"
459+ )
460+
461+ # TODO: Replace __getattr__ passthrough with wrapt.ObjectProxy in a future
462+ # cleanup once wrapt 2 typing support is available (wrapt PR #3903).
463+ def __getattr__ (self , name : str ):
464+ return getattr (self ._manager , name )
0 commit comments