Skip to content

Commit c2a47bb

Browse files
committed
pool: discard stale replacement shard connections
1 parent d940e17 commit c2a47bb

2 files changed

Lines changed: 123 additions & 3 deletions

File tree

cassandra/pool.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,13 +785,30 @@ def _replace(self, connection):
785785
self._session.submit(self._replace, connection)
786786
return
787787

788+
stale_endpoint = False
788789
with self._lock:
789790
if self.is_shutdown:
790791
replacement_connection.close()
791792
self._is_replacing = False
792793
return
793-
self._connections[replacement_connection.features.shard_id] = replacement_connection
794+
with self.host.lock:
795+
stale_endpoint = not (
796+
_endpoints_match(
797+
self._session.cluster, self.host.endpoint,
798+
expected_endpoint) and
799+
_host_is_current_for_endpoint(
800+
self._session.cluster, self.host, expected_endpoint))
801+
if not stale_endpoint:
802+
self._connections[replacement_connection.features.shard_id] = replacement_connection
794803
self._is_replacing = False
804+
if stale_endpoint:
805+
log.debug("Ignoring stale connection replacement for host %s; endpoint changed from %s",
806+
self.host, expected_endpoint)
807+
replacement_connection.close()
808+
self._remove_stale_pool(expected_endpoint)
809+
with self._stream_available_condition:
810+
self._stream_available_condition.notify()
811+
return
795812
with self._stream_available_condition:
796813
self._stream_available_condition.notify()
797814

@@ -962,11 +979,28 @@ def _open_connection_to_missing_shard(self, shard_id):
962979
)
963980
if self._keyspace:
964981
conn.set_keyspace_blocking(self._keyspace)
965-
self._connections[conn.features.shard_id] = conn
982+
with self.host.lock:
983+
stale_endpoint = not (
984+
_endpoints_match(
985+
self._session.cluster, self.host.endpoint,
986+
expected_endpoint) and
987+
_host_is_current_for_endpoint(
988+
self._session.cluster, self.host,
989+
expected_endpoint))
990+
if not stale_endpoint:
991+
self._connections[conn.features.shard_id] = conn
966992

967993
if is_shutdown:
968994
conn.close()
969995
return
996+
if stale_endpoint:
997+
log.debug("Ignoring stale shard connection replacement for host %s; endpoint changed from %s",
998+
self.host, expected_endpoint)
999+
conn.close()
1000+
self._remove_stale_pool(expected_endpoint)
1001+
with self._stream_available_condition:
1002+
self._stream_available_condition.notify()
1003+
return
9701004

9711005
if old_conn is not None:
9721006
remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids)

tests/unit/test_host_connection_pool.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from cassandra.shard_info import _ShardingInfo
2121

2222
import unittest
23-
from threading import Thread, Event, Lock
23+
from threading import Thread, Event, Lock, Condition
2424
from unittest.mock import Mock, NonCallableMagicMock, MagicMock
2525

2626
from cassandra.cluster import Cluster, Session, ShardAwareOptions
@@ -454,3 +454,89 @@ def test_replace_retries_when_replacement_keyspace_set_fails(self):
454454
submitted_fn, submitted_connection = session.submit.call_args.args
455455
assert submitted_fn == pool._replace
456456
assert submitted_connection is initial_connection
457+
458+
def test_replace_discards_replacement_when_endpoint_changes_during_keyspace_set(self):
459+
old_endpoint = DefaultEndPoint('127.0.0.1')
460+
new_endpoint = DefaultEndPoint('127.0.0.2')
461+
host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4())
462+
session = NonCallableMagicMock(spec=Session, keyspace='ks')
463+
session.cluster = MagicMock()
464+
session.cluster.shard_aware_options = ShardAwareOptions()
465+
session.cluster._endpoints_match.side_effect = Cluster._endpoints_match
466+
session.remove_pool.return_value = None
467+
initial_connection = HashableMock(
468+
spec=Connection, in_flight=0, is_defunct=False, is_closed=False,
469+
max_request_id=100, signaled_error=False,
470+
orphaned_threshold_reached=False,
471+
features=ProtocolFeatures(shard_id=0))
472+
replacement_connection = HashableMock(
473+
spec=Connection, in_flight=0, is_defunct=False, is_closed=False,
474+
max_request_id=100, signaled_error=False,
475+
orphaned_threshold_reached=False,
476+
features=ProtocolFeatures(shard_id=0))
477+
replacement_connection.set_keyspace_blocking.side_effect = (
478+
lambda keyspace: setattr(host, 'endpoint', new_endpoint))
479+
session.cluster.connection_factory.side_effect = [
480+
initial_connection, replacement_connection]
481+
482+
pool = HostConnection(host, HostDistance.LOCAL, session)
483+
pool._is_replacing = True
484+
485+
pool._replace(initial_connection)
486+
487+
replacement_connection.close.assert_called_once_with()
488+
session.remove_pool.assert_called_once_with(
489+
host, expected_host=host, expected_endpoint=old_endpoint,
490+
expected_pool=pool)
491+
assert pool._connections == {}
492+
assert not pool._is_replacing
493+
494+
def test_missing_shard_discards_connection_when_endpoint_changes_during_keyspace_set(self):
495+
old_endpoint = DefaultEndPoint('127.0.0.1')
496+
new_endpoint = DefaultEndPoint('127.0.0.2')
497+
host = Host(old_endpoint, SimpleConvictionPolicy, host_id=uuid.uuid4())
498+
host.sharding_info = _ShardingInfo(
499+
shard_id=0, shards_count=1, partitioner='',
500+
sharding_algorithm='', sharding_ignore_msb=0,
501+
shard_aware_port='', shard_aware_port_ssl='')
502+
session = NonCallableMagicMock(spec=Session, keyspace='ks')
503+
session.cluster = MagicMock()
504+
session.cluster.shard_aware_options = ShardAwareOptions()
505+
session.cluster.ssl_options = None
506+
session.cluster._endpoints_match.side_effect = Cluster._endpoints_match
507+
session.remove_pool.return_value = None
508+
connection = HashableMock(
509+
spec=Connection, in_flight=0, is_defunct=False, is_closed=False,
510+
max_request_id=100, signaled_error=False,
511+
orphaned_threshold_reached=False,
512+
features=ProtocolFeatures(shard_id=0))
513+
connection.set_keyspace_blocking.side_effect = (
514+
lambda keyspace: setattr(host, 'endpoint', new_endpoint))
515+
session.cluster.connection_factory.return_value = connection
516+
517+
pool = HostConnection.__new__(HostConnection)
518+
pool.host = host
519+
pool.endpoint = old_endpoint
520+
pool.host_distance = HostDistance.LOCAL
521+
pool.is_shutdown = False
522+
pool._session = session
523+
pool._lock = Lock()
524+
pool._stream_available_condition = Condition(Lock())
525+
pool._connections = {}
526+
pool._pending_connections = []
527+
pool._connecting = {0}
528+
pool._excess_connections = set()
529+
pool._trash = set()
530+
pool._shard_connections_futures = []
531+
pool._keyspace = 'ks'
532+
pool.advanced_shardaware_block_until = 0
533+
pool.tablets_routing_v1 = False
534+
535+
pool._open_connection_to_missing_shard(0)
536+
537+
connection.close.assert_called_once_with()
538+
session.remove_pool.assert_called_once_with(
539+
host, expected_host=host, expected_endpoint=old_endpoint,
540+
expected_pool=pool)
541+
assert pool._connections == {}
542+
assert pool._connecting == set()

0 commit comments

Comments
 (0)