Skip to content

Commit 31a4a56

Browse files
Support async cancellations. (#726)
* Add 'AsyncShieldCancellation' context manager * Update _synchronization.py * Linting * Fix docstring wording * Add interim 'nocover' to show tests passing. * Add failing test case for HTTP/1.1 cancellations * Neat cleanup for HTTP/1.1 write cancellations * Drop 'nocover' for ShieldCancellation * Add failing test case for HTTP/1.1 cancellations during response reading * Resolve failing test case * Add failing test cases for cancellations on connection pools * Resolve failing test cases * Add failing test cases for cancellations on HTTP/2 connections * Resolve failing test cases * Add failing test cases for cancellations on HTTP/2 connections when reading response * Resolve failing test cases * Update CHANGELOG * Fix yield behaviour
1 parent 630e1e9 commit 31a4a56

File tree

9 files changed

+364
-32
lines changed

9 files changed

+364
-32
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
77
## unreleased
88

99
- The networking backend interface has [been added to the public API](https://www.encode.io/httpcore/network-backends). Some classes which were previously private implementation detail are now part of the top-level public API. (#699)
10+
- Support async cancellations, ensuring that the connection pool is left in a clean state when cancellations occur. (#726)
1011
- Graceful handling of HTTP/2 GoAway frames, with requests being transparently retried on a new connection. (#730)
1112
- Add exceptions when a synchronous `trace callback` is passed to an asynchronous request or an asynchronous `trace callback` is passed to a synchronous request. (#717)
1213

httpcore/_async/connection_pool.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
88
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
99
from .._models import Origin, Request, Response
10-
from .._synchronization import AsyncEvent, AsyncLock
10+
from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation
1111
from .connection import AsyncHTTPConnection
1212
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
1313

@@ -257,7 +257,8 @@ async def handle_async_request(self, request: Request) -> Response:
257257
status.unset_connection()
258258
await self._attempt_to_acquire_connection(status)
259259
except BaseException as exc:
260-
await self.response_closed(status)
260+
with AsyncShieldCancellation():
261+
await self.response_closed(status)
261262
raise exc
262263
else:
263264
break
@@ -351,4 +352,5 @@ async def aclose(self) -> None:
351352
if hasattr(self._stream, "aclose"):
352353
await self._stream.aclose()
353354
finally:
354-
await self._pool.response_closed(self._status)
355+
with AsyncShieldCancellation():
356+
await self._pool.response_closed(self._status)

httpcore/_async/http11.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
map_exceptions,
2424
)
2525
from .._models import Origin, Request, Response
26-
from .._synchronization import AsyncLock
26+
from .._synchronization import AsyncLock, AsyncShieldCancellation
2727
from .._trace import Trace
2828
from .interfaces import AsyncConnectionInterface
2929

@@ -115,8 +115,9 @@ async def handle_async_request(self, request: Request) -> Response:
115115
},
116116
)
117117
except BaseException as exc:
118-
async with Trace("response_closed", logger, request) as trace:
119-
await self._response_closed()
118+
with AsyncShieldCancellation():
119+
async with Trace("response_closed", logger, request) as trace:
120+
await self._response_closed()
120121
raise exc
121122

122123
# Sending the request...
@@ -319,7 +320,8 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
319320
# If we get an exception while streaming the response,
320321
# we want to close the response (and possibly the connection)
321322
# before raising that exception.
322-
await self.aclose()
323+
with AsyncShieldCancellation():
324+
await self.aclose()
323325
raise exc
324326

325327
async def aclose(self) -> None:

httpcore/_async/http2.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
RemoteProtocolError,
1818
)
1919
from .._models import Origin, Request, Response
20-
from .._synchronization import AsyncLock, AsyncSemaphore
20+
from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation
2121
from .._trace import Trace
2222
from .interfaces import AsyncConnectionInterface
2323

@@ -103,9 +103,15 @@ async def handle_async_request(self, request: Request) -> Response:
103103

104104
async with self._init_lock:
105105
if not self._sent_connection_init:
106-
kwargs = {"request": request}
107-
async with Trace("send_connection_init", logger, request, kwargs):
108-
await self._send_connection_init(**kwargs)
106+
try:
107+
kwargs = {"request": request}
108+
async with Trace("send_connection_init", logger, request, kwargs):
109+
await self._send_connection_init(**kwargs)
110+
except BaseException as exc:
111+
with AsyncShieldCancellation():
112+
await self.aclose()
113+
raise exc
114+
109115
self._sent_connection_init = True
110116

111117
# Initially start with just 1 until the remote server provides
@@ -154,10 +160,11 @@ async def handle_async_request(self, request: Request) -> Response:
154160
"stream_id": stream_id,
155161
},
156162
)
157-
except Exception as exc: # noqa: PIE786
158-
kwargs = {"stream_id": stream_id}
159-
async with Trace("response_closed", logger, request, kwargs):
160-
await self._response_closed(stream_id=stream_id)
163+
except BaseException as exc: # noqa: PIE786
164+
with AsyncShieldCancellation():
165+
kwargs = {"stream_id": stream_id}
166+
async with Trace("response_closed", logger, request, kwargs):
167+
await self._response_closed(stream_id=stream_id)
161168

