Skip to content

Commit d4651f3

Browse files
committed
Ensure that all async generators are explicitly closed
1 parent 9820975 commit d4651f3

9 files changed

Lines changed: 55 additions & 28 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
66

7+
## [Unreleased]
8+
9+
- Explicitly close all async generators to ensure predictable behavior
10+
711
## Version 1.0.9 (April 24th, 2025)
812

913
- Resolve https://github.com/advisories/GHSA-vqfr-h8mv-ghfj with h11 dependency update. (#1008)

httpcore/_async/connection_pool.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncGenerator
4+
35
import ssl
46
import sys
57
import types
@@ -10,9 +12,13 @@
1012
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
1113
from .._models import Origin, Proxy, Request, Response
1214
from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock
15+
from .._utils import aclosing
1316
from .connection import AsyncHTTPConnection
1417
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
1518

19+
if typing.TYPE_CHECKING:
20+
from .http11 import HTTP11ConnectionByteStream
21+
from .http2 import HTTP2ConnectionByteStream
1622

1723
class AsyncPoolRequest:
1824
def __init__(self, request: Request) -> None:
@@ -389,7 +395,7 @@ def __repr__(self) -> str:
389395
class PoolByteStream:
390396
def __init__(
391397
self,
392-
stream: typing.AsyncIterable[bytes],
398+
stream: HTTP11ConnectionByteStream | HTTP2ConnectionByteStream,
393399
pool_request: AsyncPoolRequest,
394400
pool: AsyncConnectionPool,
395401
) -> None:
@@ -398,20 +404,16 @@ def __init__(
398404
self._pool = pool
399405
self._closed = False
400406

401-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
402-
try:
403-
async for part in self._stream:
404-
yield part
405-
except BaseException as exc:
406-
await self.aclose()
407-
raise exc from None
407+
async def __aiter__(self) -> AsyncGenerator[bytes]:
408+
async with aclosing(self._stream.__aiter__()) as iterator:
409+
async for chunk in iterator:
410+
yield chunk
408411

409412
async def aclose(self) -> None:
410413
if not self._closed:
411414
self._closed = True
412415
with AsyncShieldCancellation():
413-
if hasattr(self._stream, "aclose"):
414-
await self._stream.aclose()
416+
await self._stream.aclose()
415417

416418
with self._pool._optional_thread_lock:
417419
self._pool._requests.remove(self._pool_request)

httpcore/_async/http11.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
import types
88
import typing
9+
from collections.abc import AsyncGenerator
910

1011
import h11
1112

@@ -21,6 +22,7 @@
2122
from .._synchronization import AsyncLock, AsyncShieldCancellation
2223
from .._trace import Trace
2324
from .interfaces import AsyncConnectionInterface
25+
from .._utils import aclosing
2426

2527
logger = logging.getLogger("httpcore.http11")
2628

@@ -193,9 +195,7 @@ async def _receive_response_headers(
193195

194196
return http_version, event.status_code, event.reason, headers, trailing_data
195197

196-
async def _receive_response_body(
197-
self, request: Request
198-
) -> typing.AsyncIterator[bytes]:
198+
async def _receive_response_body(self, request: Request) -> AsyncGenerator[bytes]:
199199
timeouts = request.extensions.get("timeout", {})
200200
timeout = timeouts.get("read", None)
201201

@@ -327,12 +327,13 @@ def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None:
327327
self._request = request
328328
self._closed = False
329329

330-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
330+
async def __aiter__(self) -> AsyncGenerator[bytes]:
331331
kwargs = {"request": self._request}
332332
try:
333333
async with Trace("receive_response_body", logger, self._request, kwargs):
334-
async for chunk in self._connection._receive_response_body(**kwargs):
335-
yield chunk
334+
async with aclosing(self._connection._receive_response_body(**kwargs)) as body:
335+
async for chunk in body:
336+
yield chunk
336337
except BaseException as exc:
337338
# If we get an exception while streaming the response,
338339
# we want to close the response (and possibly the connection)

httpcore/_async/http2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66
import types
77
import typing
8+
from collections.abc import AsyncGenerator
89

910
import h2.config
1011
import h2.connection
@@ -308,7 +309,7 @@ async def _receive_response(
308309

309310
async def _receive_response_body(
310311
self, request: Request, stream_id: int
311-
) -> typing.AsyncIterator[bytes]:
312+
) -> AsyncGenerator[bytes]:
312313
"""
313314
Iterator that returns the bytes of the response body for a given stream ID.
314315
"""
@@ -568,7 +569,7 @@ def __init__(
568569
self._stream_id = stream_id
569570
self._closed = False
570571

571-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
572+
async def __aiter__(self) -> AsyncGenerator[bytes]:
572573
kwargs = {"request": self._request, "stream_id": self._stream_id}
573574
try:
574575
async with Trace("receive_response_body", logger, self._request, kwargs):

httpcore/_async/interfaces.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextlib
44
import typing
5+
from collections.abc import AsyncGenerator
56

67
from .._models import (
78
URL,
@@ -56,9 +57,9 @@ async def stream(
5657
url: URL | bytes | str,
5758
*,
5859
headers: HeaderTypes = None,
59-
content: bytes | typing.AsyncIterator[bytes] | None = None,
60+
content: bytes | AsyncGenerator[bytes] | None = None,
6061
extensions: Extensions | None = None,
61-
) -> typing.AsyncIterator[Response]:
62+
) -> AsyncGenerator[Response]:
6263
# Strict type checking on our parameters.
6364
method = enforce_bytes(method, name="method")
6465
url = enforce_url(url, name="url")

httpcore/_models.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import ssl
55
import typing
66
import urllib.parse
7+
from collections.abc import AsyncGenerator
8+
9+
from ._utils import aclosing
710

811
# Functions for typechecking...
912

@@ -151,7 +154,7 @@ def __init__(self, content: bytes) -> None:
151154
def __iter__(self) -> typing.Iterator[bytes]:
152155
yield self._content
153156

154-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
157+
async def __aiter__(self) -> AsyncGenerator[bytes]:
155158
yield self._content
156159

157160
def __repr__(self) -> str:
@@ -463,10 +466,11 @@ async def aread(self) -> bytes:
463466
"You should use 'response.read()' instead."
464467
)
465468
if not hasattr(self, "_content"):
466-
self._content = b"".join([part async for part in self.aiter_stream()])
469+
async with aclosing(self.aiter_stream()) as parts:
470+
self._content = b"".join([part async for part in parts])
467471
return self._content
468472

469-
async def aiter_stream(self) -> typing.AsyncIterator[bytes]:
473+
async def aiter_stream(self) -> AsyncGenerator[bytes]:
470474
if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover
471475
raise RuntimeError(
472476
"Attempted to stream an synchronous response using 'async for ... in "
@@ -479,8 +483,9 @@ async def aiter_stream(self) -> typing.AsyncIterator[bytes]:
479483
"more than once."
480484
)
481485
self._stream_consumed = True
482-
async for chunk in self.stream:
483-
yield chunk
486+
async with aclosing(self.stream) as parts:
487+
async for chunk in parts:
488+
yield chunk
484489

485490
async def aclose(self) -> None:
486491
if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover

httpcore/_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@
44
import socket
55
import sys
66

7+
if sys.version_info >= (3, 10):
8+
from contextlib import aclosing as aclosing
9+
else:
10+
class aclosing(AbstractAsyncContextManager):
11+
def __init__(self, thing):
12+
self.thing = thing
13+
14+
async def __aenter__(self):
15+
return self.thing
16+
17+
async def __aexit__(self, *exc_info):
18+
await self.thing.aclose()
19+
720

821
def is_socket_readable(sock: socket.socket | None) -> bool:
922
"""

tests/test_cancellations.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ async def test_h11_timeout_during_response():
171171
assert conn.is_closed()
172172

173173

174-
@pytest.mark.xfail
175174
@pytest.mark.anyio
176175
async def test_h2_timeout_during_handshake():
177176
"""
@@ -186,7 +185,6 @@ async def test_h2_timeout_during_handshake():
186185
assert conn.is_closed()
187186

188187

189-
@pytest.mark.xfail
190188
@pytest.mark.anyio
191189
async def test_h2_timeout_during_request():
192190
"""
@@ -207,7 +205,6 @@ async def test_h2_timeout_during_request():
207205
assert conn.is_idle()
208206

209207

210-
@pytest.mark.xfail
211208
@pytest.mark.anyio
212209
async def test_h2_timeout_during_response():
213210
"""

tests/test_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
163163
for chunk in self._chunks:
164164
yield chunk
165165

166+
async def aclose(self) -> None:
167+
pass
168+
166169

167170
@pytest.mark.trio
168171
async def test_response_async_read():

0 commit comments

Comments
 (0)