Skip to content

Commit 150eb37

Browse files
committed
PYTHON-5272 Implement TLS session resumption for sync pool
Add _SSLSessionCache to cache TLS sessions per pool, enabling session resumption on subsequent connections to the same server. This avoids full asymmetric-key handshakes on every new connection, addressing the OpenSSL 3.0 performance overhead seen in BF-36991.
1 parent e9058e3 commit 150eb37

5 files changed

Lines changed: 111 additions & 6 deletions

File tree

pymongo/asynchronous/pool.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
_CancellationContext,
8686
_configured_protocol_interface,
8787
_raise_connection_failure,
88+
_SSLSessionCache,
8889
)
8990
from pymongo.read_preferences import ReadPreference
9091
from pymongo.server_api import _add_to_command
@@ -754,6 +755,9 @@ def __init__(
754755
self._pending = 0
755756
self._max_connecting = self.opts.max_connecting
756757
self._client_id = client_id
758+
self._ssl_session_cache: Optional[_SSLSessionCache] = (
759+
_SSLSessionCache() if self.opts._ssl_context is not None else None
760+
)
757761
# Log before publishing event to prevent potential listener preemption in tests
758762
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
759763
_debug_log(
@@ -1040,7 +1044,9 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
10401044
)
10411045

10421046
try:
1043-
networking_interface = await _configured_protocol_interface(self.address, self.opts)
1047+
networking_interface = await _configured_protocol_interface(
1048+
self.address, self.opts, self._ssl_session_cache
1049+
)
10441050
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
10451051
except BaseException as error:
10461052
async with self.lock:

pymongo/pool_shared.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import socket
2121
import ssl
2222
import sys
23+
import threading
2324
from typing import (
2425
TYPE_CHECKING,
2526
Any,
@@ -46,6 +47,32 @@
4647
from pymongo.pyopenssl_context import _sslConn
4748
from pymongo.typings import _Address
4849

50+
51+
class _SSLSessionCache:
52+
"""Thread-safe cache for a single TLS session per pool, enabling session resumption."""
53+
54+
__slots__ = ("_session", "_lock")
55+
56+
def __init__(self) -> None:
57+
self._session: Optional[Any] = None
58+
self._lock = threading.Lock()
59+
60+
def get(self) -> Optional[Any]:
61+
with self._lock:
62+
return self._session
63+
64+
def set(self, session: Any) -> None:
65+
with self._lock:
66+
self._session = session
67+
68+
69+
def _get_ssl_session(ssl_sock: Any) -> Optional[Any]:
70+
"""Return the TLS session from an SSL socket, handling both PyOpenSSL and stdlib ssl."""
71+
if hasattr(ssl_sock, "get_session"):
72+
return ssl_sock.get_session()
73+
return getattr(ssl_sock, "session", None)
74+
75+
4976
try:
5077
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
5178

@@ -298,7 +325,9 @@ async def _async_configured_socket(
298325

299326

300327
async def _configured_protocol_interface(
301-
address: _Address, options: PoolOptions
328+
address: _Address,
329+
options: PoolOptions,
330+
ssl_session_cache: Optional[_SSLSessionCache] = None, # noqa: ARG001
302331
) -> AsyncNetworkingInterface:
303332
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
304333
@@ -470,7 +499,11 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
470499
return ssl_sock
471500

472501

473-
def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface:
502+
def _configured_socket_interface(
503+
address: _Address,
504+
options: PoolOptions,
505+
ssl_session_cache: Optional[_SSLSessionCache] = None,
506+
) -> NetworkingInterface:
474507
"""Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket.
475508
476509
Can raise socket.error, ConnectionFailure, or _CertificateError.
@@ -485,13 +518,14 @@ def _configured_socket_interface(address: _Address, options: PoolOptions) -> Net
485518
return NetworkingInterface(sock)
486519

487520
host = address[0]
521+
session = ssl_session_cache.get() if ssl_session_cache is not None else None
488522
try:
489523
# We have to pass hostname / ip address to wrap_socket
490524
# to use SSLContext.check_hostname.
491525
if _has_sni(True):
492-
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
526+
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host, session=session)
493527
else:
494-
ssl_sock = ssl_context.wrap_socket(sock)
528+
ssl_sock = ssl_context.wrap_socket(sock, session=session)
495529
except _CertificateError:
496530
sock.close()
497531
# Raise _CertificateError directly like we do after match_hostname
@@ -515,5 +549,10 @@ def _configured_socket_interface(address: _Address, options: PoolOptions) -> Net
515549
ssl_sock.close()
516550
raise
517551

552+
if ssl_session_cache is not None:
553+
new_session = _get_ssl_session(ssl_sock)
554+
if new_session is not None:
555+
ssl_session_cache.set(new_session)
556+
518557
ssl_sock.settimeout(options.socket_timeout)
519558
return NetworkingInterface(ssl_sock)

pymongo/synchronous/pool.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
_CancellationContext,
8383
_configured_socket_interface,
8484
_raise_connection_failure,
85+
_SSLSessionCache,
8586
)
8687
from pymongo.read_preferences import ReadPreference
8788
from pymongo.server_api import _add_to_command
@@ -752,6 +753,9 @@ def __init__(
752753
self._pending = 0
753754
self._max_connecting = self.opts.max_connecting
754755
self._client_id = client_id
756+
self._ssl_session_cache: Optional[_SSLSessionCache] = (
757+
_SSLSessionCache() if self.opts._ssl_context is not None else None
758+
)
755759
# Log before publishing event to prevent potential listener preemption in tests
756760
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
757761
_debug_log(
@@ -1036,7 +1040,9 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect
10361040
)
10371041

10381042
try:
1039-
networking_interface = _configured_socket_interface(self.address, self.opts)
1043+
networking_interface = _configured_socket_interface(
1044+
self.address, self.opts, self._ssl_session_cache
1045+
)
10401046
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
10411047
except BaseException as error:
10421048
with self.lock:

test/asynchronous/test_ssl.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ def test_config_ssl(self):
128128
def test_use_pyopenssl_when_available(self):
129129
self.assertTrue(HAVE_PYSSL)
130130

131+
def test_ssl_session_cache(self):
132+
from pymongo.pool_shared import _SSLSessionCache
133+
134+
cache = _SSLSessionCache()
135+
self.assertIsNone(cache.get())
136+
cache.set("session")
137+
self.assertEqual(cache.get(), "session")
138+
cache.set("new_session")
139+
self.assertEqual(cache.get(), "new_session")
140+
131141

132142
class TestSSL(AsyncIntegrationTest):
133143
saved_port: int
@@ -673,6 +683,23 @@ async def test_pyopenssl_ignored_in_async(self):
673683
await client.admin.command("ping") # command doesn't matter, just needs it to connect
674684
await client.close()
675685

686+
@async_client_context.require_tls
687+
async def test_pool_has_ssl_session_cache(self):
688+
from pymongo.pool_shared import _SSLSessionCache
689+
690+
pool = list(self.client._topology._servers.values())[0].pool
691+
self.assertIsInstance(pool._ssl_session_cache, _SSLSessionCache)
692+
693+
@async_client_context.require_tls
694+
@unittest.skipUnless(
695+
_IS_SYNC and _HAVE_PYOPENSSL, "Session caching only applies to PyOpenSSL sync path"
696+
)
697+
async def test_tls_session_cached_after_connect(self):
698+
await self.client.admin.command("ping")
699+
pool = list(self.client._topology._servers.values())[0].pool
700+
self.assertIsNotNone(pool._ssl_session_cache)
701+
self.assertIsNotNone(pool._ssl_session_cache.get())
702+
676703

677704
if __name__ == "__main__":
678705
unittest.main()

test/test_ssl.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ def test_config_ssl(self):
128128
def test_use_pyopenssl_when_available(self):
129129
self.assertTrue(HAVE_PYSSL)
130130

131+
def test_ssl_session_cache(self):
132+
from pymongo.pool_shared import _SSLSessionCache
133+
134+
cache = _SSLSessionCache()
135+
self.assertIsNone(cache.get())
136+
cache.set("session")
137+
self.assertEqual(cache.get(), "session")
138+
cache.set("new_session")
139+
self.assertEqual(cache.get(), "new_session")
140+
131141

132142
class TestSSL(IntegrationTest):
133143
saved_port: int
@@ -671,6 +681,23 @@ def test_pyopenssl_ignored_in_async(self):
671681
client.admin.command("ping") # command doesn't matter, just needs it to connect
672682
client.close()
673683

684+
@client_context.require_tls
685+
def test_pool_has_ssl_session_cache(self):
686+
from pymongo.pool_shared import _SSLSessionCache
687+
688+
pool = list(self.client._topology._servers.values())[0].pool
689+
self.assertIsInstance(pool._ssl_session_cache, _SSLSessionCache)
690+
691+
@client_context.require_tls
692+
@unittest.skipUnless(
693+
_IS_SYNC and _HAVE_PYOPENSSL, "Session caching only applies to PyOpenSSL sync path"
694+
)
695+
def test_tls_session_cached_after_connect(self):
696+
self.client.admin.command("ping")
697+
pool = list(self.client._topology._servers.values())[0].pool
698+
self.assertIsNotNone(pool._ssl_session_cache)
699+
self.assertIsNotNone(pool._ssl_session_cache.get())
700+
674701

675702
if __name__ == "__main__":
676703
unittest.main()

0 commit comments

Comments
 (0)