Skip to content

Commit b1ae845

Browse files
test: cover pool replacement race from #317
Add a deterministic unit test for the case where another thread publishes a pool while a slower add attempt is still constructing its pool. This guards against closing in-flight connections by replacing the pool that should remain current. Refs: #317
1 parent d8530e8 commit b1ae845

1 file changed

Lines changed: 35 additions & 1 deletion

File tree

tests/unit/test_cluster.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import socket
18+
import threading
1819

1920
from unittest.mock import patch, Mock
2021
import uuid
@@ -23,7 +24,7 @@
2324
InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion
2425
from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \
2526
ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT
26-
from cassandra.pool import Host
27+
from cassandra.pool import Host, HostConnection
2728
from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy
2829
from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory
2930
from tests.unit.utils import mock_session_pools
@@ -339,6 +340,39 @@ def test_set_keyspace_escapes_quotes(self, *_):
339340
assert query == 'USE simple_ks', (
340341
"Simple keyspace names should not be quoted, got: %r" % query)
341342

343+
344+
class SessionPoolRaceTest(unittest.TestCase):
345+
def test_concurrent_add_or_renew_pool_no_double_replace(self):
346+
"""Reproduces https://github.com/scylladb/python-driver/issues/317."""
347+
host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())
348+
349+
session = Session.__new__(Session)
350+
session.submit = lambda fn: Mock(result=lambda timeout=None: fn())
351+
session.keyspace = None
352+
session._lock = threading.RLock()
353+
session._pools = {}
354+
session._profile_manager = Mock()
355+
session._profile_manager.distance.return_value = HostDistance.LOCAL
356+
357+
winner_pool = Mock()
358+
created_pools = []
359+
360+
def fake_host_connection_init(pool, *_):
361+
pool._keyspace = session.keyspace
362+
pool.shutdown = Mock()
363+
created_pools.append(pool)
364+
log.info("Publishing competing pool while replacement pool is being created")
365+
with session._lock:
366+
session._pools[host] = winner_pool
367+
368+
with patch.object(HostConnection, '__init__', fake_host_connection_init):
369+
result = session.add_or_renew_pool(host, is_host_addition=True).result()
370+
371+
assert result is True
372+
assert session._pools[host] is winner_pool
373+
created_pools[0].shutdown.assert_called_once()
374+
winner_pool.shutdown.assert_not_called()
375+
342376
class ProtocolVersionTests(unittest.TestCase):
343377

344378
def test_protocol_downgrade_test(self):

0 commit comments

Comments
 (0)