162169
if isinstance(exc, h2.exceptions.ProtocolError):
163170
# One case where h2 can raise a protocol error is when a
@@ -570,7 +577,8 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
570577
# If we get an exception while streaming the response,
571578
# we want to close the response (and possibly the connection)
572579
# before raising that exception.
573-
await self.aclose()
580+
with AsyncShieldCancellation():
581+
await self.aclose()
574582
raise exc
575583

576584
async def aclose(self) -> None:

httpcore/_sync/connection_pool.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .._backends.base import SOCKET_OPTION, NetworkBackend
88
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
99
from .._models import Origin, Request, Response
10-
from .._synchronization import Event, Lock
10+
from .._synchronization import Event, Lock, ShieldCancellation
1111
from .connection import HTTPConnection
1212
from .interfaces import ConnectionInterface, RequestInterface
1313

@@ -257,7 +257,8 @@ def handle_request(self, request: Request) -> Response:
257257
status.unset_connection()
258258
self._attempt_to_acquire_connection(status)
259259
except BaseException as exc:
260-
self.response_closed(status)
260+
with ShieldCancellation():
261+
self.response_closed(status)
261262
raise exc
262263
else:
263264
break
@@ -351,4 +352,5 @@ def close(self) -> None:
351352
if hasattr(self._stream, "close"):
352353
self._stream.close()
353354
finally:
354-
self._pool.response_closed(self._status)
355+
with ShieldCancellation():
356+
self._pool.response_closed(self._status)

httpcore/_sync/http11.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
map_exceptions,
2424
)
2525
from .._models import Origin, Request, Response
26-
from .._synchronization import Lock
26+
from .._synchronization import Lock, ShieldCancellation
2727
from .._trace import Trace
2828
from .interfaces import ConnectionInterface
2929

@@ -115,8 +115,9 @@ def handle_request(self, request: Request) -> Response:
115115
},
116116
)
117117
except BaseException as exc:
118-
with Trace("response_closed", logger, request) as trace:
119-
self._response_closed()
118+
with ShieldCancellation():
119+
with Trace("response_closed", logger, request) as trace:
120+
self._response_closed()
120121
raise exc
121122

122123
# Sending the request...
@@ -319,7 +320,8 @@ def __iter__(self) -> Iterator[bytes]:
319320
# If we get an exception while streaming the response,
320321
# we want to close the response (and possibly the connection)
321322
# before raising that exception.
322-
self.close()
323+
with ShieldCancellation():
324+
self.close()
323325
raise exc
324326

325327
def close(self) -> None:

httpcore/_sync/http2.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
RemoteProtocolError,
1818
)
1919
from .._models import Origin, Request, Response
20-
from .._synchronization import Lock, Semaphore
20+
from .._synchronization import Lock, Semaphore, ShieldCancellation
2121
from .._trace import Trace
2222
from .interfaces import ConnectionInterface
2323

@@ -103,9 +103,15 @@ def handle_request(self, request: Request) -> Response:
103103

