Skip to content

Commit 563991a

Browse files
committed
PYTHON-5272 Extend TLS session resumption to asyncio path (Python 3.11+)
On Python 3.11+, SSLProtocol.__init__ creates the ssl.SSLObject via wrap_bio before the handshake starts in connection_made. We temporarily replace asyncio.sslproto.SSLProtocol with a subclass that sets sslobj.session to the cached session immediately after super().__init__, then restore the original class. With a pre-connected sock= parameter, _make_ssl_transport is called synchronously inside create_connection before the first await, so the swap is race-free in a single-threaded event loop. After the handshake, the session is retrieved via transport.get_extra_info('ssl_object').session and stored in the pool's _SSLSessionCache for the next connection.
1 parent 92a8230 commit 563991a

3 files changed

Lines changed: 95 additions & 8 deletions

File tree

pymongo/pool_shared.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,28 @@ def _get_ssl_session(ssl_sock: Any) -> Optional[Any]:
7373
return getattr(ssl_sock, "session", None)
7474

7575

76+
# asyncio's SSLProtocol does not expose a session= parameter in create_connection.
77+
# On Python 3.11+, wrap_bio() is called in SSLProtocol.__init__ and the handshake
78+
# starts later in connection_made(), so we can set sslobj.session between the two.
79+
# On older Python, _SSLPipe.do_handshake calls wrap_bio and starts the handshake
80+
# atomically; session injection there requires copying private internals, so we skip it.
81+
_ASYNCIO_SSL_SESSION_SUPPORTED = sys.version_info >= (3, 11)
82+
83+
84+
def _make_session_ssl_protocol(session: Any) -> Any:
85+
"""Return an SSLProtocol subclass that injects *session* before the handshake."""
86+
import asyncio.sslproto as _sslproto
87+
88+
class _SessionSSLProtocol(_sslproto.SSLProtocol):
89+
def __init__(self, *args: Any, **kwargs: Any) -> None:
90+
super().__init__(*args, **kwargs)
91+
sslobj = getattr(self, "_sslobj", None)
92+
if sslobj is not None:
93+
sslobj.session = session
94+
95+
return _SessionSSLProtocol
96+
97+
7698
try:
7799
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
78100

