Skip to content

Commit c9f62fe

Browse files
committed
PYTHON-5272 Fix concurrency bug in async SSL protocol patch restore
Capture _ORIGINAL_SSL_PROTOCOL once at module load time and always restore to it unconditionally, so that concurrent connections from the same pool cannot leave a stale SSLProtocol subclass active in asyncio.sslproto.
1 parent 5958f59 commit c9f62fe

1 file changed

Lines changed: 11 additions & 6 deletions

File tree

pymongo/pool_shared.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def _get_ssl_session(ssl_sock: Any) -> Optional[Any]:
8181
# On older Python, _SSLPipe.do_handshake calls wrap_bio and starts the handshake
8282
# atomically; session injection there requires copying private internals, so we skip it.
8383
_ASYNCIO_SSL_SESSION_SUPPORTED = sys.version_info >= (3, 11)
84+
if _ASYNCIO_SSL_SESSION_SUPPORTED:
85+
import asyncio.sslproto as _asyncio_sslproto
86+
87+
# Capture the true original once at import time so concurrent connections
88+
# always restore to it, not to a locally-captured (possibly stale) reference.
89+
_ORIGINAL_SSL_PROTOCOL = _asyncio_sslproto.SSLProtocol
8490

8591

8692
def _make_session_ssl_protocol(session: Any) -> Any:
@@ -376,16 +382,15 @@ async def _configured_protocol_interface(
376382
# instantiates SSLProtocol) is called synchronously inside
377383
# create_connection before the first await, so the swap is race-free in a
378384
# single-threaded event loop when the socket is pre-connected.
379-
import asyncio.sslproto as _sslproto
380-
385+
# Always restore to _ORIGINAL_SSL_PROTOCOL (not a locally captured value)
386+
# so that concurrent connections can't leave a stale subclass in place.
381387
session = (
382388
ssl_session_cache.get()
383389
if ssl_session_cache is not None and _ASYNCIO_SSL_SESSION_SUPPORTED
384390
else None
385391
)
386-
original_ssl_protocol = _sslproto.SSLProtocol
387392
if session is not None:
388-
_sslproto.SSLProtocol = _make_session_ssl_protocol(session) # type: ignore[misc]
393+
_asyncio_sslproto.SSLProtocol = _make_session_ssl_protocol(session) # type: ignore[misc]
389394
try:
390395
# We have to pass hostname / ip address to wrap_socket
391396
# to use SSLContext.check_hostname.
@@ -406,8 +411,8 @@ async def _configured_protocol_interface(
406411
details = _get_timeout_details(options)
407412
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
408413
finally:
409-
if session is not None:
410-
_sslproto.SSLProtocol = original_ssl_protocol # type: ignore[misc]
414+
if _ASYNCIO_SSL_SESSION_SUPPORTED:
415+
_asyncio_sslproto.SSLProtocol = _ORIGINAL_SSL_PROTOCOL # type: ignore[misc]
411416

412417
if (
413418
ssl_context.verify_mode

0 commit comments

Comments
 (0)