104104
with self._init_lock:
105105
if not self._sent_connection_init:
106-
kwargs = {"request": request}
107-
with Trace("send_connection_init", logger, request, kwargs):
108-
self._send_connection_init(**kwargs)
106+
try:
107+
kwargs = {"request": request}
108+
with Trace("send_connection_init", logger, request, kwargs):
109+
self._send_connection_init(**kwargs)
110+
except BaseException as exc:
111+
with ShieldCancellation():
112+
self.close()
113+
raise exc
114+
109115
self._sent_connection_init = True
110116

111117
# Initially start with just 1 until the remote server provides
@@ -154,10 +160,11 @@ def handle_request(self, request: Request) -> Response:
154160
"stream_id": stream_id,
155161
},
156162
)
157-
except Exception as exc: # noqa: PIE786
158-
kwargs = {"stream_id": stream_id}
159-
with Trace("response_closed", logger, request, kwargs):
160-
self._response_closed(stream_id=stream_id)
163+
except BaseException as exc: # noqa: PIE786
164+
with ShieldCancellation():
165+
kwargs = {"stream_id": stream_id}
166+
with Trace("response_closed", logger, request, kwargs):
167+
self._response_closed(stream_id=stream_id)
161168

162169
if isinstance(exc, h2.exceptions.ProtocolError):
163170
# One case where h2 can raise a protocol error is when a
@@ -570,7 +577,8 @@ def __iter__(self) -> typing.Iterator[bytes]:
570577
# If we get an exception while streaming the response,
571578
# we want to close the response (and possibly the connection)
572579
# before raising that exception.
573-
self.close()
580+
with ShieldCancellation():
581+
self.close()
574582
raise exc
575583

576584
def close(self) -> None:

httpcore/_synchronization.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,55 @@ async def release(self) -> None:
171171
self._anyio_semaphore.release()
172172

173173

174+
class AsyncShieldCancellation:
175+
# For certain portions of our codebase where we're dealing with
176+
# closing connections during exception handling we want to shield
177+
# the operation from being cancelled.
178+
#
179+
# with AsyncShieldCancellation():
180+
# ... # clean-up operations, shielded from cancellation.
181+
182+
def __init__(self) -> None:
183+
"""
184+
Detect if we're running under 'asyncio' or 'trio' and create
185+
a shielded scope with the correct implementation.
186+
"""
187+
self._backend = sniffio.current_async_library()
188+
189+
if self._backend == "trio":
190+
if trio is None: # pragma: nocover
191+
raise RuntimeError(
192+
"Running under trio requires the 'trio' package to be installed."
193+
)
194+
195+
self._trio_shield = trio.CancelScope(shield=True)
196+
else:
197+
if anyio is None: # pragma: nocover
198+
raise RuntimeError(
199+
"Running under asyncio requires the 'anyio' package to be installed."
200+
)
201+
202+
self._anyio_shield = anyio.CancelScope(shield=True)
203+
204+
def __enter__(self) -> "AsyncShieldCancellation":
205+
if self._backend == "trio":
206+
self._trio_shield.__enter__()
207+
else:
208+
self._anyio_shield.__enter__()
209+
return self
210+
211+
def __exit__(
212+
self,
213+
exc_type: Optional[Type[BaseException]] = None,
214+
exc_value: Optional[BaseException] = None,
215+
traceback: Optional[TracebackType] = None,
216+
) -> None:
217+
if self._backend == "trio":
218+
self._trio_shield.__exit__(exc_type, exc_value, traceback)
219+
else:
220+
self._anyio_shield.__exit__(exc_type, exc_value, traceback)
221+
222+
174223
# Our thread-based synchronization primitives...
175224

176225

@@ -212,3 +261,19 @@ def acquire(self) -> None:
212261

213262
def release(self) -> None:
214263
self._semaphore.release()
264+
265+
266+
class ShieldCancellation:
267+
# Thread-synchronous codebases don't support cancellation semantics.
268+
# We have this class because we need to mirror the async and sync
269+
# cases within our package, but it's just a no-op.
270+
def __enter__(self) -> "ShieldCancellation":
271+
return self
272+
273+
def __exit__(
274+
self,
275+
exc_type: Optional[Type[BaseException]] = None,
276+
exc_value: Optional[BaseException] = None,
277+
traceback: Optional[TracebackType] = None,
278+
) -> None:
279+
pass

0 commit comments

Comments
 (0)