@@ -327,7 +349,7 @@ async def _async_configured_socket(
327349
async def _configured_protocol_interface(
328350
address: _Address,
329351
options: PoolOptions,
330-
ssl_session_cache: Optional[_SSLSessionCache] = None, # noqa: ARG001
352+
ssl_session_cache: Optional[_SSLSessionCache] = None,
331353
) -> AsyncNetworkingInterface:
332354
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
333355
@@ -347,25 +369,37 @@ async def _configured_protocol_interface(
347369
)
348370

349371
host = address[0]
372+
# On Python 3.11+, temporarily patch asyncio's SSLProtocol to inject the
373+
# cached session before the handshake. _make_ssl_transport (which
374+
# instantiates SSLProtocol) is called synchronously inside
375+
# create_connection before the first await, so the swap is race-free in a
376+
# single-threaded event loop when the socket is pre-connected.
377+
import asyncio.sslproto as _sslproto
378+
379+
session = (
380+
ssl_session_cache.get()
381+
if ssl_session_cache is not None and _ASYNCIO_SSL_SESSION_SUPPORTED
382+
else None
383+
)
384+
original_ssl_protocol = _sslproto.SSLProtocol
385+
if session is not None:
386+
_sslproto.SSLProtocol = _make_session_ssl_protocol(session) # type: ignore[misc]
350387
try:
351-
# We have to pass hostname / ip address to wrap_socket
352-
# to use SSLContext.check_hostname.
353388
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
354389
lambda: PyMongoProtocol(timeout=timeout),
355390
sock=sock,
356391
server_hostname=host,
357392
ssl=ssl_context,
358393
)
359394
except _CertificateError:
360-
# Raise _CertificateError directly like we do after match_hostname
361-
# below.
362395
raise
363396
except (OSError, *SSLErrors) as exc:
364-
# We raise AutoReconnect for transient and permanent SSL handshake
365-
# failures alike. Permanent handshake failures, like protocol
366-
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
367397
details = _get_timeout_details(options)
368398
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
399+
finally:
400+
if session is not None:
401+
_sslproto.SSLProtocol = original_ssl_protocol # type: ignore[misc]
402+
369403
if (
370404
ssl_context.verify_mode
371405
and not ssl_context.check_hostname
@@ -377,6 +411,13 @@ async def _configured_protocol_interface(
377411
transport.abort()
378412
raise
379413

414+
if ssl_session_cache is not None and _ASYNCIO_SSL_SESSION_SUPPORTED:
415+
ssl_obj = transport.get_extra_info("ssl_object")
416+
if ssl_obj is not None:
417+
new_session = ssl_obj.session
418+
if new_session is not None:
419+
ssl_session_cache.set(new_session)
420+
380421
return AsyncNetworkingInterface((transport, protocol))
381422

382423

test/asynchronous/test_ssl.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,29 @@ def test_tls_session_reused_on_second_connection(self):
170170
_, kwargs = mock_ssl_context.wrap_socket.call_args
171171
self.assertIs(kwargs.get("session"), fake_session)
172172

173+
@unittest.skipUnless(
174+
not _IS_SYNC and sys.version_info >= (3, 11),
175+
"Async session injection requires Python 3.11+",
176+
)
177+
def test_async_tls_session_injected_into_sslobj(self):
178+
"""Cached TLS session is set on SSLObject before the handshake on Python 3.11+."""
179+
import asyncio.sslproto as _sslproto
180+
import unittest.mock as mock
181+
182+
from pymongo.pool_shared import _make_session_ssl_protocol, _SSLSessionCache
183+
184+
fake_session = mock.MagicMock()
185+
patched_cls = _make_session_ssl_protocol(fake_session)
186+
187+
mock_sslobj = mock.MagicMock()
188+
instance = patched_cls.__new__(patched_cls)
189+
instance._sslobj = mock_sslobj
190+
# Call __init__ via the patched class, bypassing the real SSLProtocol init.
191+
with mock.patch.object(_sslproto.SSLProtocol, "__init__", lambda *a, **kw: None):
192+
patched_cls.__init__(instance)
193+
194+
self.assertEqual(mock_sslobj.session, fake_session)
195+
173196

174197
class TestSSL(AsyncIntegrationTest):
175198
saved_port: int

test/test_ssl.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,29 @@ def test_tls_session_reused_on_second_connection(self):
170170
_, kwargs = mock_ssl_context.wrap_socket.call_args
171171
self.assertIs(kwargs.get("session"), fake_session)
172172

173+
@unittest.skipUnless(
174+
not _IS_SYNC and sys.version_info >= (3, 11),
175+
"Async session injection requires Python 3.11+",
176+
)
177+
def test_async_tls_session_injected_into_sslobj(self):
178+
"""Cached TLS session is set on SSLObject before the handshake on Python 3.11+."""
179+
import asyncio.sslproto as _sslproto
180+
import unittest.mock as mock
181+
182+
from pymongo.pool_shared import _make_session_ssl_protocol, _SSLSessionCache
183+
184+
fake_session = mock.MagicMock()
185+
patched_cls = _make_session_ssl_protocol(fake_session)
186+
187+
mock_sslobj = mock.MagicMock()
188+
instance = patched_cls.__new__(patched_cls)
189+
instance._sslobj = mock_sslobj
190+
# Call __init__ via the patched class, bypassing the real SSLProtocol init.
191+
with mock.patch.object(_sslproto.SSLProtocol, "__init__", lambda *a, **kw: None):
192+
patched_cls.__init__(instance)
193+
194+
self.assertEqual(mock_sslobj.session, fake_session)
195+
173196

174197
class TestSSL(IntegrationTest):
175198
saved_port: int

0 commit comments

Comments
 (0)