Skip to content

Commit d30b6d4

Browse files
committed
client-routes: preserve partial route state
Partial CLIENT_ROUTES_CHANGE handling must not treat filtered event entries as affected route state. Limit merge invalidation to configured event pairs so unrelated connection IDs cannot drop cached proxy routes. For same-host partial updates, fetch all configured connection IDs for affected hosts. This lets the route store keep the currently preferred proxy route when it is still present instead of switching because the partial event omitted it. Also keep ClientRoutesEndPoint identity port-aware and sortable when original_port is missing. Fixes #846 Refs #813
1 parent 0842348 commit d30b6d4

3 files changed

Lines changed: 142 additions & 18 deletions

File tree

cassandra/client_routes.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def handle_client_routes_change(self, connection: 'Connection', timeout: float,
294294
return
295295

296296
routes = self._query_routes_for_change_event(connection, timeout, pairs)
297-
self._routes.merge(routes, affected_host_ids=set(host_uuids))
297+
self._routes.merge(routes, affected_host_ids={host_id for _, host_id in pairs})
298298

299299
def _query_all_routes_for_connections(self, connection: 'Connection', timeout: float,
300300
connection_ids: Set[str]) -> List[_Route]:
@@ -322,27 +322,25 @@ def _query_all_routes_for_connections(self, connection: 'Connection', timeout: f
322322
def _query_routes_for_change_event(self, connection: 'Connection', timeout: float,
323323
route_pairs: List[Tuple[str, uuid.UUID]]) -> List[_Route]:
324324
"""
325-
Query specific routes affected by a CLIENT_ROUTES_CHANGE event.
325+
Query current routes for hosts affected by a CLIENT_ROUTES_CHANGE event.
326326
327-
Takes a list of (connection_id, host_id) pairs that represent the exact
328-
routes affected by an operation. This provides precise updates without
329-
fetching unrelated routes.
330-
331-
If the pairs list is empty or None, falls back to a complete refresh
332-
of all routes for safety.
327+
The in-memory route store keeps a single preferred route per host. When
328+
any configured connection_id changes for a host, fetch all configured
329+
connection_ids for that host so the existing preferred route can be
330+
retained if it is still present.
333331
334332
:param connection: Connection to execute query on
335333
:param timeout: Query timeout in seconds
336-
:param route_pairs: List of (connection_id, host_id) tuples
334+
:param route_pairs: List of affected (connection_id, host_id) tuples
337335
:return: List of _Route
338336
"""
339337
unique_pairs = list(dict.fromkeys(route_pairs))
340338

341-
conn_ids = list(dict.fromkeys(cid for cid, _ in unique_pairs))
339+
conn_ids = sorted(self._connection_ids)
342340
host_ids = list(dict.fromkeys(hid for _, hid in unique_pairs))
343341

344-
log.debug("[client routes] Querying route pairs from CLIENT_ROUTES_CHANGE "
345-
"(first 5 of %d): %s", len(unique_pairs), unique_pairs[:5])
342+
log.debug("[client routes] Querying routes from CLIENT_ROUTES_CHANGE "
343+
"for host_ids (first 5 of %d): %s", len(host_ids), host_ids[:5])
346344

347345
conn_ph = ', '.join('?' for _ in conn_ids)
348346
host_ph = ', '.join('?' for _ in host_ids)

cassandra/connection.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,21 +468,25 @@ def resolve(self) -> Tuple[str, int]:
468468
def __eq__(self, other):
469469
return (isinstance(other, ClientRoutesEndPoint) and
470470
self._host_id == other._host_id and
471-
self._original_address == other._original_address)
471+
self._original_address == other._original_address and
472+
self._original_port == other._original_port)
472473

473474
def __hash__(self):
474-
return hash((self._host_id, self._original_address))
475+
return hash((self._host_id, self._original_address, self._original_port))
476+
477+
def _comparison_key(self):
478+
return (self._host_id, self._original_address,
479+
self._original_port is None, self._original_port)
475480

476481
def __lt__(self, other):
477-
return ((self._host_id, self._original_address) <
478-
(other._host_id, other._original_address))
482+
return self._comparison_key() < other._comparison_key()
479483

480484
def __str__(self):
481485
return str("%s (host_id=%s)" % (self._original_address, self._host_id))
482486

483487
def __repr__(self):
484-
return "<%s: host_id=%s, original_addr=%s>" % (
485-
self.__class__.__name__, self._host_id, self._original_address)
488+
return "<%s: host_id=%s, original_addr=%s, original_port=%s>" % (
489+
self.__class__.__name__, self._host_id, self._original_address, self._original_port)
486490

487491

488492
class _Frame(object):

tests/unit/test_client_routes.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,92 @@ def test_handle_change_merges_when_host_ids_present(self, mock_query):
233233
self.assertIsNotNone(handler._routes.get_by_host_id(existing_host))
234234
self.assertIsNotNone(handler._routes.get_by_host_id(new_host))
235235

236+
@patch.object(_ClientRoutesHandler, '_query_routes_for_change_event')
237+
def test_handle_change_preserves_routes_for_unrelated_connection_ids(self, mock_query):
238+
"""Routes for unrelated connection_ids in mixed events should not be removed."""
239+
handler = _ClientRoutesHandler(self.config)
240+
mock_conn = Mock()
241+
242+
conn_id = str(self.conn_id)
243+
changed_host = uuid.uuid4()
244+
unrelated_host = uuid.uuid4()
245+
246+
handler._routes.update([
247+
_Route(connection_id=conn_id, host_id=changed_host, address="old.com", port=9042),
248+
_Route(connection_id=conn_id, host_id=unrelated_host, address="keep.com", port=9042),
249+
])
250+
251+
mock_query.return_value = [
252+
_Route(connection_id=conn_id, host_id=changed_host, address="new.com", port=9042),
253+
]
254+
255+
handler.handle_client_routes_change(
256+
mock_conn, 5.0,
257+
ClientRoutesChangeType.UPDATE_NODES,
258+
connection_ids=[conn_id, "unrelated-conn-id"],
259+
host_ids=[str(changed_host), str(unrelated_host)],
260+
)
261+
262+
self.assertEqual(handler._routes.get_by_host_id(changed_host).address, "new.com")
263+
self.assertEqual(handler._routes.get_by_host_id(unrelated_host).address, "keep.com")
264+
265+
def test_handle_change_preserves_preferred_route_for_same_host(self):
266+
conn_a = str(uuid.uuid4())
267+
conn_b = str(uuid.uuid4())
268+
host_id = uuid.uuid4()
269+
config = ClientRoutesConfig([
270+
ClientRouteProxy(conn_a),
271+
ClientRouteProxy(conn_b),
272+
])
273+
handler = _ClientRoutesHandler(config)
274+
handler._routes.update([
275+
_Route(connection_id=conn_b, host_id=host_id,
276+
address="current.example.com", port=9042),
277+
])
278+
279+
table_routes = [
280+
_Route(connection_id=conn_a, host_id=host_id,
281+
address="changed.example.com", port=9042),
282+
_Route(connection_id=conn_b, host_id=host_id,
283+
address="current.example.com", port=9042),
284+
]
285+
286+
def wait_for_response(query_msg, timeout):
287+
conn_placeholders = query_msg.query.split(
288+
"connection_id IN (", 1)[1].split(")", 1)[0].count("?")
289+
conn_ids = {
290+
param.decode("utf-8")
291+
for param in query_msg.query_params[:conn_placeholders]
292+
}
293+
host_ids = {
294+
uuid.UUID(bytes=param)
295+
for param in query_msg.query_params[conn_placeholders:]
296+
}
297+
rows = [
298+
(route.connection_id, route.host_id, route.address,
299+
route.port, route.port)
300+
for route in table_routes
301+
if route.connection_id in conn_ids and route.host_id in host_ids
302+
]
303+
return Mock(
304+
column_names=["connection_id", "host_id", "address", "port", "tls_port"],
305+
parsed_rows=rows,
306+
)
307+
308+
mock_conn = Mock()
309+
mock_conn.wait_for_response.side_effect = wait_for_response
310+
311+
handler.handle_client_routes_change(
312+
mock_conn, 5.0,
313+
ClientRoutesChangeType.UPDATE_NODES,
314+
connection_ids=[conn_a],
315+
host_ids=[str(host_id)],
316+
)
317+
318+
route = handler._routes.get_by_host_id(host_id)
319+
self.assertEqual(route.connection_id, conn_b)
320+
self.assertEqual(route.address, "current.example.com")
321+
236322
@patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections')
237323
def test_handle_change_updates_when_no_host_ids(self, mock_query):
238324
"""When no host_ids are provided, routes should be fully replaced."""
@@ -388,6 +474,42 @@ def test_resolve_host_missing_port_raises(self):
388474
with self.assertRaises(ValueError):
389475
self.handler.resolve_host(host_id)
390476

477+
def test_endpoint_identity_includes_original_port(self):
478+
host_id = uuid.uuid4()
479+
first = ClientRoutesEndPoint(
480+
host_id=host_id,
481+
handler=self.handler,
482+
original_address="10.0.0.1",
483+
original_port=9042,
484+
)
485+
second = ClientRoutesEndPoint(
486+
host_id=host_id,
487+
handler=self.handler,
488+
original_address="10.0.0.1",
489+
original_port=9142,
490+
)
491+
492+
self.assertNotEqual(first, second)
493+
self.assertEqual(len({first, second}), 2)
494+
495+
def test_endpoint_ordering_handles_missing_original_port(self):
496+
host_id = uuid.uuid4()
497+
without_port = ClientRoutesEndPoint(
498+
host_id=host_id,
499+
handler=self.handler,
500+
original_address="10.0.0.1",
501+
original_port=None,
502+
)
503+
with_port = ClientRoutesEndPoint(
504+
host_id=host_id,
505+
handler=self.handler,
506+
original_address="10.0.0.1",
507+
original_port=9042,
508+
)
509+
510+
self.assertCountEqual(
511+
sorted([without_port, with_port]), [without_port, with_port])
512+
391513

392514
class TestClientRoutesEndPointFactory(unittest.TestCase):
393515

0 commit comments

Comments
 (0)