|
15 | 15 | import unittest |
16 | 16 |
|
17 | 17 | import logging |
| 18 | +import time |
18 | 19 | from unittest.mock import MagicMock |
19 | 20 | from concurrent.futures import ThreadPoolExecutor |
20 | 21 |
|
|
27 | 28 | LOGGER = logging.getLogger(__name__) |
28 | 29 |
|
29 | 30 |
|
| 31 | +class MockSession(MagicMock): |
| 32 | + is_shutdown = False |
| 33 | + keyspace = "ks1" |
| 34 | + |
| 35 | + def __init__(self, is_ssl=False, *args, **kwargs): |
| 36 | + super(MockSession, self).__init__(*args, **kwargs) |
| 37 | + self.cluster = MagicMock() |
| 38 | + if is_ssl: |
| 39 | + self.cluster.ssl_options = {'some_ssl_options': True} |
| 40 | + else: |
| 41 | + self.cluster.ssl_options = None |
| 42 | + self.cluster.shard_aware_options = ShardAwareOptions() |
| 43 | + self.cluster.executor = ThreadPoolExecutor(max_workers=2) |
| 44 | + self.cluster.signal_connection_failure = lambda *args, **kwargs: False |
| 45 | + self.cluster.connection_factory = self.mock_connection_factory |
| 46 | + self.connection_counter = 0 |
| 47 | + self.futures = [] |
| 48 | + |
| 49 | + def submit(self, fn, *args, **kwargs): |
| 50 | + logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs) |
| 51 | + if not self.is_shutdown: |
| 52 | + f = self.cluster.executor.submit(fn, *args, **kwargs) |
| 53 | + self.futures += [f] |
| 54 | + return f |
| 55 | + |
| 56 | + def mock_connection_factory(self, *args, **kwargs): |
| 57 | + connection = MagicMock() |
| 58 | + connection.is_shutdown = False |
| 59 | + connection.is_defunct = False |
| 60 | + connection.is_closed = False |
| 61 | + connection.orphaned_threshold_reached = False |
| 62 | + connection.endpoint = args[0] |
| 63 | + sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045) |
| 64 | + connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info) |
| 65 | + self.connection_counter += 1 |
| 66 | + |
| 67 | + return connection |
| 68 | + |
| 69 | + |
30 | 70 | class TestShardAware(unittest.TestCase): |
31 | 71 | def test_parsing_and_calculating_shard_id(self): |
32 | 72 | """ |
@@ -55,58 +95,58 @@ def test_advanced_shard_aware_port(self): |
55 | 95 | Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class) |
56 | 96 | the next connections would be open using this port |
57 | 97 | """ |
58 | | - class MockSession(MagicMock): |
59 | | - is_shutdown = False |
60 | | - keyspace = "ks1" |
61 | | - |
62 | | - def __init__(self, is_ssl=False, *args, **kwargs): |
63 | | - super(MockSession, self).__init__(*args, **kwargs) |
64 | | - self.cluster = MagicMock() |
65 | | - if is_ssl: |
66 | | - self.cluster.ssl_options = {'some_ssl_options': True} |
67 | | - else: |
68 | | - self.cluster.ssl_options = None |
69 | | - self.cluster.shard_aware_options = ShardAwareOptions() |
70 | | - self.cluster.executor = ThreadPoolExecutor(max_workers=2) |
71 | | - self.cluster.signal_connection_failure = lambda *args, **kwargs: False |
72 | | - self.cluster.connection_factory = self.mock_connection_factory |
73 | | - self.connection_counter = 0 |
74 | | - self.futures = [] |
75 | | - |
76 | | - def submit(self, fn, *args, **kwargs): |
77 | | - logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs) |
78 | | - if not self.is_shutdown: |
79 | | - f = self.cluster.executor.submit(fn, *args, **kwargs) |
80 | | - self.futures += [f] |
81 | | - return f |
82 | | - |
83 | | - def mock_connection_factory(self, *args, **kwargs): |
84 | | - connection = MagicMock() |
85 | | - connection.is_shutdown = False |
86 | | - connection.is_defunct = False |
87 | | - connection.is_closed = False |
88 | | - connection.orphaned_threshold_reached = False |
89 | | - connection.endpoint = args[0] |
90 | | - sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045) |
91 | | - connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info) |
92 | | - self.connection_counter += 1 |
93 | | - |
94 | | - return connection |
95 | | - |
96 | 98 | host = MagicMock() |
97 | 99 | host.endpoint = DefaultEndPoint("1.2.3.4") |
98 | 100 |
|
99 | 101 | for port, is_ssl in [(19042, False), (19045, True)]: |
100 | 102 | session = MockSession(is_ssl=is_ssl) |
101 | 103 | pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) |
102 | | - for f in session.futures: |
103 | | - f.result() |
104 | | - assert len(pool._connections) == 4 |
105 | | - for shard_id, connection in pool._connections.items(): |
106 | | - assert connection.features.shard_id == shard_id |
107 | | - if shard_id == 0: |
108 | | - assert connection.endpoint == DefaultEndPoint("1.2.3.4") |
109 | | - else: |
110 | | - assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port) |
111 | | - |
112 | | - session.cluster.executor.shutdown(wait=True) |
| 104 | + try: |
| 105 | + for f in session.futures: |
| 106 | + f.result() |
| 107 | + assert len(pool._connections) == 4 |
| 108 | + for shard_id, connection in pool._connections.items(): |
| 109 | + assert connection.features.shard_id == shard_id |
| 110 | + if shard_id == 0: |
| 111 | + assert connection.endpoint == DefaultEndPoint("1.2.3.4") |
| 112 | + else: |
| 113 | + assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port) |
| 114 | + finally: |
| 115 | + session.cluster.executor.shutdown(wait=True) |
| 116 | + |
| 117 | + def test_advanced_shard_aware_cooldown(self): |
| 118 | + """ |
| 119 | + `disable_advanced_shard_aware` must suppress the shard-aware endpoint for |
| 120 | + the duration of the cool-down window, then automatically restore it once |
| 121 | + the deadline has passed. The hard-disable flag must suppress the endpoint |
| 122 | + unconditionally. |
| 123 | + """ |
| 124 | + host = MagicMock() |
| 125 | + host.endpoint = DefaultEndPoint("1.2.3.4") |
| 126 | + session = MockSession(is_ssl=False) |
| 127 | + |
| 128 | + pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) |
| 129 | + for f in session.futures: |
| 130 | + f.result() |
| 131 | + |
| 132 | + try: |
| 133 | + # Baseline: shard-aware port is returned. |
| 134 | + endpoint = pool._get_shard_aware_endpoint() |
| 135 | + assert endpoint is not None |
| 136 | + assert endpoint.port == 19042 |
| 137 | + |
| 138 | + # During the cool-down window `_get_shard_aware_endpoint` must return None. |
| 139 | + pool.disable_advanced_shard_aware(600) |
| 140 | + assert pool._get_shard_aware_endpoint() is None |
| 141 | + |
| 142 | + # Once the deadline has passed, the shard-aware port must be used again. |
| 143 | + pool.advanced_shardaware_block_until = time.time() - 1 |
| 144 | + endpoint = pool._get_shard_aware_endpoint() |
| 145 | + assert endpoint is not None |
| 146 | + assert endpoint.port == 19042 |
| 147 | + |
| 148 | + # The hard-disable flag must suppress the endpoint regardless of the timer. |
| 149 | + session.cluster.shard_aware_options.disable_shardaware_port = True |
| 150 | + assert pool._get_shard_aware_endpoint() is None |
| 151 | + finally: |
| 152 | + session.cluster.executor.shutdown(wait=True) |
0 commit comments