11from __future__ import annotations
22
3+ from collections .abc import AsyncGenerator
4+
35import ssl
46import sys
57import types
1012from .._exceptions import ConnectionNotAvailable , UnsupportedProtocol
1113from .._models import Origin , Proxy , Request , Response
1214from .._synchronization import AsyncEvent , AsyncShieldCancellation , AsyncThreadLock
15+ from .._utils import aclosing
1316from .connection import AsyncHTTPConnection
1417from .interfaces import AsyncConnectionInterface , AsyncRequestInterface
1518
19+ if typing .TYPE_CHECKING :
20+ from .http11 import HTTP11ConnectionByteStream
21+ from .http2 import HTTP2ConnectionByteStream
1622
1723class AsyncPoolRequest :
1824 def __init__ (self , request : Request ) -> None :
@@ -389,7 +395,7 @@ def __repr__(self) -> str:
389395class PoolByteStream :
390396 def __init__ (
391397 self ,
392- stream : typing . AsyncIterable [ bytes ] ,
398+ stream : HTTP11ConnectionByteStream | HTTP2ConnectionByteStream ,
393399 pool_request : AsyncPoolRequest ,
394400 pool : AsyncConnectionPool ,
395401 ) -> None :
@@ -398,20 +404,16 @@ def __init__(
398404 self ._pool = pool
399405 self ._closed = False
400406
401- async def __aiter__ (self ) -> typing .AsyncIterator [bytes ]:
402- try :
403- async for part in self ._stream :
404- yield part
405- except BaseException as exc :
406- await self .aclose ()
407- raise exc from None
407+ async def __aiter__ (self ) -> AsyncGenerator [bytes ]:
408+ async with aclosing (self ._stream .__aiter__ ()) as iterator :
409+ async for chunk in iterator :
410+ yield chunk
408411
409412 async def aclose (self ) -> None :
410413 if not self ._closed :
411414 self ._closed = True
412415 with AsyncShieldCancellation ():
413- if hasattr (self ._stream , "aclose" ):
414- await self ._stream .aclose ()
416+ await self ._stream .aclose ()
415417
416418 with self ._pool ._optional_thread_lock :
417419 self ._pool ._requests .remove (self ._pool_request )
0 commit comments