Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Changelog
=========

Changes in Version 4.18.0
-------------------------

- Improved TLS connection performance by reusing TLS sessions across connections
to the same server, avoiding a full handshake on each new connection.
Session reuse is active on the sync path unconditionally, and on the async
path on Python 3.11 or later.

Changes in Version 4.17.0 (2026/04/20)
--------------------------------------

Expand Down
8 changes: 7 additions & 1 deletion pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
_CancellationContext,
_configured_protocol_interface,
_raise_connection_failure,
_SSLSessionCache,
)
from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
Expand Down Expand Up @@ -754,6 +755,9 @@ def __init__(
self._pending = 0
self._max_connecting = self.opts.max_connecting
self._client_id = client_id
self._ssl_session_cache: Optional[_SSLSessionCache] = (
_SSLSessionCache() if self.opts._ssl_context is not None else None
)
# Log before publishing event to prevent potential listener preemption in tests
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
Expand Down Expand Up @@ -1040,7 +1044,9 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
)

try:
networking_interface = await _configured_protocol_interface(self.address, self.opts)
networking_interface = await _configured_protocol_interface(
self.address, self.opts, self._ssl_session_cache
)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
async with self.lock:
Expand Down
106 changes: 102 additions & 4 deletions pymongo/pool_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import socket
import ssl
import sys
import threading
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -46,6 +47,60 @@
from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address


class _SSLSessionCache:
"""Thread-safe cache for a single TLS session per pool, enabling session resumption."""

__slots__ = ("_session", "_lock")

def __init__(self) -> None:
self._session: Optional[Any] = None
self._lock = threading.Lock()

def get(self) -> Optional[Any]:
with self._lock:
return self._session

def set(self, session: Any) -> None:
with self._lock:
self._session = session


def _get_ssl_session(ssl_sock: Any) -> Optional[Any]:
"""Return the TLS session from an SSL socket, handling both PyOpenSSL and stdlib ssl."""
if hasattr(ssl_sock, "get_session"):
return ssl_sock.get_session()
return getattr(ssl_sock, "session", None)


# asyncio's create_connection does not support TLS session resumption natively.
# https://github.com/python/cpython/issues/79152 tracks this; a patch was submitted
# in 2018 but never merged, and the issue is now closed.
# On Python 3.11+, wrap_bio() is called in SSLProtocol.__init__ and the handshake
# starts later in connection_made(), so we can set sslobj.session between the two.
# On older Python, _SSLPipe.do_handshake calls wrap_bio and starts the handshake
# atomically; session injection there requires copying private internals, so we skip it.
_ASYNCIO_SSL_SESSION_SUPPORTED = sys.version_info >= (3, 11)
# Captured lazily on first SSL async connection; never reset thereafter so
# concurrent connections always restore to the true original, not a locally-
# captured (possibly stale) reference.
_ORIGINAL_SSL_PROTOCOL: Any = None


def _make_session_ssl_protocol(session: Any) -> Any:
"""Return an SSLProtocol subclass that injects *session* before the handshake."""
import asyncio.sslproto as _sslproto

class _SessionSSLProtocol(_sslproto.SSLProtocol):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
sslobj = getattr(self, "_sslobj", None)
if sslobj is not None:
sslobj.session = session

return _SessionSSLProtocol


try:
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl

