|
15 | 15 |
|
16 | 16 | import logging |
17 | 17 | import socket |
| 18 | +import threading |
18 | 19 |
|
19 | 20 | from unittest.mock import patch, Mock |
20 | 21 | import uuid |
|
23 | 24 | InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion |
24 | 25 | from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ |
25 | 26 | ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT |
26 | | -from cassandra.pool import Host |
| 27 | +from cassandra.pool import Host, HostConnection |
27 | 28 | from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy |
28 | 29 | from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory |
29 | 30 | from tests.unit.utils import mock_session_pools |
@@ -339,6 +340,39 @@ def test_set_keyspace_escapes_quotes(self, *_): |
339 | 340 | assert query == 'USE simple_ks', ( |
340 | 341 | "Simple keyspace names should not be quoted, got: %r" % query) |
341 | 342 |
|
| 343 | + |
| 344 | +class SessionPoolRaceTest(unittest.TestCase): |
| 345 | + def test_concurrent_add_or_renew_pool_no_double_replace(self): |
| 346 | + """Reproduces https://github.com/scylladb/python-driver/issues/317.""" |
| 347 | + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) |
| 348 | + |
| 349 | + session = Session.__new__(Session) |
| 350 | + session.submit = lambda fn: Mock(result=lambda timeout=None: fn()) |
| 351 | + session.keyspace = None |
| 352 | + session._lock = threading.RLock() |
| 353 | + session._pools = {} |
| 354 | + session._profile_manager = Mock() |
| 355 | + session._profile_manager.distance.return_value = HostDistance.LOCAL |
| 356 | + |
| 357 | + winner_pool = Mock() |
| 358 | + created_pools = [] |
| 359 | + |
| 360 | + def fake_host_connection_init(pool, *_): |
| 361 | + pool._keyspace = session.keyspace |
| 362 | + pool.shutdown = Mock() |
| 363 | + created_pools.append(pool) |
| 364 | + log.info("Publishing competing pool while replacement pool is being created") |
| 365 | + with session._lock: |
| 366 | + session._pools[host] = winner_pool |
| 367 | + |
| 368 | + with patch.object(HostConnection, '__init__', fake_host_connection_init): |
| 369 | + result = session.add_or_renew_pool(host, is_host_addition=True).result() |
| 370 | + |
| 371 | + assert result is True |
| 372 | + assert session._pools[host] is winner_pool |
| 373 | + created_pools[0].shutdown.assert_called_once() |
| 374 | + winner_pool.shutdown.assert_not_called() |
| 375 | + |
342 | 376 | class ProtocolVersionTests(unittest.TestCase): |
343 | 377 |
|
344 | 378 | def test_protocol_downgrade_test(self): |
|
0 commit comments