Skip to content

Commit ca5b8c2

Browse files
nikagrasylwiaszunejko
authored andcommitted
pool: fix inverted cooldown check in _get_shard_aware_endpoint
The `block_until < time.time()` condition was true only *after* the NAT-detection cooldown had already expired, so the shard-aware port was never suppressed during the 10-minute window and was permanently disabled once that window closed. Fix: flip to `>` so the guard fires while the deadline is in the future. Add unit test covering the active-block, expired-block, and hard-disable paths to prevent regression.
1 parent 284bd90 commit ca5b8c2

2 files changed

Lines changed: 90 additions & 50 deletions

File tree

cassandra/pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def disable_advanced_shard_aware(self, secs):
677677
self.advanced_shardaware_block_until = max(time.time() + secs, self.advanced_shardaware_block_until)
678678

679679
def _get_shard_aware_endpoint(self):
680-
if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until < time.time()) or \
680+
if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until > time.time()) or \
681681
self._session.cluster.shard_aware_options.disable_shardaware_port:
682682
return None
683683

tests/unit/test_shard_aware.py

Lines changed: 89 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import unittest
1616

1717
import logging
18+
import time
1819
from unittest.mock import MagicMock
1920
from concurrent.futures import ThreadPoolExecutor
2021

@@ -27,6 +28,45 @@
2728
LOGGER = logging.getLogger(__name__)
2829

2930

31+
class MockSession(MagicMock):
32+
is_shutdown = False
33+
keyspace = "ks1"
34+
35+
def __init__(self, is_ssl=False, *args, **kwargs):
36+
super(MockSession, self).__init__(*args, **kwargs)
37+
self.cluster = MagicMock()
38+
if is_ssl:
39+
self.cluster.ssl_options = {'some_ssl_options': True}
40+
else:
41+
self.cluster.ssl_options = None
42+
self.cluster.shard_aware_options = ShardAwareOptions()
43+
self.cluster.executor = ThreadPoolExecutor(max_workers=2)
44+
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
45+
self.cluster.connection_factory = self.mock_connection_factory
46+
self.connection_counter = 0
47+
self.futures = []
48+
49+
def submit(self, fn, *args, **kwargs):
50+
logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
51+
if not self.is_shutdown:
52+
f = self.cluster.executor.submit(fn, *args, **kwargs)
53+
self.futures += [f]
54+
return f
55+
56+
def mock_connection_factory(self, *args, **kwargs):
57+
connection = MagicMock()
58+
connection.is_shutdown = False
59+
connection.is_defunct = False
60+
connection.is_closed = False
61+
connection.orphaned_threshold_reached = False
62+
connection.endpoint = args[0]
63+
sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
64+
connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info)
65+
self.connection_counter += 1
66+
67+
return connection
68+
69+
3070
class TestShardAware(unittest.TestCase):
3171
def test_parsing_and_calculating_shard_id(self):
3272
"""
@@ -55,58 +95,58 @@ def test_advanced_shard_aware_port(self):
5595
Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class)
5696
the next connections would be open using this port
5797
"""
58-
class MockSession(MagicMock):
59-
is_shutdown = False
60-
keyspace = "ks1"
61-
62-
def __init__(self, is_ssl=False, *args, **kwargs):
63-
super(MockSession, self).__init__(*args, **kwargs)
64-
self.cluster = MagicMock()
65-
if is_ssl:
66-
self.cluster.ssl_options = {'some_ssl_options': True}
67-
else:
68-
self.cluster.ssl_options = None
69-
self.cluster.shard_aware_options = ShardAwareOptions()
70-
self.cluster.executor = ThreadPoolExecutor(max_workers=2)
71-
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
72-
self.cluster.connection_factory = self.mock_connection_factory
73-
self.connection_counter = 0
74-
self.futures = []
75-
76-
def submit(self, fn, *args, **kwargs):
77-
logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
78-
if not self.is_shutdown:
79-
f = self.cluster.executor.submit(fn, *args, **kwargs)
80-
self.futures += [f]
81-
return f
82-
83-
def mock_connection_factory(self, *args, **kwargs):
84-
connection = MagicMock()
85-
connection.is_shutdown = False
86-
connection.is_defunct = False
87-
connection.is_closed = False
88-
connection.orphaned_threshold_reached = False
89-
connection.endpoint = args[0]
90-
sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
91-
connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info)
92-
self.connection_counter += 1
93-
94-
return connection
95-
9698
host = MagicMock()
9799
host.endpoint = DefaultEndPoint("1.2.3.4")
98100

99101
for port, is_ssl in [(19042, False), (19045, True)]:
100102
session = MockSession(is_ssl=is_ssl)
101103
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
102-
for f in session.futures:
103-
f.result()
104-
assert len(pool._connections) == 4
105-
for shard_id, connection in pool._connections.items():
106-
assert connection.features.shard_id == shard_id
107-
if shard_id == 0:
108-
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
109-
else:
110-
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
111-
112-
session.cluster.executor.shutdown(wait=True)
104+
try:
105+
for f in session.futures:
106+
f.result()
107+
assert len(pool._connections) == 4
108+
for shard_id, connection in pool._connections.items():
109+
assert connection.features.shard_id == shard_id
110+
if shard_id == 0:
111+
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
112+
else:
113+
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
114+
finally:
115+
session.cluster.executor.shutdown(wait=True)
116+
117+
def test_advanced_shard_aware_cooldown(self):
118+
"""
119+
`disable_advanced_shard_aware` must suppress the shard-aware endpoint for
120+
the duration of the cool-down window, then automatically restore it once
121+
the deadline has passed. The hard-disable flag must suppress the endpoint
122+
unconditionally.
123+
"""
124+
host = MagicMock()
125+
host.endpoint = DefaultEndPoint("1.2.3.4")
126+
session = MockSession(is_ssl=False)
127+
128+
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
129+
for f in session.futures:
130+
f.result()
131+
132+
try:
133+
# Baseline: shard-aware port is returned.
134+
endpoint = pool._get_shard_aware_endpoint()
135+
assert endpoint is not None
136+
assert endpoint.port == 19042
137+
138+
# During the cool-down window `_get_shard_aware_endpoint` must return None.
139+
pool.disable_advanced_shard_aware(600)
140+
assert pool._get_shard_aware_endpoint() is None
141+
142+
# Once the deadline has passed, the shard-aware port must be used again.
143+
pool.advanced_shardaware_block_until = time.time() - 1
144+
endpoint = pool._get_shard_aware_endpoint()
145+
assert endpoint is not None
146+
assert endpoint.port == 19042
147+
148+
# The hard-disable flag must suppress the endpoint regardless of the timer.
149+
session.cluster.shard_aware_options.disable_shardaware_port = True
150+
assert pool._get_shard_aware_endpoint() is None
151+
finally:
152+
session.cluster.executor.shutdown(wait=True)

0 commit comments

Comments
 (0)