Expand Down Expand Up @@ -298,7 +353,9 @@ async def _async_configured_socket(


async def _configured_protocol_interface(
address: _Address, options: PoolOptions
address: _Address,
options: PoolOptions,
ssl_session_cache: Optional[_SSLSessionCache] = None,
) -> AsyncNetworkingInterface:
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.

Expand All @@ -318,6 +375,26 @@ async def _configured_protocol_interface(
)

host = address[0]
# On Python 3.11+, temporarily patch asyncio's SSLProtocol to inject the
# cached session before the handshake. _make_ssl_transport (which
# instantiates SSLProtocol) is called synchronously inside
# create_connection before the first await, so the swap is race-free in a
# single-threaded event loop when the socket is pre-connected.
# Always restore to _ORIGINAL_SSL_PROTOCOL (not a locally captured value)
# so that concurrent connections can't leave a stale subclass in place.
session = (
ssl_session_cache.get()
if ssl_session_cache is not None and _ASYNCIO_SSL_SESSION_SUPPORTED
else None
)
if _ASYNCIO_SSL_SESSION_SUPPORTED:
import asyncio.sslproto as _asyncio_sslproto

global _ORIGINAL_SSL_PROTOCOL # noqa: PLW0603
if _ORIGINAL_SSL_PROTOCOL is None:
_ORIGINAL_SSL_PROTOCOL = _asyncio_sslproto.SSLProtocol
if session is not None:
_asyncio_sslproto.SSLProtocol = _make_session_ssl_protocol(session) # type: ignore[misc]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
Expand All @@ -337,6 +414,10 @@ async def _configured_protocol_interface(
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
finally:
if _ASYNCIO_SSL_SESSION_SUPPORTED:
_asyncio_sslproto.SSLProtocol = _ORIGINAL_SSL_PROTOCOL # type: ignore[misc]

if (
ssl_context.verify_mode
and not ssl_context.check_hostname
Expand All @@ -348,6 +429,13 @@ async def _configured_protocol_interface(
transport.abort()
raise

if ssl_session_cache is not None and _ASYNCIO_SSL_SESSION_SUPPORTED:
ssl_obj = transport.get_extra_info("ssl_object")
if ssl_obj is not None:
new_session = ssl_obj.session
if new_session is not None:
ssl_session_cache.set(new_session)

return AsyncNetworkingInterface((transport, protocol))


Expand Down Expand Up @@ -470,7 +558,11 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
return ssl_sock


def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface:
def _configured_socket_interface(
address: _Address,
options: PoolOptions,
ssl_session_cache: Optional[_SSLSessionCache] = None,
) -> NetworkingInterface:
"""Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket.

Can raise socket.error, ConnectionFailure, or _CertificateError.
Expand All @@ -485,13 +577,14 @@ def _configured_socket_interface(address: _Address, options: PoolOptions) -> Net
return NetworkingInterface(sock)

host = address[0]
session = ssl_session_cache.get() if ssl_session_cache is not None else None
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if _has_sni(True):
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host, session=session)
else:
ssl_sock = ssl_context.wrap_socket(sock)
ssl_sock = ssl_context.wrap_socket(sock, session=session)
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
Expand All @@ -515,5 +608,10 @@ def _configured_socket_interface(address: _Address, options: PoolOptions) -> Net
ssl_sock.close()
raise

if ssl_session_cache is not None:
new_session = _get_ssl_session(ssl_sock)
if new_session is not None:
ssl_session_cache.set(new_session)

ssl_sock.settimeout(options.socket_timeout)
return NetworkingInterface(ssl_sock)
8 changes: 7 additions & 1 deletion pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
_CancellationContext,
_configured_socket_interface,
_raise_connection_failure,
_SSLSessionCache,
)
from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
Expand Down Expand Up @@ -752,6 +753,9 @@ def __init__(
self._pending = 0
self._max_connecting = self.opts.max_connecting
self._client_id = client_id
self._ssl_session_cache: Optional[_SSLSessionCache] = (
_SSLSessionCache() if self.opts._ssl_context is not None else None
)
# Log before publishing event to prevent potential listener preemption in tests
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
Expand Down Expand Up @@ -1036,7 +1040,9 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect
)

try:
networking_interface = _configured_socket_interface(self.address, self.opts)
networking_interface = _configured_socket_interface(
self.address, self.opts, self._ssl_session_cache
)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
with self.lock:
Expand Down
82 changes: 82 additions & 0 deletions test/asynchronous/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,71 @@ def test_config_ssl(self):
def test_use_pyopenssl_when_available(self):
self.assertTrue(HAVE_PYSSL)

def test_ssl_session_cache(self):
from pymongo.pool_shared import _SSLSessionCache

cache = _SSLSessionCache()
self.assertIsNone(cache.get())
cache.set("session")
self.assertEqual(cache.get(), "session")
cache.set("new_session")
self.assertEqual(cache.get(), "new_session")

@unittest.skipUnless(_IS_SYNC, "Tests sync wrap_socket path only")
def test_tls_session_reused_on_second_connection(self):
"""Cached TLS session is passed to wrap_socket on subsequent connections."""
import unittest.mock as mock

from pymongo.pool_shared import _configured_socket_interface, _SSLSessionCache

fake_session = object()
cache = _SSLSessionCache()
cache.set(fake_session)

fake_ssl_sock = mock.MagicMock()
fake_ssl_sock.getpeercert.return_value = {}

mock_ssl_context = mock.MagicMock()
mock_ssl_context.wrap_socket.return_value = fake_ssl_sock
mock_ssl_context.verify_mode = False
mock_ssl_context.check_hostname = False

mock_opts = mock.MagicMock()
mock_opts._ssl_context = mock_ssl_context
mock_opts.socket_timeout = None
mock_opts.tls_allow_invalid_hostnames = True

with mock.patch("pymongo.pool_shared._create_connection") as mock_create:
mock_create.return_value = mock.MagicMock()
_configured_socket_interface(("localhost", 27017), mock_opts, cache)

mock_ssl_context.wrap_socket.assert_called_once()
_, kwargs = mock_ssl_context.wrap_socket.call_args
self.assertIs(kwargs.get("session"), fake_session)

@unittest.skipUnless(
not _IS_SYNC and sys.version_info >= (3, 11),
"Async session injection requires Python 3.11+",
)
def test_async_tls_session_injected_into_sslobj(self):
"""Cached TLS session is set on SSLObject before the handshake on Python 3.11+."""
import asyncio.sslproto as _sslproto
import unittest.mock as mock

from pymongo.pool_shared import _make_session_ssl_protocol, _SSLSessionCache

fake_session = mock.MagicMock()
patched_cls = _make_session_ssl_protocol(fake_session)

mock_sslobj = mock.MagicMock()
instance = patched_cls.__new__(patched_cls)
instance._sslobj = mock_sslobj
# Call __init__ via the patched class, bypassing the real SSLProtocol init.
with mock.patch.object(_sslproto.SSLProtocol, "__init__", lambda *a, **kw: None):
patched_cls.__init__(instance)

self.assertEqual(mock_sslobj.session, fake_session)


class TestSSL(AsyncIntegrationTest):
saved_port: int
Expand Down Expand Up @@ -673,6 +738,23 @@ async def test_pyopenssl_ignored_in_async(self):
await client.admin.command("ping") # command doesn't matter, just needs it to connect
await client.close()

@async_client_context.require_tls
async def test_pool_has_ssl_session_cache(self):
from pymongo.pool_shared import _SSLSessionCache

pool = list(self.client._topology._servers.values())[0].pool
self.assertIsInstance(pool._ssl_session_cache, _SSLSessionCache)

@async_client_context.require_tls
@unittest.skipUnless(
_IS_SYNC and _HAVE_PYOPENSSL, "Session caching only applies to PyOpenSSL sync path"
)
async def test_tls_session_cached_after_connect(self):
await self.client.admin.command("ping")
pool = list(self.client._topology._servers.values())[0].pool
self.assertIsNotNone(pool._ssl_session_cache)
self.assertIsNotNone(pool._ssl_session_cache.get())


if __name__ == "__main__":
unittest.main()
Loading
Loading