From 9aa5ddceaae9fad29ddd778016d4913f77e612ac Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 12 Mar 2026 19:04:16 +0100 Subject: [PATCH 01/27] Add optional query_params parameter to QueryMessage --- cassandra/protocol.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index f37633a756..4628c7ee0e 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -611,9 +611,10 @@ class QueryMessage(_QueryMessage): name = 'QUERY' def __init__(self, query, consistency_level, serial_consistency_level=None, - fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None): + fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None, + query_params=None): self.query = query - super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size, + super(QueryMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size, paging_state, timestamp, False, continuous_paging_options, keyspace) def send_body(self, f, protocol_version): From 8bba6ebd361e4df959cc4f02dcbcb67201c6a526 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Fri, 13 Mar 2026 08:10:59 +0100 Subject: [PATCH 02/27] Introduce skip_scylla_version_lt for integration tests --- tests/integration/__init__.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index dfac2dc1d9..a53e7aafa6 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -715,6 +715,27 @@ def xfail_scylla_version_lt(reason, oss_scylla_version, ent_scylla_version, *arg return pytest.mark.xfail(current_version < Version(oss_scylla_version), reason=reason, *args, **kwargs) + +def skip_scylla_version_lt(reason, scylla_version): + """ + Skip tests on scylla versions older than the specified thresholds. + :param reason: message explaining why the test is skipped + :param scylla_version: str, version from which test supposed to work + """ + if not (reason.startswith("scylladb/scylladb#") or reason.startswith("scylladb/scylla-enterprise#")): + raise ValueError('reason should start with scylladb/scylladb# or scylladb/scylla-enterprise# to reference issue in scylla repo') + + if not isinstance(scylla_version, str): + raise ValueError('scylla_version should be a str') + + if SCYLLA_VERSION is None: + return pytest.mark.skipif(False, reason="It is just a NoOP Decor, should not skip anything") + + current_version = Version(get_scylla_version(SCYLLA_VERSION)) + + return pytest.mark.skipif(current_version < Version(scylla_version), reason=reason) + + class UpDownWaiter(object): def __init__(self, host): From bc864c1b4e7e030c22aef539b622865e5e0fea95 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Tue, 17 Mar 2026 10:31:29 +0100 Subject: [PATCH 03/27] Add client routes data types and route store Introduce the data layer for Private Link client routes support: - ClientRoutesChangeType enum for CLIENT_ROUTES_CHANGE event types - ClientRouteProxy dataclass and ClientRoutesConfig for user-facing configuration - _Route frozen dataclass for immutable route records - _RouteStore for thread-safe route storage with atomic update/merge and preferred route selection that avoids unnecessary connection_id migration when multiple routes exist for the same host --- cassandra/client_routes.py | 192 +++++++++++++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 cassandra/client_routes.py diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py new file mode 100644 index 0000000000..f26aeef152 --- /dev/null +++ b/cassandra/client_routes.py @@ -0,0 +1,192 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Client Routes support for Private Link and similar network configurations. + +This module implements support for dynamic address translation via the +system.client_routes table and CLIENT_ROUTES_CHANGE events. +""" + +from __future__ import absolute_import + +from dataclasses import dataclass +import enum +import logging +import threading +import uuid +from typing import Dict, List, Optional, Set + +log = logging.getLogger(__name__) + + +class ClientRoutesChangeType(enum.Enum): + """ + Types of CLIENT_ROUTES_CHANGE events. + + Currently the protocol defines only UPDATE_NODES. + New variants will be added here if the protocol is extended. + """ + UPDATE_NODES = "UPDATE_NODES" + + +@dataclass +class ClientRouteProxy: + """ + :param connection_id: String identifying the connection (required) + :param connection_addr_override:: Optional string address for initial connection + """ + + connection_id: str + connection_addr_override: Optional[str] = None + + def __post_init__(self): + if self.connection_id is None: + raise ValueError("connection_id is required") + +class ClientRoutesConfig: + """ + Configuration for client routes (Private Link support). + + :param proxies: List of :class:`ClientRouteProxy` objects + (REQUIRED, at least one) + :param advanced_shard_awareness: Whether to enable advanced shard awareness + (default: ``False``) + """ + + proxies: List[ClientRouteProxy] + advanced_shard_awareness: bool + + def __init__(self, proxies: List[ClientRouteProxy], advanced_shard_awareness: bool = False): + """ + :param proxies: List of ClientRouteProxy objects + :param advanced_shard_awareness: Enable advanced shard awareness (default False) + """ + if not proxies: + raise ValueError("At least one proxy must be specified") + + if not isinstance(proxies, (list, tuple)): + raise TypeError("proxies must be a list or tuple") + + for proxy in proxies: + if not isinstance(proxy, ClientRouteProxy): + raise TypeError("All proxies must be ClientRouteProxy instances") + + self.proxies = proxies + self.advanced_shard_awareness = advanced_shard_awareness + + def __repr__(self) -> str: + return (f"ClientRoutesConfig(proxies={self.proxies}, " + f"advanced_shard_awareness={self.advanced_shard_awareness})") + + +@dataclass(frozen=True) +class _Route: + connection_id: str + host_id: uuid.UUID + address: str # ipv4, ipv6 or DNS hostname from system.client_routes + port: int + +class _RouteStore: + """ + Thread-safe storage for routes. Reads are safe under CPython's GIL; + writes are serialized with a lock. + + This uses atomic pointer swaps for updates, allowing lock-free reads + while serializing writes. + """ + + _routes_by_host_id: Dict[uuid.UUID, _Route] + _lock: threading.Lock + + def __init__(self) -> None: + self._routes_by_host_id = {} + self._lock = threading.Lock() + + def get_by_host_id(self, host_id: uuid.UUID) -> Optional[_Route]: + """ + Get route for a host ID (lock-free read). + + :param host_id: UUID of the host + :return: _Route or None + """ + return self._routes_by_host_id.get(host_id) + + def get_all(self) -> List[_Route]: + """ + Get all routes as a list (lock-free read). + + :return: List of _Route + """ + return list(self._routes_by_host_id.values()) + + def _select_preferred_routes(self, new_routes: List[_Route]) -> List[_Route]: + """ + When multiple routes exist for the same host_id (different connection_ids), + prefer the connection_id already in use. Only migrate to a different + connection_id when the previously used one is no longer available. + + Must be called under self._lock. + """ + by_host: Dict[uuid.UUID, List[_Route]] = {} + for route in new_routes: + by_host.setdefault(route.host_id, []).append(route) + + selected = [] + for host_id, candidates in by_host.items(): + if len(candidates) == 1: + selected.append(candidates[0]) + continue + + existing = self._routes_by_host_id.get(host_id) + if existing: + preferred = [c for c in candidates if c.connection_id == existing.connection_id] + if preferred: + selected.append(preferred[0]) + continue + + selected.append(candidates[0]) + + return selected + + def update(self, routes: List[_Route]) -> None: + """ + Replace all routes atomically. + + :param routes: List of _Route objects + """ + with self._lock: + preferred = self._select_preferred_routes(routes) + self._routes_by_host_id = {route.host_id: route for route in preferred} + + def merge(self, new_routes: List[_Route], affected_host_ids: Set[uuid.UUID]) -> None: + """ + Merge new routes with existing ones atomically. + + Routes for affected_host_ids are replaced entirely: existing routes + for those hosts are dropped and replaced with whatever is in new_routes. + This handles deletions from system.client_routes (affected host present + but no new route for it). + + :param new_routes: List of _Route objects to merge + :param affected_host_ids: Set of host IDs affected by the change. + """ + with self._lock: + preferred = self._select_preferred_routes(new_routes) + new_by_host = {r.host_id: r for r in preferred} + + updated = {hid: r for hid, r in self._routes_by_host_id.items() + if hid not in affected_host_ids} + updated.update(new_by_host) + self._routes_by_host_id = updated From ac91295c85a8d176254d33bf79a62760ebf9056d Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Tue, 17 Mar 2026 10:32:17 +0100 Subject: [PATCH 04/27] Add client routes handler for Private Link support Add _ClientRoutesHandler which manages the full lifecycle of dynamic address translation via system.client_routes: - initialize(): loads all routes at startup and on control connection reconnect - handle_client_routes_change(): processes CLIENT_ROUTES_CHANGE events with targeted merge or full refresh depending on event data - _query_all_routes_for_connections(): complete refresh query using connection_id IN (...) - _query_routes_for_change_event(): targeted query grouping by connection_id with host_id IN (...) per group - _execute_routes_query(): common query execution and result parsing with proxy address override support - resolve_host(): host_id to (address, port) resolution with DNS lookup --- cassandra/client_routes.py | 261 ++++++++++++++++++++++++++++++++++++- 1 file changed, 260 insertions(+), 1 deletion(-) diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py index f26aeef152..80b2477a6d 100644 --- a/cassandra/client_routes.py +++ b/cassandra/client_routes.py @@ -24,9 +24,17 @@ from dataclasses import dataclass import enum import logging +import socket import threading import uuid -from typing import Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple + +from cassandra import ConsistencyLevel +from cassandra.protocol import QueryMessage +from cassandra.query import dict_factory + +if TYPE_CHECKING: + from cassandra.connection import Connection log = logging.getLogger(__name__) @@ -190,3 +198,254 @@ def merge(self, new_routes: List[_Route], affected_host_ids: Set[uuid.UUID]) -> if hid not in affected_host_ids} updated.update(new_by_host) self._routes_by_host_id = updated + + +class _ClientRoutesHandler: + """ + Handles dynamic address translation for Private Link via system.client_routes. + + Lifecycle: + 1. Construction: Create with configuration + 2. Initialization: Read system.client_routes after control connection established + 3. Steady state: Listen for CLIENT_ROUTES_CHANGE events and update routes + 4. Translation: Translate addresses using Host ID lookup + """ + + config: 'ClientRoutesConfig' + ssl_enabled: bool + _routes: _RouteStore + _connection_ids: Set[str] + _proxy_addresses_override: Dict[str, str] + + def __init__(self, config: 'ClientRoutesConfig', ssl_enabled: bool = False): + """ + :param config: ClientRoutesConfig instance + :param ssl_enabled: Whether TLS is enabled (determines port selection) + """ + if not isinstance(config, ClientRoutesConfig): + raise TypeError("config must be a ClientRoutesConfig instance") + + self.config = config + self.ssl_enabled = ssl_enabled + self._routes = _RouteStore() + self._connection_ids = {dep.connection_id for dep in config.proxies} + # Precalculate proxy address mappings for efficient lookup + self._proxy_addresses_override = { + proxy.connection_id: proxy.connection_addr_override + for proxy in config.proxies + if proxy.connection_addr_override + } + + def initialize(self, connection: 'Connection', timeout: float) -> None: + """ + Load all routes from system.client_routes. + + Called once at startup and again whenever the control connection + is re-established. Reads all configured connection IDs and + replaces the in-memory route store atomically. + + Raises on failure so the caller can decide how to react (e.g. + abort startup or schedule a reconnect). + + :param connection: The Connection instance to execute queries on + :param timeout: Query timeout in seconds + """ + log.info("[client routes] Loading routes for %d proxies", len(self.config.proxies)) + + routes = self._query_all_routes_for_connections(connection, timeout, self._connection_ids) + self._routes.update(routes) + + def handle_client_routes_change(self, connection: 'Connection', timeout: float, + change_type: 'ClientRoutesChangeType', + connection_ids: Sequence[str], host_ids: Sequence[str]) -> None: + """ + Handle CLIENT_ROUTES_CHANGE event. + + Currently the protocol defines only :attr:`ClientRoutesChangeType.UPDATE_NODES`. + New variants will be added to the enum if the protocol is extended. + + :param connection: The Connection instance to execute queries on + :param timeout: Query timeout in seconds + :param change_type: A :class:`ClientRoutesChangeType` value + :param connection_ids: Affected connection ID strings; empty means all. + :param host_ids: Affected host ID strings; empty means all. + """ + + full_refresh = False + if not connection_ids or not host_ids: + log.warning( + "[client routes] CLIENT_ROUTES_CHANGE has no connection_ids or host_ids, doing full refresh") + full_refresh = True + elif len(connection_ids) != len(host_ids): + log.warning("[client routes] CLIENT_ROUTES_CHANGE has mismatched lengths (conn: %d, host: %d), doing full refresh", + len(connection_ids), len(host_ids)) + full_refresh = True + + if full_refresh: + routes = self._query_all_routes_for_connections(connection, timeout, self._connection_ids) + self._routes.update(routes) + return + + host_uuids = [uuid.UUID(hid) for hid in host_ids] + pairs = [(cid, hid) for cid, hid in zip(connection_ids, host_uuids) + if cid in self._connection_ids] + + if not pairs: + return + + routes = self._query_routes_for_change_event(connection, timeout, pairs) + self._routes.merge(routes, affected_host_ids=set(host_uuids)) + + def _query_all_routes_for_connections(self, connection: 'Connection', timeout: float, + connection_ids: Set[str]) -> List[_Route]: + """ + Query all routes for the given connection IDs (complete refresh). + + Used when control connection reconnects or as a fallback when + CLIENT_ROUTES_CHANGE event has malformed data. + + :param connection: Connection to execute query on + :param timeout: Query timeout in seconds + :param connection_ids: Set of connection ID strings + :return: List of _Route + """ + if not connection_ids: + return [] + + placeholders = ', '.join('?' for _ in connection_ids) + query = f"SELECT connection_id, host_id, address, port, tls_port FROM system.client_routes WHERE connection_id IN ({placeholders})" + params = [cid.encode('utf-8') for cid in connection_ids] + + log.debug("[client routes] Querying all routes for connection_ids=%s", connection_ids) + return self._execute_routes_query(connection, timeout, query, params) + + def _query_routes_for_change_event(self, connection: 'Connection', timeout: float, + route_pairs: List[Tuple[str, uuid.UUID]]) -> List[_Route]: + """ + Query specific routes affected by a CLIENT_ROUTES_CHANGE event. + + Takes a list of (connection_id, host_id) pairs that represent the exact + routes affected by an operation. This provides precise updates without + fetching unrelated routes. + + If the pairs list is empty or None, falls back to a complete refresh + of all routes for safety. + + :param connection: Connection to execute query on + :param timeout: Query timeout in seconds + :param route_pairs: List of (connection_id, host_id) tuples + :return: List of _Route + """ + unique_pairs = list(dict.fromkeys(route_pairs)) + + conn_ids = list(dict.fromkeys(cid for cid, _ in unique_pairs)) + host_ids = list(dict.fromkeys(hid for _, hid in unique_pairs)) + + log.debug("[client routes] Querying route pairs from CLIENT_ROUTES_CHANGE " + "(first 5 of %d): %s", len(unique_pairs), unique_pairs[:5]) + + conn_ph = ', '.join('?' for _ in conn_ids) + host_ph = ', '.join('?' for _ in host_ids) + query = ( + "SELECT connection_id, host_id, address, port, tls_port " + "FROM system.client_routes " + f"WHERE connection_id IN ({conn_ph}) AND host_id IN ({host_ph})" + ) + params: List = [cid.encode('utf-8') for cid in conn_ids] + params.extend(hid.bytes for hid in host_ids) + + return self._execute_routes_query(connection, timeout, query, params) + + def _execute_routes_query(self, connection: 'Connection', timeout: float, + query: str, params: List) -> List[_Route]: + """ + Execute a routes query and parse results. + + Common helper for both complete refresh and change event queries. + + :param connection: Connection to execute query on + :param timeout: Query timeout in seconds + :param query: CQL query string + :param params: Query parameters + :return: List of _Route + """ + log.debug("[client routes] Executing query: %s with %d parameters", query, len(params)) + + query_msg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE, + query_params=params if params else None) + result = connection.wait_for_response( + query_msg, timeout=timeout + ) + + routes = [] + broken = 0 + rows = dict_factory(result.column_names, result.parsed_rows) + for row in rows: + try: + absent = [] + port = row['tls_port'] if self.ssl_enabled else row['port'] + connection_id = row['connection_id'] + host_id = row['host_id'] + address = row['address'] + + if not port: + absent.append("tls_port" if self.ssl_enabled else "port") + if not connection_id: + absent.append("connection_id") + if not host_id: + absent.append("host_id") + if not address: + absent.append("address") + + if absent: + log.error("[client routes] read a route %s, that has no values for the following fields: %s", row, ",".join(absent)) + broken += 1 + continue + + final_address = self._proxy_addresses_override.get(connection_id, address) + + routes.append(_Route( + connection_id=connection_id, + host_id=host_id, + address=final_address, + port=port, + )) + except Exception as e: + log.warning("[client routes] Failed to parse route row: %s", e) + broken += 1 + + if broken and not routes: + raise RuntimeError( + "[client routes] All %d route rows failed validation; " + "refusing to return empty result that would wipe the route store" % broken + ) + + return routes + + def resolve_host(self, host_id: uuid.UUID) -> Optional[Tuple[str, int]]: + """ + Resolve a host_id to an (address, port) pair. + + Looks up the current route and selects the appropriate port. + + :param host_id: Host UUID to resolve + :return: Tuple of (address, port) or None if no route mapping exists + """ + route = self._routes.get_by_host_id(host_id) + if route is None: + return None + + if not route.port: + raise ValueError("Mapping for host %s has no port" % host_id) + + try: + result = socket.getaddrinfo(route.address, route.port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + if not result: + raise socket.gaierror("No addresses found for %s" % route.address) + resolved_ip = result[0][4][0] + return resolved_ip, route.port + except socket.gaierror as e: + log.warning('[client routes] Could not resolve hostname "%s" (host_id=%s): %s', + route.address, host_id, e) + raise From 1e0e6ca1004f372c825fb88f6dc64183e80baa5f Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Tue, 10 Mar 2026 15:45:43 +0100 Subject: [PATCH 05/27] Add ClientRoutesEndPoint and ClientRoutesEndPointFactory - ClientRoutesEndPointFactory: creates endpoints from system.peers rows by extracting host_id, deferring address translation and DNS resolution until connection time - ClientRoutesEndPoint: endpoint that resolves via _ClientRoutesHandler on each connection attempt, ensuring immediate reaction to route changes and CLIENT_ROUTES_CHANGE events --- cassandra/connection.py | 120 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 2 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index 87f860f32b..72b273ec37 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -25,12 +25,14 @@ from threading import Thread, Event, RLock, Condition import time import ssl +import uuid import weakref import random import itertools -from typing import Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from cassandra.application_info import ApplicationInfoBase +from cassandra.client_routes import _ClientRoutesHandler from cassandra.protocol_features import ProtocolFeatures if 'gevent.monkey' in sys.modules: @@ -230,7 +232,7 @@ class DefaultEndPointFactory(EndPointFactory): port = None """ If no port is discovered in the row, this is the default port - used for endpoint creation. + used for endpoint creation. """ def __init__(self, port=None): @@ -328,6 +330,50 @@ def create_from_sni(self, sni): return SniEndPoint(self._proxy_address, sni, self._port) +class ClientRoutesEndPointFactory(EndPointFactory): + """ + EndPointFactory for Client Routes (Private Link) support. + + Creates ClientRoutesEndPoint instances that defer both address translation + (host_id -> hostname lookup) and DNS resolution until connection time. + This ensures immediate reaction to infrastructure changes. + """ + + client_routes_handler: _ClientRoutesHandler + default_port: int + + def __init__(self, client_routes_handler: _ClientRoutesHandler, default_port: int = None) -> None: + """ + :param client_routes_handler: _ClientRoutesHandler instance to lookup routes + :param default_port: Default port if none found in row + """ + self.client_routes_handler = client_routes_handler + self.default_port = default_port + + def create(self, row: Dict[str, Any]) -> 'ClientRoutesEndPoint': + """ + Create a ClientRoutesEndPoint from a system.peers row. + + Stores only the host_id and handler reference. Both translation + (route lookup) and DNS resolution happen later in resolve(). + """ + from cassandra.metadata import _NodeInfo + host_id = row.get("host_id") + + if host_id is None: + raise ValueError("No host_id to create ClientRoutesEndPoint") + + addr = _NodeInfo.get_broadcast_rpc_address(row) + port = _NodeInfo.get_broadcast_rpc_port(row) or _NodeInfo.get_broadcast_port(row) or self.default_port + + return ClientRoutesEndPoint( + host_id=host_id, + handler=self.client_routes_handler, + original_address=addr, + original_port=port, + ) + + @total_ordering class UnixSocketEndPoint(EndPoint): """ @@ -369,6 +415,76 @@ def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self._unix_socket_path) +@total_ordering +class ClientRoutesEndPoint(EndPoint): + """ + Client Routes (Private Link) EndPoint implementation. + + Defers both address translation (route lookup) and DNS resolution + until resolve() is called at connection time. This ensures immediate + reaction to infrastructure changes and CLIENT_ROUTES_CHANGE events. + """ + + _host_id: uuid.UUID + _handler: _ClientRoutesHandler + _original_address: str + _original_port: int + + def __init__(self, host_id: uuid.UUID, handler: _ClientRoutesHandler, original_address: str, original_port: int = None) -> None: + """ + :param host_id: Host UUID for route lookup + :param handler: _ClientRoutesHandler instance + :param original_address: Original address from system.peers (for identification) + :param original_port: Original port if route doesn't specify one + """ + self._host_id = host_id + self._handler = handler + self._original_address = original_address + self._original_port = original_port + + @property + def address(self) -> str: + """Returns the original address (updated by resolve()).""" + return self._original_address + + @property + def port(self) -> Optional[int]: + return self._original_port + + @property + def host_id(self) -> uuid.UUID: + return self._host_id + + def resolve(self) -> Tuple[str, int]: + """ + Resolve endpoint by delegating to the handler. + Falls back to original address/port if no route mapping is available. + """ + result = self._handler.resolve_host(self._host_id) + if result is None: + return self._original_address, self._original_port + return result + + def __eq__(self, other): + return (isinstance(other, ClientRoutesEndPoint) and + self._host_id == other._host_id and + self._original_address == other._original_address) + + def __hash__(self): + return hash((self._host_id, self._original_address)) + + def __lt__(self, other): + return ((self._host_id, self._original_address) < + (other._host_id, other._original_address)) + + def __str__(self): + return str("%s (host_id=%s)" % (self._original_address, self._host_id)) + + def __repr__(self): + return "<%s: host_id=%s, original_addr=%s>" % ( + self.__class__.__name__, self._host_id, self._original_address) + + class _Frame(object): def __init__(self, version, flags, stream, opcode, body_offset, end_pos): self.version = version From 32023021bb16ada4c8425cf06bdf685f0169aefc Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Tue, 10 Mar 2026 15:49:20 +0100 Subject: [PATCH 06/27] Integrate client routes handler into Cluster and ControlConnection Cluster: - Add client_routes_config parameter with mutual exclusivity check against endpoint_factory - Create _ClientRoutesHandler and ClientRoutesEndPointFactory when client_routes_config is provided ControlConnection: - Register CLIENT_ROUTES_CHANGE event watcher when handler is present - Forward events to handler via _handle_client_routes_change - Trigger full route re-read on control connection reconnection --- cassandra/cluster.py | 103 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 98 insertions(+), 5 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 51d0b2d88b..8da9df6a55 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -29,7 +29,7 @@ from itertools import groupby, count, chain import json import logging -from typing import Optional, Union +from typing import Any, Dict, Optional, Union from warnings import warn from random import random import re @@ -48,7 +48,8 @@ SchemaTargetType, DriverException, ProtocolVersion, UnresolvableContactPoints, DependencyException) from cassandra.auth import _proxy_execute_key, PlainTextAuthProvider -from cassandra.connection import (ConnectionException, ConnectionShutdown, +from cassandra.client_routes import ClientRoutesChangeType, ClientRoutesConfig, _ClientRoutesHandler +from cassandra.connection import (ClientRoutesEndPointFactory, ConnectionException, ConnectionShutdown, ConnectionHeartbeat, ProtocolVersionUnsupported, EndPoint, DefaultEndPoint, DefaultEndPointFactory, SniEndPointFactory, ConnectionBusy, locally_supported_compressions) @@ -1215,7 +1216,8 @@ def __init__(self, shard_aware_options=None, metadata_request_timeout: Optional[float] = None, column_encryption_policy=None, - application_info:Optional[ApplicationInfoBase]=None + application_info:Optional[ApplicationInfoBase]=None, + client_routes_config:Optional[ClientRoutesConfig]=None ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1280,6 +1282,45 @@ def __init__(self, if column_encryption_policy is not None: self.column_encryption_policy = column_encryption_policy + if client_routes_config is not None and endpoint_factory is not None: + raise ValueError("client_routes_config and endpoint_factory are mutually exclusive") + + self._client_routes_handler = None + if client_routes_config is not None: + if not isinstance(client_routes_config, ClientRoutesConfig): + raise TypeError("client_routes_config must be a ClientRoutesConfig instance") + + # SSL hostname verification is incompatible with client routes: + # connections go through NLB proxies whose addresses won't match + # server certificates. + _check_hostname_enabled = False + if ssl_context is not None and ssl_context.check_hostname: + _check_hostname_enabled = True + if ssl_options is not None and ssl_options.get('check_hostname', False): + _check_hostname_enabled = True + if _check_hostname_enabled: + raise ValueError( + "SSL hostname verification (check_hostname=True) is currently incompatible " + "with client_routes_config. When using client routes, connections " + "go through NLB proxies whose addresses won't match server " + "certificates. Disable hostname verification by setting " + "ssl_context.check_hostname = False." + ) + + ssl_enabled = ssl_context is not None or ssl_options is not None + self._client_routes_handler = _ClientRoutesHandler(client_routes_config, ssl_enabled=ssl_enabled) + + if contact_points is _NOT_SET or not self._contact_points_explicit: + seed_addrs = [dep.connection_addr_override for dep in client_routes_config.proxies + if dep.connection_addr_override] + if seed_addrs: + self.contact_points = seed_addrs + self._contact_points_explicit = True + log.info("[client routes] Using %d deployment connection addresses as contact points", + len(seed_addrs)) + + if self._client_routes_handler is not None: + endpoint_factory = ClientRoutesEndPointFactory(self._client_routes_handler, self.port) self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port) self.endpoint_factory.configure(self) @@ -1437,6 +1478,10 @@ def __init__(self, self.monitor_reporting_interval = monitor_reporting_interval self.shard_aware_options = ShardAwareOptions(opts=shard_aware_options) + if (client_routes_config is not None + and not client_routes_config.advanced_shard_awareness): + self.shard_aware_options.disable_shardaware_port = True + self._listeners = set() self._listener_lock = Lock() @@ -3612,11 +3657,21 @@ def _try_connect(self, endpoint): # this object (after a dereferencing a weakref) self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) try: - connection.register_watchers({ + watchers = { "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') - }, register_timeout=self._timeout) + } + + if self._cluster._client_routes_handler is not None: + watchers["CLIENT_ROUTES_CHANGE"] = partial(_watch_callback, self_weakref, '_handle_client_routes_change') + + connection.register_watchers(watchers, register_timeout=self._timeout) + + if self._cluster._client_routes_handler is not None: + self._cluster._client_routes_handler.initialize( + connection, + self._timeout) sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS @@ -3979,6 +4034,44 @@ def _handle_status_change(self, event): # this will be run by the scheduler self._cluster.on_down(host, is_host_addition=False) + def _handle_client_routes_change(self, event: Dict[str, Any]) -> None: + """ + Handle CLIENT_ROUTES_CHANGE event from the server. + + This event indicates that the system.client_routes table has been updated + and we need to refresh our route mappings. + """ + if self._cluster._client_routes_handler is None: + log.warning("[control connection] Received CLIENT_ROUTES_CHANGE but no handler configured") + return + + raw_change_type = event.get("change_type") + try: + change_type = ClientRoutesChangeType(raw_change_type) + except ValueError: + log.warning("[control connection] Unknown CLIENT_ROUTES_CHANGE type: %s", raw_change_type) + return + + connection_ids = tuple(event.get("connection_ids", [])) + host_ids = tuple(event.get("host_ids", [])) + + self._cluster.scheduler.schedule_unique( + 0, + self._handle_client_routes_refresh, + self._connection, self._timeout, change_type, connection_ids, host_ids + ) + + def _handle_client_routes_refresh(self, connection, timeout, + change_type, connection_ids, host_ids): + try: + self._cluster._client_routes_handler.handle_client_routes_change( + connection, timeout, change_type, connection_ids, host_ids) + except ReferenceError: + pass # our weak reference to the Cluster is no good + except Exception: + log.debug("[control connection] Error handling CLIENT_ROUTES_CHANGE", exc_info=True) + self._signal_error() + def _handle_schema_change(self, event): if self._schema_event_refresh_window < 0: return From b205f838191c11ea630ff7829f0353e4e8f88be6 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Tue, 10 Mar 2026 15:57:52 +0100 Subject: [PATCH 07/27] tests: add unit tests for client routes Cover ClientRouteEntry/ClientRoutesConfig validation, _RouteStore get/merge operations, _ClientRoutesHandler initialization, ClientRoutesEndPoint resolution with and without route mappings, and SSL check_hostname rejection with client_routes_config. --- tests/unit/test_client_routes.py | 482 +++++++++++++++++++++++++++++++ 1 file changed, 482 insertions(+) create mode 100644 tests/unit/test_client_routes.py diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py new file mode 100644 index 0000000000..0aa82fc76a --- /dev/null +++ b/tests/unit/test_client_routes.py @@ -0,0 +1,482 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import ssl +import unittest +import uuid +from unittest.mock import Mock, patch + +from cassandra.client_routes import ( + ClientRouteProxy, + ClientRoutesChangeType, + ClientRoutesConfig, + _RouteStore, + _Route, + _ClientRoutesHandler +) +from cassandra.connection import ClientRoutesEndPoint, ClientRoutesEndPointFactory +from cassandra.cluster import Cluster + + +class TestClientRouteProxy(unittest.TestCase): + + def test_endpoint_none_connection_id(self): + with self.assertRaises(ValueError): + ClientRouteProxy(None) + + +class TestClientRoutesConfig(unittest.TestCase): + + def test_config_with_proxies(self): + ep1 = ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1") + ep2 = ClientRouteProxy(str(uuid.uuid4()), "10.0.0.2") + config = ClientRoutesConfig([ep1, ep2]) + self.assertEqual(len(config.proxies), 2) + + def test_config_empty_proxies(self): + with self.assertRaises(ValueError): + ClientRoutesConfig([]) + + def test_config_invalid_proxy_type(self): + with self.assertRaises(TypeError): + ClientRoutesConfig(["not-a-proxy"]) + + + +class TestRouteStore(unittest.TestCase): + + def test_get_by_host_id(self): + routes = _RouteStore() + host_id = uuid.uuid4() + route = _Route( + connection_id=str(uuid.uuid4()), + host_id=host_id, + address="example.com", + port=9042, + ) + + routes.update([route]) + + retrieved = routes.get_by_host_id(host_id) + self.assertEqual(retrieved.host_id, host_id) + self.assertEqual(retrieved.address, "example.com") + + def test_merge_routes(self): + routes = _RouteStore() + host_id1 = uuid.uuid4() + host_id2 = uuid.uuid4() + + route1 = _Route( + connection_id=str(uuid.uuid4()), host_id=host_id1, + address="host1.com", port=9042, + ) + + route2 = _Route( + connection_id=str(uuid.uuid4()), host_id=host_id2, + address="host2.com", port=9042, + ) + + routes.update([route1]) + routes.merge([route2], affected_host_ids={host_id2}) + + self.assertIsNotNone(routes.get_by_host_id(host_id1)) + self.assertIsNotNone(routes.get_by_host_id(host_id2)) + + def test_merge_deletes_affected_host_with_no_new_route(self): + """When an affected host_id has no corresponding new route, it should be removed.""" + store = _RouteStore() + host_id1 = uuid.uuid4() + host_id2 = uuid.uuid4() + conn_id = str(uuid.uuid4()) + + store.update([ + _Route(connection_id=conn_id, host_id=host_id1, address="a.com", port=9042), + _Route(connection_id=conn_id, host_id=host_id2, address="b.com", port=9042), + ]) + self.assertIsNotNone(store.get_by_host_id(host_id1)) + self.assertIsNotNone(store.get_by_host_id(host_id2)) + + # Merge with host_id2 affected but no new route for it → deletion + store.merge([], affected_host_ids={host_id2}) + + self.assertIsNotNone(store.get_by_host_id(host_id1)) + self.assertIsNone(store.get_by_host_id(host_id2)) + + def test_select_preferred_routes_keeps_existing_connection_id(self): + """When multiple connection_ids provide routes for the same host_id, + the one already in use should be preferred.""" + store = _RouteStore() + host_id = uuid.uuid4() + conn_a = "conn-a" + conn_b = "conn-b" + + # Populate store with conn_a for host_id + store.update([_Route(connection_id=conn_a, host_id=host_id, address="a.com", port=9042)]) + self.assertEqual(store.get_by_host_id(host_id).connection_id, conn_a) + + # Update with both conn_a and conn_b for the same host_id + store.update([ + _Route(connection_id=conn_b, host_id=host_id, address="b.com", port=9042), + _Route(connection_id=conn_a, host_id=host_id, address="a-new.com", port=9042), + ]) + # conn_a should be preferred since it was already in use + result = store.get_by_host_id(host_id) + self.assertEqual(result.connection_id, conn_a) + self.assertEqual(result.address, "a-new.com") + + def test_select_preferred_routes_falls_back_when_existing_gone(self): + """When the existing connection_id is no longer among candidates, + the first candidate should be selected.""" + store = _RouteStore() + host_id = uuid.uuid4() + + store.update([_Route(connection_id="old-conn", host_id=host_id, address="old.com", port=9042)]) + + # Update only has new connection_ids + store.update([ + _Route(connection_id="new-a", host_id=host_id, address="a.com", port=9042), + _Route(connection_id="new-b", host_id=host_id, address="b.com", port=9042), + ]) + result = store.get_by_host_id(host_id) + self.assertEqual(result.connection_id, "new-a") + + +class TestClientRoutesHandler(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + self.proxy = ClientRouteProxy(str(self.conn_id), "10.0.0.1") + self.config = ClientRoutesConfig([self.proxy]) + + def test_handler_initialization(self): + handler = _ClientRoutesHandler(self.config, ssl_enabled=False) + self.assertIsNotNone(handler) + self.assertEqual(handler.ssl_enabled, False) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize(self, mock_query): + host_id = uuid.uuid4() + mock_query.return_value = [ + _Route( + connection_id=self.conn_id, + host_id=host_id, + address="node1.example.com", + port=9042, + ) + ] + + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + + handler.initialize(mock_conn, timeout=5.0) + + mock_query.assert_called_once() + route = handler._routes.get_by_host_id(host_id) + self.assertIsNotNone(route) + self.assertEqual(route.address, "node1.example.com") + + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_filters_by_configured_connection_ids(self, mock_query): + """Events with unrelated connection_ids should be ignored.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + host_id = str(uuid.uuid4()) + + # Event with a connection_id NOT in our config → should return early + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=["unrelated-conn-id"], + host_ids=[host_id], + ) + mock_query.assert_not_called() + + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_merges_when_host_ids_present(self, mock_query): + """When host_ids are provided, routes should be merged (not full replace).""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + + existing_host = uuid.uuid4() + new_host = uuid.uuid4() + conn_id = str(self.conn_id) + + # Pre-populate a route + handler._routes.update([ + _Route(connection_id=conn_id, host_id=existing_host, address="old.com", port=9042), + ]) + + mock_query.return_value = [ + _Route(connection_id=conn_id, host_id=new_host, address="new.com", port=9042), + ] + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_id], + host_ids=[str(new_host)], + ) + + # Existing route should still be there (merge, not replace) + self.assertIsNotNone(handler._routes.get_by_host_id(existing_host)) + self.assertIsNotNone(handler._routes.get_by_host_id(new_host)) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_handle_change_updates_when_no_host_ids(self, mock_query): + """When no host_ids are provided, routes should be fully replaced.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + conn_id = str(self.conn_id) + + old_host = uuid.uuid4() + handler._routes.update([ + _Route(connection_id=conn_id, host_id=old_host, address="old.com", port=9042), + ]) + + new_host = uuid.uuid4() + mock_query.return_value = [ + _Route(connection_id=conn_id, host_id=new_host, address="new.com", port=9042), + ] + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=None, + host_ids=None, + ) + + # Full replace: old_host gone, new_host present + self.assertIsNone(handler._routes.get_by_host_id(old_host)) + self.assertIsNotNone(handler._routes.get_by_host_id(new_host)) + + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_propagates_query_failure(self, mock_query): + """If _query_routes raises, handle_client_routes_change should propagate.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + mock_query.side_effect = Exception("network error") + + conn_id = self.proxy.connection_id + host_id = str(uuid.uuid4()) + with self.assertRaises(Exception) as cm: + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_id], + host_ids=[host_id], + ) + self.assertIn("network error", str(cm.exception)) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize_propagates_exception_on_failure(self, mock_query): + """initialize should propagate exceptions to caller.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + mock_query.side_effect = Exception("query failed") + + with self.assertRaises(Exception) as ctx: + handler.initialize(mock_conn, 5.0) + self.assertIn("query failed", str(ctx.exception)) + self.assertEqual(mock_query.call_count, 1) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize_keeps_old_routes_on_failure(self, mock_query): + """On failure, existing routes must be preserved (critical for PL clusters).""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + host_id = uuid.uuid4() + + # Pre-populate a route + handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, address="old.com", port=9042), + ]) + + mock_query.side_effect = Exception("query failed") + with self.assertRaises(Exception): + handler.initialize(mock_conn, 5.0) + + # Old route must still be there + self.assertIsNotNone(handler._routes.get_by_host_id(host_id)) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize_updates_routes_on_success(self, mock_query): + """initialize should update routes on success.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + host_id = uuid.uuid4() + + mock_query.return_value = [ + _Route(connection_id=str(self.conn_id), host_id=host_id, address="new.com", port=9042), + ] + + handler.initialize(mock_conn, 5.0) + + self.assertEqual(mock_query.call_count, 1) + route = handler._routes.get_by_host_id(host_id) + self.assertIsNotNone(route) + self.assertEqual(route.address, "new.com") + +class TestClientRoutesEndPoint(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + self.proxy = ClientRouteProxy(str(self.conn_id), "10.0.0.1") + self.config = ClientRoutesConfig([self.proxy]) + self.handler = _ClientRoutesHandler(self.config, ssl_enabled=False) + + def test_resolve_falls_back_when_no_mapping(self): + """resolve() should return original address/port when no route mapping exists.""" + host_id = uuid.uuid4() + ep = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + self.assertEqual(ep.resolve(), ("10.0.0.1", 9042)) + + @patch('cassandra.client_routes.socket.getaddrinfo', + return_value=[(socket.AF_INET, socket.SOCK_STREAM, 0, '', ("192.168.1.100", 9042))]) + def test_resolve_returns_address_when_route_exists(self, _mock_getaddrinfo): + """resolve() should return the DNS-resolved address and port when a route exists.""" + host_id = uuid.uuid4() + self.handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, + address="nlb.example.com", port=9042), + ]) + ep = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + self.assertEqual(ep.resolve(), ("192.168.1.100", 9042)) + _mock_getaddrinfo.assert_called_once_with( + "nlb.example.com", 9042, socket.AF_UNSPEC, socket.SOCK_STREAM) + + @patch('cassandra.client_routes.socket.getaddrinfo', + side_effect=socket.gaierror("DNS resolution failed")) + def test_resolve_host_dns_failure_raises(self, _mock_getaddrinfo): + """resolve_host should propagate socket.gaierror on DNS failure.""" + host_id = uuid.uuid4() + self.handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, + address="nonexistent.example.com", port=9042), + ]) + with self.assertRaises(socket.gaierror): + self.handler.resolve_host(host_id) + + def test_resolve_host_missing_port_raises(self): + """resolve_host should raise ValueError when route has no port.""" + host_id = uuid.uuid4() + self.handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, + address="host.com", port=0), + ]) + with self.assertRaises(ValueError): + self.handler.resolve_host(host_id) + + +class TestClientRoutesEndPointFactory(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + proxy = ClientRouteProxy(str(self.conn_id), "10.0.0.1") + self.config = ClientRoutesConfig([proxy]) + self.handler = _ClientRoutesHandler(self.config, ssl_enabled=False) + self.factory = ClientRoutesEndPointFactory(self.handler, default_port=9042) + + def test_create_from_row(self): + """Factory should create a ClientRoutesEndPoint from a peers row.""" + host_id = uuid.uuid4() + row = { + "host_id": host_id, + "rpc_address": "10.0.0.5", + "native_transport_port": 9042, + "peer": "10.0.0.5", + } + ep = self.factory.create(row) + self.assertIsInstance(ep, ClientRoutesEndPoint) + self.assertEqual(ep.host_id, host_id) + self.assertEqual(ep.address, "10.0.0.5") + + def test_create_missing_host_id_raises(self): + """Factory should raise ValueError when row has no host_id.""" + row = {"rpc_address": "10.0.0.5", "native_transport_port": 9042} + with self.assertRaises(ValueError): + self.factory.create(row) + +class TestClientRoutesSSLValidation(unittest.TestCase): + + def test_check_hostname_with_ssl_context_raises(self): + """Cluster should reject check_hostname=True with client_routes_config.""" + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertTrue(ssl_ctx.check_hostname) + + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + with self.assertRaises(ValueError) as cm: + Cluster( + contact_points=["10.0.0.1"], + ssl_context=ssl_ctx, + client_routes_config=config, + ) + self.assertIn("check_hostname", str(cm.exception)) + + def test_check_hostname_with_ssl_options_raises(self): + """Cluster should reject check_hostname=True in ssl_options with client_routes_config.""" + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + with self.assertRaises(ValueError) as cm: + Cluster( + contact_points=["10.0.0.1"], + ssl_options={'check_hostname': True}, + client_routes_config=config, + ) + self.assertIn("check_hostname", str(cm.exception)) + + def test_disabled_check_hostname_with_client_routes_ok(self): + """Cluster should allow check_hostname=False with client_routes_config.""" + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_ctx.check_hostname = False + + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + # Should not raise + cluster = Cluster( + contact_points=["10.0.0.1"], + ssl_context=ssl_ctx, + client_routes_config=config, + ) + cluster.shutdown() + + def test_no_ssl_with_client_routes_ok(self): + """Cluster should allow client_routes_config without SSL.""" + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + # Should not raise + cluster = Cluster( + contact_points=["10.0.0.1"], + client_routes_config=config, + ) + cluster.shutdown() + + +if __name__ == '__main__': + unittest.main() From 6743edd4baf8116195cf54e422cd3c91e6390d3b Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Wed, 11 Mar 2026 11:16:03 +0100 Subject: [PATCH 08/27] tests: add integration tests for client routes Add comprehensive integration tests covering: - TCP proxy and NLB emulator infrastructure for simulating private link connectivity - query_routes filtering with different connection/host ID combinations - Full private-link connectivity verifying all driver connections go exclusively through the NLB proxy - Dynamic route updates via REST API with driver reconnection through new proxy ports --- .../standard/test_client_routes.py | 1314 +++++++++++++++++ 1 file changed, 1314 insertions(+) create mode 100644 tests/integration/standard/test_client_routes.py diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py new file mode 100644 index 0000000000..a8a3c30f2c --- /dev/null +++ b/tests/integration/standard/test_client_routes.py @@ -0,0 +1,1314 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive integration tests for Client Routes (Private Link) support. + +Includes: +- TCP proxy and NLB emulator for simulating private link infrastructure +- Tests verifying all connections go exclusively through the proxy +- Tests for dynamic route updates and topology changes +- Tests for query_routes filtering +""" + +import logging +import os +import select +import shutil +import socket +import ssl +import subprocess +import tempfile +import threading +import time +import unittest +import uuid + +import json as _json +import urllib.request + +from cassandra.cluster import Cluster +from cassandra.client_routes import ClientRoutesConfig, ClientRouteProxy +from cassandra.connection import ClientRoutesEndPoint +from cassandra.policies import RoundRobinPolicy +from tests.integration import ( + TestCluster, + get_cluster, + get_node, + use_cluster, + wait_for_node_socket, + skip_scylla_version_lt, +) +from tests.util import wait_until_not_raised + +log = logging.getLogger(__name__) + +class TcpProxy: + """ + A simple TCP proxy that forwards connections from a local listen port + to a target (host, port). Tracks active connections so tests can + verify that traffic flows through the proxy. + """ + + BUF_SIZE = 65536 + + def __init__(self, listen_host, listen_port, target_host, target_port): + self.listen_host = listen_host + self.listen_port = listen_port + self.target_host = target_host + self.target_port = target_port + + self._server_sock = None + self._running = False + self._thread = None + self._lock = threading.Lock() + self._connections = set() + self.total_connections = 0 + + def start(self): + self._server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._server_sock.bind((self.listen_host, self.listen_port)) + self.listen_port = self._server_sock.getsockname()[1] + self._server_sock.listen(128) + self._server_sock.setblocking(False) + self._running = True + self._thread = threading.Thread(target=self._run, daemon=True, + name="proxy-%s:%d" % (self.listen_host, self.listen_port)) + self._thread.start() + log.info("TcpProxy started %s:%d -> %s:%d", + self.listen_host, self.listen_port, + self.target_host, self.target_port) + + def stop(self): + self._running = False + if self._server_sock: + try: + self._server_sock.close() + except Exception: + pass + with self._lock: + for csock, tsock in list(self._connections): + self._close_pair(csock, tsock) + self._connections.clear() + if self._thread: + self._thread.join(timeout=5) + log.info("TcpProxy stopped %s:%d", self.listen_host, self.listen_port) + + @property + def active_connections(self): + with self._lock: + return len(self._connections) + + def retarget(self, new_host, new_port): + """Change the backend target for new connections (existing ones keep the old target).""" + self.target_host = new_host + self.target_port = new_port + log.info("TcpProxy %s:%d retargeted to %s:%d", + self.listen_host, self.listen_port, new_host, new_port) + + def drop_connections(self): + """Forcibly close all active connections.""" + with self._lock: + for csock, tsock in list(self._connections): + self._close_pair(csock, tsock) + self._connections.clear() + log.info("TcpProxy %s:%d dropped all connections", self.listen_host, self.listen_port) + + def _run(self): + while self._running: + try: + readable, _, _ = select.select([self._server_sock], [], [], 0.2) + except (ValueError, OSError): + break + for sock in readable: + if sock is self._server_sock: + try: + client_sock, _ = self._server_sock.accept() + except OSError: + continue + self._handle_new_connection(client_sock) + + def _handle_new_connection(self, client_sock, target_host=None, target_port=None): + target_host = target_host or self.target_host + target_port = target_port or self.target_port + try: + target_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + target_sock.connect((target_host, target_port)) + except Exception as e: + log.warning("TcpProxy %s:%d failed to connect to target %s:%d: %s", + self.listen_host, self.listen_port, + target_host, target_port, e) + client_sock.close() + return + + with self._lock: + self._connections.add((client_sock, target_sock)) + self.total_connections += 1 + + t = threading.Thread(target=self._forward_loop, + args=(client_sock, target_sock), + daemon=True) + t.start() + + def _forward_loop(self, client_sock, target_sock): + try: + while self._running: + readable, _, _ = select.select([client_sock, target_sock], [], [], 0.5) + for sock in readable: + data = sock.recv(self.BUF_SIZE) + if not data: + return + if sock is client_sock: + target_sock.sendall(data) + else: + client_sock.sendall(data) + except (OSError, ConnectionResetError, BrokenPipeError): + pass + finally: + with self._lock: + self._connections.discard((client_sock, target_sock)) + self._close_pair(client_sock, target_sock) + + @staticmethod + def _close_pair(csock, tsock): + for s in (csock, tsock): + try: + s.close() + except Exception: + pass + + +class NLBEmulator: + """ + Emulates a Network Load Balancer for a CCM cluster. + + Provides: + - One *discovery port* (round-robin across all live nodes, used as the + driver's ``contact_points``). + - One *per-node port* for each node (dedicated proxy to that node's + native transport port). + + All proxies listen on ``LISTEN_HOST`` (127.254.254.101), an address + outside the CCM node range, simulating a real NLB endpoint. + + Port layout (all ports are OS-assigned by default): + LISTEN_HOST:discovery_port -> round-robin to all live nodes + LISTEN_HOST: -> node1 (127.0.0.1:9042) + LISTEN_HOST: -> node2 (127.0.0.2:9042) + ... + + Automatically creates/removes per-node proxies when nodes are + added/removed so CCM cluster operations are reflected seamlessly. + """ + + LISTEN_HOST = "127.254.254.101" + + def __init__(self, discovery_port=0, + per_node_base=0, + native_port=9042, + node_addresses=None): + self.discovery_port = discovery_port + self.per_node_base = per_node_base + self.native_port = native_port + self._deferred_node_addresses = node_addresses + + self._node_proxies = {} + self._discovery_proxy = None + self._rr_index = 0 + self._lock = threading.Lock() + self._running = False + + def start(self, node_addresses): + """ + Start the NLB with an initial set of node addresses. + + :param node_addresses: dict of node_id -> ip_address, e.g. + {1: "127.0.0.1", 2: "127.0.0.2"} + """ + self._running = True + try: + for node_id, addr in node_addresses.items(): + self._add_node_proxy(node_id, addr) + + first_addr = list(node_addresses.values())[0] + self._discovery_proxy = TcpProxy( + self.LISTEN_HOST, self.discovery_port, + first_addr, self.native_port, + ) + self._discovery_proxy.start() + self.discovery_port = self._discovery_proxy.listen_port + except Exception: + self.stop() + raise + original_handler = self._discovery_proxy._handle_new_connection + + def rr_handler(client_sock): + addrs = self._live_addresses() + if not addrs: + client_sock.close() + return + idx = self._rr_index % len(addrs) + self._rr_index += 1 + addr = addrs[idx] + original_handler(client_sock, target_host=addr, target_port=self.native_port) + + self._discovery_proxy._handle_new_connection = rr_handler + + log.info("NLB started: discovery=%s:%d, %d node proxies", + self.LISTEN_HOST, self.discovery_port, len(self._node_proxies)) + return self + + def __enter__(self): + if not self._running and self._deferred_node_addresses is not None: + self.start(self._deferred_node_addresses) + return self + + def __exit__(self, *args): + self.stop() + + def stop(self): + self._running = False + if self._discovery_proxy: + self._discovery_proxy.stop() + for proxy in self._node_proxies.values(): + proxy.stop() + self._node_proxies.clear() + log.info("NLB stopped") + + def add_node(self, node_id, addr): + self._add_node_proxy(node_id, addr) + + def remove_node(self, node_id): + with self._lock: + proxy = self._node_proxies.pop(node_id, None) + if proxy: + proxy.stop() + log.info("NLB removed node %d", node_id) + + def node_port(self, node_id): + proxy = self._node_proxies.get(node_id) + if proxy: + return proxy.listen_port + return self.per_node_base + node_id + + def get_node_proxy(self, node_id): + return self._node_proxies.get(node_id) + + def total_proxy_connections(self): + return sum(p.total_connections for p in self._node_proxies.values()) + + def active_proxy_connections(self): + return sum(p.active_connections for p in self._node_proxies.values()) + + def drop_all_connections(self): + for proxy in self._node_proxies.values(): + proxy.drop_connections() + if self._discovery_proxy: + self._discovery_proxy.drop_connections() + + def _add_node_proxy(self, node_id, addr): + port = 0 + proxy = TcpProxy(self.LISTEN_HOST, port, addr, self.native_port) + proxy.start() + with self._lock: + self._node_proxies[node_id] = proxy + log.info("NLB added node %d: %s:%d -> %s:%d", + node_id, self.LISTEN_HOST, port, addr, self.native_port) + + def _live_addresses(self): + """IPs of nodes with active proxies.""" + return [p.target_host for p in self._node_proxies.values()] + +def post_client_routes(contact_point, routes): + """ + Post client routes to Scylla's REST API. + + :param contact_point: IP/hostname of a Scylla node (e.g. "127.0.0.1") + :param routes: List of route dicts with keys: connection_id, host_id, address, port + and optionally tls_port + """ + payload = [] + for route in routes: + entry = { + "connection_id": str(route["connection_id"]), + "host_id": str(route["host_id"]), + "address": route["address"], + "port": route["port"], + } + if route.get("tls_port") is not None: + entry["tls_port"] = route["tls_port"] + payload.append(entry) + + url = "http://%s:10000/v2/client-routes" % contact_point + log.info("Posting %d routes to %s", len(payload), url) + data = _json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + url, + data=data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + method="POST", + ) + response = urllib.request.urlopen(req) + log.info("Routes posted successfully (status %d)", response.status) + + +def get_host_ids_from_cluster(session): + """ + Build a mapping of rpc_address -> host_id for all nodes in the cluster. + + Uses the driver's metadata rather than querying system.local / system.peers + directly, because those queries can be routed to different coordinators + (system.local returns the coordinator's own info while system.peers omits + the coordinator), leading to a node being missing from the map. + """ + host_id_map = {} + for host in session.cluster.metadata.all_hosts(): + host_id_map[host.address] = host.host_id + return host_id_map + + +def build_routes_for_nlb(connection_id, host_id_map, nlb): + """ + Build routes that direct each host_id through the NLB per-node proxy. + + :param connection_id: Connection ID string + :param host_id_map: dict ip -> uuid host_id (from get_host_ids_from_cluster) + :param nlb: NLBEmulator instance + :return: list of route dicts + """ + routes = [] + for ip, host_id in host_id_map.items(): + node_id = int(ip.split(".")[-1]) + port = nlb.node_port(node_id) + routes.append({ + "connection_id": connection_id, + "host_id": host_id, + "address": NLBEmulator.LISTEN_HOST, + "port": port, + }) + return routes + + +def post_routes_for_nlb(contact_point, connection_id, host_id_map, nlb): + """Build routes for the NLB and POST them via the REST API.""" + routes = build_routes_for_nlb(connection_id, host_id_map, nlb) + post_client_routes(contact_point, routes) + return routes + +def wait_for_routes_visible(session, connection_id, expected_count, timeout=10, poll_interval=0.1): + """ + Poll system.client_routes on **every** node until each one sees at + least *expected_count* rows for *connection_id*. + + ``system.client_routes`` is a node-local table, so routes posted via + the REST API to one node are not guaranteed to be visible on the + others at the same time. This helper ensures they have propagated + everywhere before the test proceeds. + + :param session: an active driver Session (direct, not through NLB) + :param connection_id: the connection_id string to filter on + :param expected_count: how many rows we expect to see per node + :param timeout: maximum seconds to wait + :param poll_interval: seconds between polls + """ + all_hosts = list(session.cluster.metadata.all_hosts()) + deadline = time.time() + timeout + while True: + pending_hosts = [] + for host in all_hosts: + rows = list(session.execute( + "SELECT * FROM system.client_routes WHERE connection_id = %s", + (connection_id,), + host=host, + )) + if len(rows) < expected_count: + pending_hosts.append((host, len(rows))) + if not pending_hosts: + return + if time.time() >= deadline: + details = ", ".join( + "%s: %d" % (h.address, count) for h, count in pending_hosts + ) + raise RuntimeError( + "Timed out waiting for %d routes (connection_id=%s) to appear " + "in system.client_routes on all nodes; pending: %s" + % (expected_count, connection_id, details) + ) + time.sleep(poll_interval) + + +def node_id_from_ip(ip): + """Extract node_id from an IP like '127.0.0.3' -> 3.""" + return int(ip.split(".")[-1]) + + +def assert_routes_via_nlb(test, cluster, nlb, expected_node_ids): + """ + Assert that every host in *expected_node_ids* has its endpoint + resolving through the NLB (correct address and per-node port). + """ + nlb_listen_host = NLBEmulator.LISTEN_HOST + expected_node_ids = set(expected_node_ids) + + seen_node_ids = set() + for host in cluster.metadata.all_hosts(): + ep = host.endpoint + if not isinstance(ep, ClientRoutesEndPoint): + continue + node_id = node_id_from_ip(ep.address) + if node_id not in expected_node_ids: + continue + resolved_addr, resolved_port = ep.resolve() + test.assertEqual( + resolved_addr, nlb_listen_host, + "Node %d endpoint should resolve to NLB address %s, got %s" + % (node_id, nlb_listen_host, resolved_addr), + ) + test.assertEqual( + resolved_port, nlb.node_port(node_id), + "Node %d endpoint should resolve to NLB port %d, got %d" + % (node_id, nlb.node_port(node_id), resolved_port), + ) + seen_node_ids.add(node_id) + test.assertEqual( + seen_node_ids, expected_node_ids, + "Not all expected nodes found in metadata endpoints", + ) + + +def assert_routes_direct(test, cluster, expected_node_ids, direct_port=9042): + """ + Assert that every host in *expected_node_ids* has its endpoint + resolving to the node's own IP on *direct_port*. + """ + expected_node_ids = set(expected_node_ids) + + for host in cluster.metadata.all_hosts(): + ep = host.endpoint + if not isinstance(ep, ClientRoutesEndPoint): + continue + node_id = node_id_from_ip(ep.address) + if node_id not in expected_node_ids: + continue + resolved_addr, resolved_port = ep.resolve() + expected_ip = "127.0.0.%d" % node_id + test.assertEqual( + resolved_addr, expected_ip, + "Node %d endpoint should resolve to direct address %s, got %s" + % (node_id, expected_ip, resolved_addr), + ) + test.assertEqual( + resolved_port, direct_port, + "Node %d endpoint should resolve to direct port %d, got %d" + % (node_id, direct_port, resolved_port), + ) + + +def setup_module(): + os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" + use_cluster('test_client_routes', [3], start=True) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestGetHostPortMapping(unittest.TestCase): + """ + Test _query_all_routes_for_connections and _query_routes_for_change_event + methods with different filtering scenarios. + """ + + @classmethod + def setUpClass(cls): + cls.cluster = TestCluster(client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy("conn_id", "127.0.0.1")])) + cls.session = cls.cluster.connect() + + cls.host_ids = [uuid.uuid4() for _ in range(3)] + cls.connection_ids = [str(uuid.uuid4()) for _ in range(3)] + cls.expected = [] + + for idx, host_id in enumerate(cls.host_ids): + ip = f"127.0.0.{idx + 1}" + for connection_id in cls.connection_ids: + cls.expected.append({ + 'connection_id': connection_id, + 'host_id': host_id, + 'address': ip, + 'port': 9042, + 'tls_port': 9142, + }) + + cls._sort_routes(cls.expected) + post_client_routes(cls.cluster.contact_points[0], cls.expected) + + @classmethod + def tearDownClass(cls): + cls.cluster.shutdown() + + @staticmethod + def _sort_routes(routes): + routes.sort(key=lambda r: (str(r['connection_id']), str(r['host_id']))) + + def _routes_to_dicts(self, routes): + """Convert _Route objects to comparable dicts, adjusting port for ssl_enabled.""" + return [ + { + 'connection_id': route.connection_id, + 'host_id': route.host_id, + 'address': route.address, + 'port': route.port, + } + for route in routes + ] + + def _expected_dicts(self, expected): + """Build expected dicts with tls_port or port based on ssl_enabled.""" + port_key = 'tls_port' if self.cluster._client_routes_handler.ssl_enabled else 'port' + return [ + { + 'connection_id': e['connection_id'], + 'host_id': e['host_id'], + 'address': e['address'], + 'port': e[port_key], + } + for e in expected + ] + + def test_get_all_routes_for_all_connections(self): + """Querying all connection IDs returns every route.""" + cc = self.cluster.control_connection + routes = self.cluster._client_routes_handler._query_all_routes_for_connections( + cc._connection, cc._timeout, self.connection_ids, + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + expected = self._expected_dicts(self.expected) + self._sort_routes(expected) + self.assertEqual(got, expected) + + def test_get_routes_for_single_connection(self): + """Querying a single connection ID returns only its routes.""" + cc = self.cluster.control_connection + routes = self.cluster._client_routes_handler._query_all_routes_for_connections( + cc._connection, cc._timeout, [self.connection_ids[0]], + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + filtered = [r for r in self.expected + if r['connection_id'] == self.connection_ids[0]] + expected = self._expected_dicts(filtered) + self._sort_routes(expected) + self.assertEqual(got, expected) + + def test_get_routes_for_change_event_all_pairs(self): + """Querying all (connection_id, host_id) pairs returns every route.""" + cc = self.cluster.control_connection + pairs = [(r['connection_id'], r['host_id']) for r in self.expected] + routes = self.cluster._client_routes_handler._query_routes_for_change_event( + cc._connection, cc._timeout, pairs, + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + expected = self._expected_dicts(self.expected) + self._sort_routes(expected) + self.assertEqual(got, expected) + + def test_get_routes_for_change_event_single_pair(self): + """Querying a single (connection_id, host_id) pair returns one route.""" + cc = self.cluster.control_connection + target_conn_id = self.connection_ids[0] + target_host_id = self.host_ids[0] + routes = self.cluster._client_routes_handler._query_routes_for_change_event( + cc._connection, cc._timeout, [(target_conn_id, target_host_id)], + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + filtered = [r for r in self.expected + if r['connection_id'] == target_conn_id + and r['host_id'] == target_host_id] + expected = self._expected_dicts(filtered) + self._sort_routes(expected) + self.assertEqual(got, expected) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestPrivateLinkConnectivity(unittest.TestCase): + """ + Verifies the driver connects to all cluster nodes exclusively through + the NLB proxy, never directly. + + Setup: + 1. Start a 3-node CCM cluster (done by setup_module). + 2. Start an NLB emulator with per-node proxies. + 3. Use a direct session to read host_ids, then POST client routes + pointing each host_id at the NLB proxy port. + 4. Create a client-routes-enabled session using the NLB discovery + port as the contact point. + 5. Verify all driver connections go through proxy ports. + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + log.info("Host ID map: %s", cls.host_id_map) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.nlb = NLBEmulator() + cls.nlb.start(cls.node_addrs) + + cls.connection_id = str(uuid.uuid4()) + post_routes_for_nlb("127.0.0.1", cls.connection_id, cls.host_id_map, cls.nlb) + wait_for_routes_visible(cls.direct_session, cls.connection_id, len(cls.host_id_map)) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + cls.nlb.stop() + + def _make_client_routes_cluster(self, **extra_kwargs): + """Create a Cluster configured with client-routes pointing at the NLB.""" + return Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=self.nlb.discovery_port, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + **extra_kwargs, + ) + + def test_all_connections_through_proxy(self): + """Every pool connection must go through the NLB proxy, not directly.""" + with self._make_client_routes_cluster() as cluster: + session = cluster.connect(wait_for_all_pools=True) + + for _ in range(50): + session.execute("SELECT key FROM system.local") + + pool_state = session.get_pool_state() + self.assertEqual(len(pool_state), len(self.node_addrs), + "Driver should have pools for all nodes") + + for host, state in pool_state.items(): + node_id = node_id_from_ip(host.address) + proxy = self.nlb.get_node_proxy(node_id) + self.assertIsNotNone(proxy, f"No proxy for node {node_id}") + open_count = state['open_count'] + self.assertGreaterEqual( + proxy.total_connections, open_count, + f"Node {node_id} proxy saw {proxy.total_connections} " + f"connections but pool has {open_count} open — " + f"some connections bypassed the proxy") + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + + def test_queries_succeed_through_proxy(self): + """Queries should work normally through the proxy.""" + with self._make_client_routes_cluster() as cluster: + session = cluster.connect() + session.execute( + "CREATE KEYSPACE IF NOT EXISTS test_cr_ks " + "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}" + ) + session.execute( + "CREATE TABLE IF NOT EXISTS test_cr_ks.t (k int PRIMARY KEY, v text)" + ) + session.execute("INSERT INTO test_cr_ks.t (k, v) VALUES (1, 'hello')") + row = session.execute("SELECT v FROM test_cr_ks.t WHERE k = 1").one() + self.assertEqual(row.v, "hello") + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + + def test_connection_recovery_after_proxy_drop(self): + """ + After the proxy drops all connections, the driver should reconnect + (still through the proxy). + """ + with self._make_client_routes_cluster() as cluster: + session = cluster.connect(wait_for_all_pools=True) + session.execute("SELECT key FROM system.local") + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + + self.nlb.drop_all_connections() + + def query_ok(): + session.execute("SELECT key FROM system.local") + + wait_until_not_raised(query_ok, 1, 30) + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestDynamicRouteUpdates(unittest.TestCase): + """ + Verify that when routes are updated (e.g. port changes), the driver + picks up the new routes and reconnects through the new proxy ports + after existing connections are dropped. + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + + def test_route_update_causes_reconnect_to_new_port(self): + """ + 1. Start NLB v1, post routes -> driver connects through v1 ports. + 2. Start NLB v2 on different ports, post new routes. + 3. Drop v1 connections. + 4. Driver should reconnect through v2 ports. + """ + with NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb_v1, NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb_v2: + post_routes_for_nlb("127.0.0.1", self.connection_id, + self.host_id_map, nlb_v1) + wait_for_routes_visible(self.direct_session, self.connection_id, len(self.host_id_map)) + + with Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=nlb_v1.discovery_port, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + session.execute("SELECT key FROM system.local") + + for node_id in self.node_addrs: + self.assertGreater( + nlb_v1.get_node_proxy(node_id).total_connections, 0) + assert_routes_via_nlb(self, cluster, nlb_v1, + self.node_addrs.keys()) + + post_routes_for_nlb("127.0.0.1", self.connection_id, + self.host_id_map, nlb_v2) + time.sleep(2) # let CLIENT_ROUTES_CHANGE propagate + + # Stop v1 per-node proxies entirely so v1 ports become + # unreachable, forcing the driver to reconnect through v2. + # (Merely dropping connections is insufficient because v1 + # proxies would still accept new connections before the + # route update propagates.) + for node_id in list(self.node_addrs.keys()): + nlb_v1.remove_node(node_id) + + def all_nodes_via_v2(): + session.execute("SELECT key FROM system.local") + for nid in self.node_addrs: + assert nlb_v2.get_node_proxy(nid).total_connections > 0, \ + "NLB v2 node %d proxy has no connections yet" % nid + + wait_until_not_raised(all_nodes_via_v2, 1, 30) + + assert_routes_via_nlb(self, cluster, nlb_v2, + self.node_addrs.keys()) + + +def _generate_ssl_certs(cert_dir, node_ips): + """ + Generate test SSL certificates with SANs covering the given node IPs. + + File names follow CCM's ``ScyllaCluster.enable_ssl()`` convention so the + resulting directory can be passed directly to ``enable_ssl(cert_dir, ...)``. + + Creates: + - ca.key / ca.crt: self-signed CA + - ccm_node.key / ccm_node.pem: server cert signed by CA with SANs for all node_ips + + :param cert_dir: directory to write files into (must exist) + :param node_ips: list of IP strings to include as SANs (e.g. ["127.0.0.1", "127.0.0.2"]) + """ + if shutil.which("openssl") is None: + raise unittest.SkipTest("openssl not found on PATH; skipping SSL cert generation") + + san_cnf = os.path.join(cert_dir, "san.cnf") + san_value = ",".join("IP:%s" % ip for ip in node_ips) + with open(san_cnf, "w") as f: + f.write("subjectAltName=%s\n" % san_value) + + def _run(cmd): + result = subprocess.run(cmd, cwd=cert_dir, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError("Command failed: %s\n%s" % (" ".join(cmd), result.stderr)) + + _run(["openssl", "req", "-x509", "-newkey", "rsa:2048", + "-keyout", "ca.key", "-out", "ca.crt", + "-days", "1", "-nodes", "-subj", "/CN=Test CA"]) + + _run(["openssl", "req", "-newkey", "rsa:2048", + "-keyout", "ccm_node.key", "-out", "ccm_node.csr", + "-nodes", "-subj", "/CN=Test Server"]) + + _run(["openssl", "x509", "-req", + "-in", "ccm_node.csr", "-CA", "ca.crt", "-CAkey", "ca.key", + "-CAcreateserial", "-out", "ccm_node.pem", + "-days", "1", "-extfile", "san.cnf"]) + + log.info("Generated SSL certs in %s with SANs: %s", cert_dir, san_value) + + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestMixedDirectAndNlbConnections(unittest.TestCase): + """ + Verify the cluster works when some nodes are accessed through the NLB + proxy and others are accessed directly (no route posted, falls back + to the default endpoint). + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + + def test_mixed_direct_and_nlb_connections(self): + """ + Post routes for only a subset of nodes (through NLB proxy). + Remaining nodes have no route and fall back to direct connections. + Queries should work through both paths. + """ + proxied_node_id = min(self.node_addrs.keys()) + proxied_ip = self.node_addrs[proxied_node_id] + + with NLBEmulator( + node_addresses={proxied_node_id: proxied_ip}, + ) as nlb: + proxied_host_id = self.host_id_map[proxied_ip] + routes = [{ + "connection_id": self.connection_id, + "host_id": proxied_host_id, + "address": NLBEmulator.LISTEN_HOST, + "port": nlb.node_port(proxied_node_id), + }] + post_client_routes("127.0.0.1", routes) + time.sleep(1) + + with Cluster( + contact_points=["127.0.0.1"], + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + + for _ in range(50): + session.execute("SELECT key FROM system.local") + + assert_routes_via_nlb(self, cluster, nlb, + [proxied_node_id]) + + direct_node_ids = set(self.node_addrs.keys()) - {proxied_node_id} + assert_routes_direct(self, cluster, direct_node_ids) + + proxy = nlb.get_node_proxy(proxied_node_id) + self.assertGreater(proxy.total_connections, 0, + "Proxied node should have connections through NLB") + + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestSslThroughNlb(unittest.TestCase): + """ + Verify SSL with check_hostname=False works through the NLB proxy. + + When using client routes, connections go through NLB proxies whose + addresses won't match server certificates, so hostname verification + must be disabled. Certificate chain validation (verify_mode=CERT_REQUIRED) + is still active — only hostname matching is skipped. + + The driver raises ValueError at Cluster init time if check_hostname=True + is used with client_routes_config. + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + cls.direct_cluster.shutdown() + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + cls.cert_dir = tempfile.mkdtemp(prefix="client-routes-ssl-") + cert_ips = list(cls.node_addrs.values()) + _generate_ssl_certs(cls.cert_dir, cert_ips) + + cls.ccm_cluster = get_cluster() + cls.ccm_cluster.stop() + cls.ccm_cluster.set_configuration_options({ + 'client_encryption_options': { + 'enabled': True, + 'certificate': os.path.join(cls.cert_dir, "ccm_node.pem"), + 'keyfile': os.path.join(cls.cert_dir, "ccm_node.key"), + } + }) + cls.ccm_cluster.start(wait_for_binary_proto=True) + + @classmethod + def tearDownClass(cls): + cls.ccm_cluster.stop() + cls.ccm_cluster.set_configuration_options({ + 'client_encryption_options': { + 'enabled': False, + } + }) + cls.ccm_cluster.start(wait_for_binary_proto=True) + + shutil.rmtree(cls.cert_dir, ignore_errors=True) + + def test_ssl_without_hostname_verification_through_nlb(self): + """ + Connect through NLB with SSL but check_hostname=False. + + When using client routes, connections go through NLB proxies + whose addresses won't match server certificates, so hostname + verification must be disabled. Certificate chain validation + (verify_mode=CERT_REQUIRED) is still active. + """ + with NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb: + routes = build_routes_for_nlb( + self.connection_id, self.host_id_map, nlb, + ) + for route in routes: + route["tls_port"] = route["port"] + post_client_routes("127.0.0.1", routes) + + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_ctx.check_hostname = False + ssl_ctx.load_verify_locations(os.path.join(self.cert_dir, 'ca.crt')) + + self.assertFalse(ssl_ctx.check_hostname, + "check_hostname must be False for this test") + self.assertEqual(ssl_ctx.verify_mode, ssl.CERT_REQUIRED, + "verify_mode must be CERT_REQUIRED") + + def routes_visible(): + with TestCluster( + contact_points=["127.0.0.1"], + ssl_context=ssl_ctx, + ) as c: + session = c.connect() + rs = session.execute( + "SELECT * FROM system.client_routes " + "WHERE connection_id = %s ALLOW FILTERING", + (self.connection_id,) + ) + return len(list(rs)) >= len(self.host_id_map) + + wait_until_not_raised( + lambda: self.assertTrue(routes_visible()), + 0.5, 10, + ) + + with Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=nlb.discovery_port, + ssl_context=ssl_ctx, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + + for _ in range(20): + row = session.execute( + "SELECT release_version FROM system.local" + ).one() + self.assertIsNotNone(row) + + assert_routes_via_nlb(self, cluster, nlb, + self.node_addrs.keys()) + + def test_ssl_with_hostname_verification_raises_error(self): + """ + Verify that Cluster raises ValueError when client_routes_config + is used with SSL hostname verification enabled. + """ + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_ctx.load_verify_locations(os.path.join(self.cert_dir, 'ca.crt')) + self.assertTrue(ssl_ctx.check_hostname) + + with self.assertRaises(ValueError) as cm: + Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + ssl_context=ssl_ctx, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy("test-id", NLBEmulator.LISTEN_HOST)], + ), + ) + self.assertIn("check_hostname", str(cm.exception)) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestFullNodeReplacementThroughNlb(unittest.TestCase): + """ + End-to-end test: creates a session through an NLB proxy with client routes, + scales the cluster up, then decommissions original nodes, verifying the + session survives the full node replacement. + + This test is destructive — it modifies the CCM cluster topology by + bootstrapping new nodes and decommissioning original ones. It uses + its own CCM cluster so it cannot interfere with other tests. + """ + + @classmethod + def setUpClass(cls): + os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" + use_cluster('test_client_routes_replacement', [3], start=True) + + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + + def test_should_survive_full_node_replacement_through_nlb(self): + """ + 1. Start with 3 nodes behind the NLB + 2. Bootstrap 2 new nodes, add to NLB, update routes + 3. Decommission the original 3 nodes one-by-one, updating NLB/routes + 4. Verify the session survives with only new nodes + """ + original_node_ids = sorted(self.node_addrs.keys()) + with NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb: + # ---- Stage 1: Set up NLB for initial nodes ---- + log.info("Stage 1: Setting up NLB for %d initial nodes", len(original_node_ids)) + + post_routes_for_nlb("127.0.0.1", self.connection_id, self.host_id_map, nlb) + wait_for_routes_visible(self.direct_session, self.connection_id, len(self.host_id_map)) + + # ---- Stage 2: Create session through NLB ---- + log.info("Stage 2: Creating session through NLB") + with Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=nlb.discovery_port, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + self._assert_query_works(session) + + handler = cluster._client_routes_handler + self.assertIsNotNone(handler) + + assert_routes_via_nlb(self, cluster, nlb, + original_node_ids) + log.info("Stage 2: Session created, all %d nodes via NLB", + len(original_node_ids)) + + # ---- Stage 3: Bootstrap new nodes ---- + new_node_ids = [max(original_node_ids) + 1, max(original_node_ids) + 2] + log.info("Stage 3: Adding nodes %s", new_node_ids) + ccm_cluster = get_cluster() + + for node_id in new_node_ids: + self._bootstrap_node(ccm_cluster, node_id) + + expected_total = len(original_node_ids) + len(new_node_ids) + self._wait_for_condition( + lambda: len(cluster.metadata.all_hosts()) >= expected_total, + timeout_seconds=60, + description="%d nodes in metadata" % expected_total, + ) + + for node_id in new_node_ids: + nlb.add_node(node_id, "127.0.0.%d" % node_id) + + all_host_ids = get_host_ids_from_cluster(session) + log.info("All host IDs after expansion: %s", all_host_ids) + post_routes_for_nlb("127.0.0.1", self.connection_id, all_host_ids, nlb) + + handler.initialize( + cluster.control_connection._connection, + cluster.control_connection._timeout) + + self._wait_for_condition( + lambda: sum(1 for h in cluster.metadata.all_hosts() if h.is_up) >= expected_total, + timeout_seconds=60, + description="all %d nodes up" % expected_total, + ) + + self._assert_query_works(session) + + all_node_ids = set(original_node_ids) | set(new_node_ids) + assert_routes_via_nlb(self, cluster, nlb, all_node_ids) + log.info("Stage 3: All %d nodes via NLB after expansion", + len(all_node_ids)) + + # ---- Stage 4: Decommission original nodes ---- + log.info("Stage 4: Decommissioning original nodes %s", original_node_ids) + + remaining_node_ids = set(all_node_ids) + remaining_host_ids = dict(all_host_ids) + for node_id in original_node_ids: + log.info("Decommissioning node %d", node_id) + get_node(node_id).decommission() + nlb.remove_node(node_id) + remaining_node_ids.discard(node_id) + + ip = "127.0.0.%d" % node_id + remaining_host_ids.pop(ip, None) + + surviving_ips = list(remaining_host_ids.keys()) + if surviving_ips: + post_routes_for_nlb( + surviving_ips[0], self.connection_id, + remaining_host_ids, nlb, + ) + + expected_remaining = expected_total - (original_node_ids.index(node_id) + 1) + self._wait_for_condition( + lambda er=expected_remaining: ( + len(cluster.metadata.all_hosts()) <= er + and self._query_succeeds(session) + ), + timeout_seconds=60, + description="node %d decommissioned" % node_id, + ) + + # Reload routes after the control connection has + # re-established itself (the decommission may have + # killed the old control connection). + handler.initialize( + cluster.control_connection._connection, + cluster.control_connection._timeout) + + assert_routes_via_nlb(self, cluster, nlb, + remaining_node_ids) + log.info("Node %d decommissioned, %d nodes still via NLB", + node_id, len(remaining_node_ids)) + + # ---- Stage 5: Verify with only new nodes ---- + log.info("Stage 5: Verifying session works with only new nodes %s", new_node_ids) + self._assert_query_works(session) + + hosts = cluster.metadata.all_hosts() + self.assertEqual( + len(hosts), len(new_node_ids), + "Expected %d hosts, got %d" % (len(new_node_ids), len(hosts)) + ) + + for _ in range(10): + self._assert_query_works(session) + + assert_routes_via_nlb(self, cluster, nlb, new_node_ids) + log.info("PASS: Full node replacement, all %d new nodes via NLB", + len(new_node_ids)) + + def _assert_query_works(self, session): + rs = session.execute("SELECT release_version FROM system.local WHERE key='local'") + row = rs.one() + self.assertIsNotNone(row, "Query via NLB should return a result") + + def _query_succeeds(self, session): + try: + self._assert_query_works(session) + return True + except Exception: + return False + + def _bootstrap_node(self, ccm_cluster, node_id): + node_type = type(next(iter(ccm_cluster.nodes.values()))) + ip = "127.0.0.%d" % node_id + node_instance = node_type( + 'node%s' % node_id, + ccm_cluster, + auto_bootstrap=True, + thrift_interface=(ip, 9160), + storage_interface=(ip, 7000), + binary_interface=(ip, 9042), + jmx_port=str(7000 + 100 * node_id), + remote_debug_port=0, + initial_token=None, + ) + ccm_cluster.add(node_instance, is_seed=False) + node_instance.start(wait_for_binary_proto=True, wait_other_notice=True) + wait_for_node_socket(node_instance, 120) + log.info("Node %d bootstrapped successfully", node_id) + + @staticmethod + def _wait_for_condition(predicate, timeout_seconds, poll_interval=2, description="condition"): + deadline = time.time() + timeout_seconds + while time.time() < deadline: + if predicate(): + return True + time.sleep(poll_interval) + raise AssertionError( + "Timed out waiting for %s after %d seconds" % (description, timeout_seconds) + ) From efdc08a9aa5d72128f8cb87faf43b2d8a711cfd2 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Wed, 18 Mar 2026 14:30:10 -0400 Subject: [PATCH 09/27] Release 3.29.9: changelog, version and documentation --- CHANGELOG.rst | 22 ++++++++++++++++++++++ cassandra/__init__.py | 2 +- docs/conf.py | 3 ++- docs/installation.rst | 4 ++-- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0c4aa63669..3ae00a7ee8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,25 @@ +3.29.9 +====== +March 18, 2026 + +Features +-------- +* Add Private Link support via client routes handler +* Add optional query_params parameter to QueryMessage + +Bug Fixes +--------- +* Fix segmentation fault in libev prepare_callback during shutdown +* Add null checks to io_callback and timer_callback in libev wrapper +* Fix RecursionError in execute_concurrent on synchronous errbacks +* Fix floating-point precision loss for timestamps far from epoch + +Others +------ +* Cache parsed tablet routing type in ResponseFuture +* Remove deprecated setup_requires in favor of PEP 517 build-system.requires +* Update dependency hatchling to v1.29.0 + 3.29.8 ====== February 09, 2026 diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 5567c0b9bd..3ad8fcdfd1 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -23,7 +23,7 @@ def emit(self, record): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 29, 8) +__version_info__ = (3, 29, 9) __version__ = '.'.join(map(str, __version_info__)) diff --git a/docs/conf.py b/docs/conf.py index 403908c29e..4b6b329525 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,10 +29,11 @@ '3.29.6-scylla', '3.29.7-scylla', '3.29.8-scylla', + '3.29.9-scylla', ] BRANCHES = ['master'] # Set the latest version. -LATEST_VERSION = '3.29.8-scylla' +LATEST_VERSION = '3.29.9-scylla' # Set which versions are not released yet. UNSTABLE_VERSIONS = ['master'] # Set which versions are deprecated diff --git a/docs/installation.rst b/docs/installation.rst index 4207c46092..7b4823b832 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -26,7 +26,7 @@ To check if the installation was successful, you can run:: python -c 'import cassandra; print(cassandra.__version__)' -It should print something like "3.29.8". +It should print something like "3.29.9". (*Optional*) Compression Support -------------------------------- @@ -199,7 +199,7 @@ through `Homebrew `_. For example, on Mac OS X:: $ brew install libev -The libev extension can now be built for Windows as of Python driver version 3.29.8. You can +The libev extension can now be built for Windows as of Python driver version 3.29.9. You can install libev using any Windows package manager. For example, to install using `vcpkg `_: $ vcpkg install libev From fec90aec2a362ab6c7341adb762e26ab82a660d7 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko <52855732+sylwiaszunejko@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:32:48 +0100 Subject: [PATCH 10/27] Specify auth superuser name for tests (#759) Recently scylladb started to rely on the options "--auth-superuser-name" and "--auth-superuser-salted-password" to ensure that a cassandra/cassandra user exists for tests - without those options a default superuser no longer exists. --- tests/integration/standard/test_authentication.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integration/standard/test_authentication.py b/tests/integration/standard/test_authentication.py index eb8019bf65..0208909494 100644 --- a/tests/integration/standard/test_authentication.py +++ b/tests/integration/standard/test_authentication.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + from packaging.version import Version import logging import time @@ -34,6 +36,7 @@ def setup_module(): + os.environ['SCYLLA_EXT_OPTS'] = '--auth-superuser-name=cassandra --auth-superuser-salted-password=$6$x7IFjiX5VCpvNiFk$2IfjTvSyGL7zerpV.wbY7mJjaRCrJ/68dtT3UpT.sSmNYz1bPjtn3mH.kJKFvaZ2T4SbVeBijjmwGjcb83LlV/' if CASSANDRA_IP.startswith("127.0.0.") and not USE_CASS_EXTERNAL: use_singledc(start=False) ccm_cluster = get_cluster() From 153c913482ee9fa1cd914df795fdd41f8e56f234 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 6 Mar 2026 10:49:02 +0200 Subject: [PATCH 11/27] (improvement) cqltypes: fast-path lookup_casstype() for simple type names Skip the regex scanner and stack-based parser in parse_casstype_args() when the type string has no parentheses. For simple types like 'AsciiType' or 'org.apache.cassandra.db.marshal.FloatType', go directly to lookup_casstype_simple() which is just a prefix strip + dict lookup. This avoids re.Scanner, re.split on ':' / '=>', int() try/except, and list-of-lists stack manipulation for the common case of non-parameterized types. Signed-off-by: Yaniv Kaul --- cassandra/cqltypes.py | 2 ++ tests/unit/test_types.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index d33e5fceb8..547a13c979 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -249,6 +249,8 @@ def lookup_casstype(casstype): """ if isinstance(casstype, (CassandraType, CassandraTypeType)): return casstype + if '(' not in casstype: + return lookup_casstype_simple(casstype) try: return parse_casstype_args(casstype) except (ValueError, AssertionError, IndexError) as e: diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 7a8c584f75..11aab2748d 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -120,8 +120,9 @@ def test_lookup_casstype(self): assert str(lookup_casstype('unknown')) == str(cassandra.cqltypes.mkUnrecognizedType('unknown')) - with pytest.raises(ValueError): - lookup_casstype('AsciiType~') + # With the fast-path for simple type names (no parens), malformed names + # like 'AsciiType~' create unrecognized types instead of raising ValueError + assert str(lookup_casstype('AsciiType~')) == str(cassandra.cqltypes.mkUnrecognizedType('AsciiType~')) def test_casstype_parameterized(self): assert LongType.cass_parameterized_type_with(()) == 'LongType' From 70995bd808bbfae6dd1998c6cdc5f82a0557616c Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 27 Mar 2026 11:15:48 +0300 Subject: [PATCH 12/27] tests: remove redundant 10s sleep from setup_keyspace() The time.sleep(10) in setup_keyspace() is redundant because callers already ensure the cluster is fully ready before calling it: - use_cluster() calls start_cluster_wait_for_up() which uses wait_for_binary_proto=True + wait_other_notice=True, then wait_for_node_socket() per node - External cluster path (wait=False) had no sleep anyway Remove the wait parameter entirely and its associated sleep, saving 10s per cluster startup. --- tests/integration/__init__.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index a53e7aafa6..2015e0663f 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -442,7 +442,7 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, else: log.debug("Using unnamed external cluster") if set_keyspace and start: - setup_keyspace(ipformat=ipformat, wait=False) + setup_keyspace(ipformat=ipformat) return if is_current_cluster(cluster_name, nodes, workloads): @@ -632,11 +632,7 @@ def drop_keyspace_shutdown_cluster(keyspace_name, session, cluster): cluster.shutdown() -def setup_keyspace(ipformat=None, wait=True, protocol_version=None, port=9042): - # wait for nodes to startup - if wait: - time.sleep(10) - +def setup_keyspace(ipformat=None, protocol_version=None, port=9042): if protocol_version: _protocol_version = protocol_version else: From 7931113b6c1ba70c6b937edf4a8178562d0854a3 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 27 Mar 2026 11:19:19 +0300 Subject: [PATCH 13/27] tests: replace high-priority time.sleep() calls with polling Replace fixed sleeps with condition-based polling to speed up tests: - simulacron/utils.py: replace 5s sleep with HTTP endpoint polling (max 15s timeout, typically <1s) - test_authentication.py: replace 10s sleep with auth readiness poll that tries connecting with default credentials - upgrade/__init__.py: replace 10s auth sleep with same polling pattern - upgrade/test_upgrade.py: replace 3x 20s sleeps (60s total) with control connection readiness polling Total potential saving: ~95s of unconditional waiting per test run. --- tests/integration/simulacron/utils.py | 9 +++++++-- .../standard/test_authentication.py | 18 ++++++++++++++--- tests/integration/upgrade/__init__.py | 20 +++++++++++++++---- tests/integration/upgrade/test_upgrade.py | 17 +++++++++++++--- 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/tests/integration/simulacron/utils.py b/tests/integration/simulacron/utils.py index b6136e247a..2322319234 100644 --- a/tests/integration/simulacron/utils.py +++ b/tests/integration/simulacron/utils.py @@ -89,8 +89,13 @@ def start_simulacron(): SERVER_SIMULACRON.start() - # TODO improve this sleep, maybe check the logs like ccm - time.sleep(5) + # Poll the admin endpoint until simulacron is ready + def _check_simulacron_ready(): + opener = build_opener(HTTPHandler) + request = Request("http://127.0.0.1:8187/cluster") + opener.open(request, timeout=2) + + wait_until_not_raised(_check_simulacron_ready, delay=0.5, max_attempts=30) def stop_simulacron(): diff --git a/tests/integration/standard/test_authentication.py b/tests/integration/standard/test_authentication.py index 0208909494..d8073af659 100644 --- a/tests/integration/standard/test_authentication.py +++ b/tests/integration/standard/test_authentication.py @@ -49,10 +49,22 @@ def setup_module(): # PYTHON-1328 # - # Give the cluster enough time to startup (and perform necessary initialization) - # before executing the test. + # Wait for PasswordAuthenticator to finish initializing (creating the + # default superuser). Poll by attempting to authenticate rather than + # using a fixed sleep. if CASSANDRA_VERSION > Version('4.0-a'): - time.sleep(10) + from tests.util import wait_until_not_raised + + def _check_auth_ready(): + cluster = TestCluster(protocol_version=PROTOCOL_VERSION, + auth_provider=PlainTextAuthProvider('cassandra', 'cassandra')) + try: + session = cluster.connect() + session.execute("SELECT * FROM system.local WHERE key='local'") + finally: + cluster.shutdown() + + wait_until_not_raised(_check_auth_ready, delay=1, max_attempts=30) def teardown_module(): remove_cluster() # this test messes with config diff --git a/tests/integration/upgrade/__init__.py b/tests/integration/upgrade/__init__.py index a1c751bcbd..fab6fed34a 100644 --- a/tests/integration/upgrade/__init__.py +++ b/tests/integration/upgrade/__init__.py @@ -182,9 +182,21 @@ class UpgradeBaseAuth(UpgradeBase): def _upgrade_step_setup(self): """ - We sleep here for the same reason as we do in test_authentication.py: - there seems to be some race, with some versions of C* taking longer to - get the auth (and default user) setup. Sleep here to give it a chance + Wait for PasswordAuthenticator to finish initializing (creating the + default superuser). Poll by attempting to authenticate rather than + using a fixed sleep. """ super(UpgradeBaseAuth, self)._upgrade_step_setup() - time.sleep(10) + + from cassandra.auth import PlainTextAuthProvider + from tests.util import wait_until_not_raised + + def _check_auth_ready(): + c = Cluster(auth_provider=PlainTextAuthProvider('cassandra', 'cassandra')) + try: + s = c.connect() + s.execute("SELECT * FROM system.local WHERE key='local'") + finally: + c.shutdown() + + wait_until_not_raised(_check_auth_ready, delay=1, max_attempts=30) diff --git a/tests/integration/upgrade/test_upgrade.py b/tests/integration/upgrade/test_upgrade.py index fec9a38604..45827723b3 100644 --- a/tests/integration/upgrade/test_upgrade.py +++ b/tests/integration/upgrade/test_upgrade.py @@ -19,11 +19,22 @@ from cassandra.cluster import ConsistencyLevel, Cluster, DriverException, ExecutionProfile from cassandra.policies import ConstantSpeculativeExecutionPolicy from tests.integration.upgrade import UpgradeBase, UpgradeBaseAuth, UpgradePath, upgrade_paths +from tests.util import wait_until import unittest import pytest +def _wait_for_control_connection(cluster_driver, timeout=60): + """Wait for the driver's control connection to be established.""" + wait_until( + lambda: cluster_driver.control_connection._connection is not None + and not cluster_driver.control_connection._connection.is_closed, + delay=1, + max_attempts=timeout, + ) + + # Previous Cassandra upgrade two_to_three_path = upgrade_paths([ UpgradePath("2.2.9-3.11", {"version": "2.2.9"}, {"version": "3.11.4"}, {}), @@ -142,14 +153,14 @@ def test_schema_metadata_gets_refreshed(self): for node in nodes[1:]: self.upgrade_node(node) # Wait for the control connection to reconnect - time.sleep(20) + _wait_for_control_connection(self.cluster_driver) with pytest.raises(DriverException): self.cluster_driver.refresh_schema_metadata(max_schema_agreement_wait=10) self.upgrade_node(nodes[0]) # Wait for the control connection to reconnect - time.sleep(20) + _wait_for_control_connection(self.cluster_driver) self.cluster_driver.refresh_schema_metadata(max_schema_agreement_wait=40) assert original_meta != self.cluster_driver.metadata.keyspaces @@ -171,7 +182,7 @@ def test_schema_nodes_gets_refreshed(self): token_map = self.cluster_driver.metadata.token_map self.upgrade_node(node) # Wait for the control connection to reconnect - time.sleep(20) + _wait_for_control_connection(self.cluster_driver) self.cluster_driver.refresh_nodes(force_token_rebuild=True) self._assert_same_token_map(token_map, self.cluster_driver.metadata.token_map) From 4a23f72f356608d6d0518c5b698f821b04b716f0 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 27 Mar 2026 11:32:30 +0300 Subject: [PATCH 14/27] tests: replace medium-priority time.sleep() calls with polling Replace fixed sleeps with condition-based polling in four test files: - test_shard_aware.py: replace 25s of sleeps (5+10+5+5) with wait_until_not_raised polling for reconnection after shard connection close and iptables blocking - test_metrics.py: replace 15s of sleeps (5+5+5) with polling for cluster recovery and node-down detection - test_tablets.py: replace 13s of sleeps (3+10) with polling for metadata refresh and decommission completion - simulacron/test_connection.py: replace 20s of sleeps (10+10) with polling for quiescent pool state Total potential saving: ~73s of unconditional waiting. --- .../integration/simulacron/test_connection.py | 14 +++--- tests/integration/standard/test_metrics.py | 18 +++++-- .../integration/standard/test_shard_aware.py | 48 +++++++++++++++---- tests/integration/standard/test_tablets.py | 13 ++++- 4 files changed, 72 insertions(+), 21 deletions(-) diff --git a/tests/integration/simulacron/test_connection.py b/tests/integration/simulacron/test_connection.py index 818d0b46b9..ceceea814f 100644 --- a/tests/integration/simulacron/test_connection.py +++ b/tests/integration/simulacron/test_connection.py @@ -23,7 +23,7 @@ from cassandra.policies import HostStateListener, RoundRobinPolicy, WhiteListRoundRobinPolicy from tests import connection_class, thread_pool_executor_class -from tests.util import late +from tests.util import late, wait_until_not_raised from tests.integration import requiressimulacron, libevtest from tests.integration.util import assert_quiescent_pool_state # important to import the patch PROTOCOL_VERSION from the simulacron module @@ -356,13 +356,15 @@ def test_retry_after_defunct(self): for _ in range(10): session.execute(query_to_prime) - # Might take some time to close the previous connections and reconnect - time.sleep(10) - assert_quiescent_pool_state(cluster) + # Wait for previous connections to close and pool to stabilize + wait_until_not_raised( + lambda: assert_quiescent_pool_state(cluster), + delay=1, max_attempts=30) clear_queries() - time.sleep(10) - assert_quiescent_pool_state(cluster) + wait_until_not_raised( + lambda: assert_quiescent_pool_state(cluster), + delay=1, max_attempts=30) def test_idle_connection_is_not_closed(self): """ diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 7b502d91c3..7ebdded141 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -25,6 +25,7 @@ from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT from tests.integration import get_cluster, get_node, use_singledc, execute_until_pass, TestCluster +from tests.util import wait_until, wait_until_not_raised from cassandra import metrics from tests.integration import BasicSharedKeyspaceUnitTestCaseRF3WM, BasicExistingKeyspaceUnitTestCase, local @@ -75,8 +76,10 @@ def test_connection_error(self): self.session.execute(query) finally: get_cluster().start(wait_for_binary_proto=True, wait_other_notice=True) - # Give some time for the cluster to come back up, for the next test - time.sleep(5) + # Wait for the cluster to come back up for the next test + wait_until_not_raised( + lambda: self.session.execute("SELECT key FROM system.local WHERE key='local'"), + delay=0.5, max_attempts=30) assert self.cluster.metrics.stats.connection_errors > 0 @@ -156,7 +159,10 @@ def test_unavailable(self): # Sometimes this commands continues with the other nodes having not noticed # 1 is down, and a Timeout error is returned instead of an Unavailable get_node(1).stop(wait=True, wait_other_notice=True) - time.sleep(5) + wait_until( + lambda: not self.cluster.metadata.get_host('127.0.0.1') or + not self.cluster.metadata.get_host('127.0.0.1').is_up, + delay=0.5, max_attempts=30) try: # Test write query = SimpleStatement("INSERT INTO test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) @@ -171,8 +177,10 @@ def test_unavailable(self): assert self.cluster.metrics.stats.unavailables == 2 finally: get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) - # Give some time for the cluster to come back up, for the next test - time.sleep(5) + # Wait for the cluster to come back up for the next test + wait_until_not_raised( + lambda: self.session.execute("SELECT key FROM system.local WHERE key='local'"), + delay=0.5, max_attempts=30) self.cluster.shutdown() diff --git a/tests/integration/standard/test_shard_aware.py b/tests/integration/standard/test_shard_aware.py index 48d1aa3609..2d764d681e 100644 --- a/tests/integration/standard/test_shard_aware.py +++ b/tests/integration/standard/test_shard_aware.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import time import random from subprocess import run import logging @@ -27,6 +26,7 @@ from cassandra import OperationTimedOut, ConsistencyLevel from tests.integration import use_cluster, get_node, PROTOCOL_VERSION +from tests.util import wait_until_not_raised LOGGER = logging.getLogger(__name__) @@ -131,6 +131,31 @@ def query_data(self, session, verify_in_tracing=True): if verify_in_tracing: self.verify_same_shard_in_tracing(results, "shard 0") + def _assert_blocked_node_disconnected(self, node_ip_address, node_port): + control_connection = self.cluster.control_connection + active_control_connection = control_connection._connection if control_connection else None + if active_control_connection and \ + active_control_connection.endpoint.address == node_ip_address and \ + active_control_connection.endpoint.port == node_port: + assert active_control_connection.is_closed or active_control_connection.is_defunct + + pools = getattr(self.session, '_pools', None) or {} + for host, pool in pools.items(): + if host.endpoint.address != node_ip_address or host.endpoint.port != node_port: + continue + + open_connections = [ + connection for connection in pool._connections.values() + if not (connection.is_closed or connection.is_defunct) + ] + assert not open_connections + + pending_connections = [ + connection for connection in pool._pending_connections + if not (connection.is_closed or connection.is_defunct) + ] + assert not pending_connections + def test_all_tracing_coming_one_shard(self): """ Testing that shard aware driver is sending the requests to the correct shards @@ -178,11 +203,13 @@ def test_closing_connections(self): continue shard_id = random.choice(list(pool._connections.keys())) pool._connections.get(shard_id).close() - time.sleep(5) - self.query_data(self.session, verify_in_tracing=False) + wait_until_not_raised( + lambda: self.query_data(self.session, verify_in_tracing=False), + delay=0.5, max_attempts=30) - time.sleep(10) - self.query_data(self.session) + wait_until_not_raised( + lambda: self.query_data(self.session), + delay=0.5, max_attempts=60) @pytest.mark.skip def test_blocking_connections(self): @@ -212,13 +239,18 @@ def remove_iptables(): '--destination {node1_ip_address}/32 -j REJECT --reject-with icmp-port-unreachable' ).format(node1_ip_address=node1_ip_address, node1_port=node1_port).split(' ') ) - time.sleep(5) + + wait_until_not_raised( + lambda: self._assert_blocked_node_disconnected(node1_ip_address, node1_port), + delay=0.1, + max_attempts=50) try: self.query_data(self.session, verify_in_tracing=False) except OperationTimedOut: pass remove_iptables() - time.sleep(5) - self.query_data(self.session, verify_in_tracing=False) + wait_until_not_raised( + lambda: self.query_data(self.session, verify_in_tracing=False), + delay=0.5, max_attempts=30) self.query_data(self.session) diff --git a/tests/integration/standard/test_tablets.py b/tests/integration/standard/test_tablets.py index d9439e5c2c..f300cb947c 100644 --- a/tests/integration/standard/test_tablets.py +++ b/tests/integration/standard/test_tablets.py @@ -6,6 +6,7 @@ from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy from tests.integration import PROTOCOL_VERSION, use_cluster, get_cluster +from tests.util import wait_until from tests.unit.test_host_connection_pool import LOGGER @@ -212,7 +213,10 @@ def test_tablets_invalidation_drop_ks(self): def drop_ks(_): # Drop and recreate ks and table to trigger tablets invalidation self.create_ks_and_cf(self.cluster.connect()) - time.sleep(3) + # Wait for tablet metadata to be refreshed + wait_until( + lambda: 'test1' in self.cluster.metadata.keyspaces, + delay=0.5, max_attempts=20) self.run_tablets_invalidation_test(drop_ks) @@ -233,7 +237,12 @@ def decommission_non_cc_node(rec): break else: assert False, "failed to find node to decommission" - time.sleep(10) + # Wait for decommission to complete and metadata to update + wait_until( + lambda: len([h for h in self.cluster.metadata.all_hosts() if h.is_up]) < 3, + delay=1, max_attempts=60) + # Allow additional time for tablet metadata invalidation to propagate + time.sleep(2) self.run_tablets_invalidation_test(decommission_non_cc_node) From 9fe993153965e454b65da57998ea33cb13c6e42f Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 29 Mar 2026 16:30:11 +0300 Subject: [PATCH 15/27] tests: fix flaky tablet tests by increasing trace timeout and polling for invalidation The tablet tests were intermittently failing because: 1. get_query_trace() used the default 2s max_wait, which is too short under resource pressure (--smp 2). Increased to 10s. 2. test_tablets_invalidation_decommission_non_cc_node used a fixed time.sleep(2) hoping tablet metadata invalidation would complete. Replaced with wait_until polling for the tablet record to be purged (0.5s delay, 20 attempts = 10s budget). --- tests/integration/standard/test_tablets.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/integration/standard/test_tablets.py b/tests/integration/standard/test_tablets.py index f300cb947c..d969140339 100644 --- a/tests/integration/standard/test_tablets.py +++ b/tests/integration/standard/test_tablets.py @@ -1,5 +1,3 @@ -import time - import pytest from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile @@ -29,7 +27,7 @@ def teardown_class(cls): cls.cluster.shutdown() def verify_hosts_in_tracing(self, results, expected): - traces = results.get_query_trace() + traces = results.get_query_trace(max_wait_sec=10) events = traces.events host_set = set() for event in events: @@ -55,7 +53,7 @@ def get_tablet_record(self, query): return metadata._tablets.get_tablet_for_key(query.keyspace, query.table, metadata.token_map.token_class.from_key(query.routing_key)) def verify_same_shard_in_tracing(self, results): - traces = results.get_query_trace() + traces = results.get_query_trace(max_wait_sec=10) events = traces.events shard_set = set() for event in events: @@ -241,8 +239,8 @@ def decommission_non_cc_node(rec): wait_until( lambda: len([h for h in self.cluster.metadata.all_hosts() if h.is_up]) < 3, delay=1, max_attempts=60) - # Allow additional time for tablet metadata invalidation to propagate - time.sleep(2) + # Tablet metadata invalidation may take additional time to propagate; + # run_tablets_invalidation_test will poll for the expected result. self.run_tablets_invalidation_test(decommission_non_cc_node) @@ -266,5 +264,7 @@ def run_tablets_invalidation_test(self, invalidate): invalidate(rec) - # Check if tablets information was purged - assert self.get_tablet_record(bound) is None, "tablet was not deleted, invalidation did not work" + # Wait for tablets information to be purged (invalidation is async) + wait_until( + lambda: self.get_tablet_record(bound) is None, + delay=0.5, max_attempts=20) From d31ea37d252bcddceb56d75bc263c4f1befc9537 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 29 Mar 2026 16:31:27 +0300 Subject: [PATCH 16/27] tests: replace fixed time.sleep() calls with polling (~17s saving) - test_cluster.py: replace sleep(1) x10 iterations with connect(wait_for_all_pools=True) for deterministic pool readiness - test_query.py: replace sleep(5) with wait_until polling for 'Preparing all known prepared statements' log message - test_connection.py: replace sleep(2) with wait_until polling for host_down listener notification --- tests/integration/standard/test_cluster.py | 3 +-- tests/integration/standard/test_connection.py | 8 +++++--- tests/integration/standard/test_query.py | 9 +++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index bf62f5df48..aab4131739 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -1121,8 +1121,7 @@ def test_stale_connections_after_shutdown(self): """ for _ in range(10): with TestCluster(protocol_version=3) as cluster: - cluster.connect().execute("SELECT * FROM system_schema.keyspaces") - time.sleep(1) + cluster.connect(wait_for_all_pools=True).execute("SELECT * FROM system_schema.keyspaces") with TestCluster(protocol_version=3) as cluster: session = cluster.connect() diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 630e5e6ba0..df0f568c2c 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -32,6 +32,7 @@ from tests import is_monkey_patched from tests.integration import use_singledc, get_node, CASSANDRA_IP, local, \ requiresmallclockgranularity, greaterthancass20, TestCluster +from tests.util import wait_until try: import cassandra.io.asyncorereactor @@ -140,9 +141,10 @@ def test_heart_beat_timeout(self): # Wait for connections associated with this host go away self.wait_for_no_connections(host, self.cluster) - # Wait to seconds for the driver to be notified - time.sleep(2) - assert test_listener.host_down + # Wait for the driver to detect the host is down + wait_until( + lambda: test_listener.host_down, + delay=0.5, max_attempts=20) # Resume paused node finally: node.resume() diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 9cebc22b05..f9d3dc26bc 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -29,7 +29,7 @@ USE_CASS_EXTERNAL, greaterthanorequalcass40, TestCluster, xfail_scylla from tests import notwindows from tests.integration import greaterthanorequalcass30, get_node -from tests.util import assertListEqual +from tests.util import assertListEqual, wait_until import time import random @@ -1571,9 +1571,10 @@ def test_reprepare_after_host_is_down(self): get_node(1).start(wait_for_binary_proto=True, wait_other_notice=True) - # We wait for cluster._prepare_all_queries to be called - time.sleep(5) - assert 1 == mock_handler.get_message_count('debug', 'Preparing all known prepared statements') + # Wait for cluster._prepare_all_queries to be called + wait_until( + lambda: mock_handler.get_message_count('debug', 'Preparing all known prepared statements') >= 1, + delay=0.5, max_attempts=20) results = self.session.execute(prepared_statement, (1,), execution_profile="only_first") assert results.one() == (1, ) From e2a951104a79795d1d68a68125ca182bd179c197 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Tue, 31 Mar 2026 12:12:18 +0200 Subject: [PATCH 17/27] Replace SCYLLA_EXT_OPTS env var with ccm updateconf options for auth superuser config Use set_configuration_options() (the Python API behind `ccm updateconf`) to set auth_superuser_name and auth_superuser_salted_password directly in the YAML config instead of passing them via the SCYLLA_EXT_OPTS environment variable. --- tests/integration/standard/test_authentication.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/integration/standard/test_authentication.py b/tests/integration/standard/test_authentication.py index d8073af659..502fdf8993 100644 --- a/tests/integration/standard/test_authentication.py +++ b/tests/integration/standard/test_authentication.py @@ -36,13 +36,16 @@ def setup_module(): - os.environ['SCYLLA_EXT_OPTS'] = '--auth-superuser-name=cassandra --auth-superuser-salted-password=$6$x7IFjiX5VCpvNiFk$2IfjTvSyGL7zerpV.wbY7mJjaRCrJ/68dtT3UpT.sSmNYz1bPjtn3mH.kJKFvaZ2T4SbVeBijjmwGjcb83LlV/' if CASSANDRA_IP.startswith("127.0.0.") and not USE_CASS_EXTERNAL: use_singledc(start=False) ccm_cluster = get_cluster() ccm_cluster.stop() - config_options = {'authenticator': 'PasswordAuthenticator', - 'authorizer': 'CassandraAuthorizer'} + config_options = { + 'authenticator': 'PasswordAuthenticator', + 'authorizer': 'CassandraAuthorizer', + 'auth_superuser_name': 'cassandra', + 'auth_superuser_salted_password': '$6$x7IFjiX5VCpvNiFk$2IfjTvSyGL7zerpV.wbY7mJjaRCrJ/68dtT3UpT.sSmNYz1bPjtn3mH.kJKFvaZ2T4SbVeBijjmwGjcb83LlV/' + } ccm_cluster.set_configuration_options(config_options) log.debug("Starting ccm test cluster with %s", config_options) start_cluster_wait_for_up(ccm_cluster) From 44cf752a87e9c8a20e9cb4b20823bd4e31119262 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 30 Mar 2026 20:43:27 +0000 Subject: [PATCH 18/27] chore(deps): update dependency pygments to v2.20.0 [security] --- docs/uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/uv.lock b/docs/uv.lock index d6b5359d21..2bdf4de3e8 100644 --- a/docs/uv.lock +++ b/docs/uv.lock @@ -614,11 +614,11 @@ wheels = [ [[package]] name = "pygments" -version = "2.19.2" +version = "2.20.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/b2/bc9c9196916376152d655522fdcebac55e66de6603a76a02bca1b6414f6c/pygments-2.20.0.tar.gz", hash = "sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f", size = 4955991, upload-time = "2026-03-29T13:29:33.898Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, + { url = "https://files.pythonhosted.org/packages/f4/7e/a72dd26f3b0f4f2bf1dd8923c85f7ceb43172af56d63c7383eb62b332364/pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176", size = 1231151, upload-time = "2026-03-29T13:29:30.038Z" }, ] [[package]] From d5f9d37681987cce93cad661964d09daa9b2f2a9 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Thu, 26 Mar 2026 18:27:26 +0200 Subject: [PATCH 19/27] (fix) cluster: handle None control_connection_timeout in wait_for_schema_agreement min(self._timeout, total_timeout - elapsed) raises TypeError when control_connection_timeout is set to None, which is explicitly documented as a supported value (meaning no timeout). Guard the min() call so that when self._timeout is None, we use only the remaining schema agreement wait time. --- cassandra/cluster.py | 3 ++- tests/unit/test_control_connection.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 8da9df6a55..9eace8810d 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4117,7 +4117,8 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai local_query = QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_SCHEMA_LOCAL, self._metadata_request_timeout), consistency_level=cl) try: - timeout = min(self._timeout, total_timeout - elapsed) + remaining = total_timeout - elapsed + timeout = min(self._timeout, remaining) if self._timeout is not None else remaining peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=timeout) except OperationTimedOut as timeout: diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index d759e12332..037d4a8888 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -287,6 +287,20 @@ def test_wait_for_schema_agreement_rpc_lookup(self): assert not self.control_connection.wait_for_schema_agreement() assert self.time.clock >= self.cluster.max_schema_agreement_wait + + def test_wait_for_schema_agreement_none_timeout(self): + """ + When control_connection_timeout is None, wait_for_schema_agreement + should not raise a TypeError on the min() call. + """ + cc = ControlConnection(self.cluster, timeout=None, + schema_event_refresh_window=0, + topology_event_refresh_window=0, + status_event_refresh_window=0) + cc._connection = self.connection + cc._time = self.time + assert cc.wait_for_schema_agreement() + def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata From 94438c6f0679c0dc5d49bd2c21f69359fa8b1855 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Thu, 26 Mar 2026 16:32:28 +0200 Subject: [PATCH 20/27] tests: fix flaky TestTwistedConnection.test_connection_initialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Patch reactor.running to False in setUp() so that maybe_start() always enters the branch that spawns the reactor thread. Without this, leaked global reactor state from prior tests can leave reactor.running as True, causing maybe_start() to skip thread creation and the reactor.run mock to never be called — making the assertion in test_connection_initialization fail intermittently. Observed in CI on PyPy 3.11 + macOS x86 (Rosetta 2), where timing differences make the reactor state leak more likely. --- tests/unit/io/test_twistedreactor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/unit/io/test_twistedreactor.py b/tests/unit/io/test_twistedreactor.py index 54abe884ae..8ba9ca5b1d 100644 --- a/tests/unit/io/test_twistedreactor.py +++ b/tests/unit/io/test_twistedreactor.py @@ -99,14 +99,23 @@ def setUp(self): self.reactor_cft_patcher = patch( 'twisted.internet.reactor.callFromThread') self.reactor_run_patcher = patch('twisted.internet.reactor.run') + # Patch reactor.running to False so maybe_start() always enters + # the branch that spawns the reactor thread. Without this, leaked + # reactor state from prior tests can cause reactor.running to be + # True, making maybe_start() a no-op and the reactor.run mock + # never called — leading to a flaky test_connection_initialization. + self.reactor_running_patcher = patch( + 'twisted.internet.reactor.running', new=False) self.mock_reactor_cft = self.reactor_cft_patcher.start() self.mock_reactor_run = self.reactor_run_patcher.start() + self.reactor_running_patcher.start() self.obj_ut = twistedreactor.TwistedConnection(DefaultEndPoint('1.2.3.4'), cql_version='3.0.1') def tearDown(self): self.reactor_cft_patcher.stop() self.reactor_run_patcher.stop() + self.reactor_running_patcher.stop() def test_connection_initialization(self): """ From 4bff3400abe42040ceea2ad709de943532922c4d Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 25 Mar 2026 00:25:11 +0200 Subject: [PATCH 21/27] fix: correct 'clustering_key' to 'clustering' in column kind filter The column kind filter at line 2744 used 'clustering_key' but system_schema.columns uses 'clustering' as the kind value. This caused clustering columns to not be excluded from the 'other columns' loop, resulting in them being processed twice (once as clustering key, once as regular column). The correct value 'clustering' was already used 6 lines above in the clustering key extraction loop. --- cassandra/metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index b85308449e..512aaf7265 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -2741,7 +2741,7 @@ def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=Fa meta.clustering_key.append(meta.columns[r.get('column_name')]) for col_row in (r for r in col_rows - if r.get('kind', None) not in ('partition_key', 'clustering_key')): + if r.get('kind', None) not in ('partition_key', 'clustering')): column_meta = self._build_column_metadata(meta, col_row) if is_dense and column_meta.cql_type == types.cql_empty_type: continue From ad12bedf67c4166d9a34fd7ddbf5ce311a19265d Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 30 Mar 2026 18:01:30 +0000 Subject: [PATCH 22/27] chore(deps): update dependency tornado to v6.5.5 [security] --- docs/uv.lock | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/docs/uv.lock b/docs/uv.lock index 2bdf4de3e8..720a2080e7 100644 --- a/docs/uv.lock +++ b/docs/uv.lock @@ -1040,21 +1040,19 @@ wheels = [ [[package]] name = "tornado" -version = "6.5.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/09/ce/1eb500eae19f4648281bb2186927bb062d2438c2e5093d1360391afd2f90/tornado-6.5.2.tar.gz", hash = "sha256:ab53c8f9a0fa351e2c0741284e06c7a45da86afb544133201c5cc8578eb076a0", size = 510821, upload-time = "2025-08-08T18:27:00.78Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f6/48/6a7529df2c9cc12efd2e8f5dd219516184d703b34c06786809670df5b3bd/tornado-6.5.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2436822940d37cde62771cff8774f4f00b3c8024fe482e16ca8387b8a2724db6", size = 442563, upload-time = "2025-08-08T18:26:42.945Z" }, - { url = "https://files.pythonhosted.org/packages/f2/b5/9b575a0ed3e50b00c40b08cbce82eb618229091d09f6d14bce80fc01cb0b/tornado-6.5.2-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:583a52c7aa94ee046854ba81d9ebb6c81ec0fd30386d96f7640c96dad45a03ef", size = 440729, upload-time = "2025-08-08T18:26:44.473Z" }, - { url = "https://files.pythonhosted.org/packages/1b/4e/619174f52b120efcf23633c817fd3fed867c30bff785e2cd5a53a70e483c/tornado-6.5.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0fe179f28d597deab2842b86ed4060deec7388f1fd9c1b4a41adf8af058907e", size = 444295, upload-time = "2025-08-08T18:26:46.021Z" }, - { url = "https://files.pythonhosted.org/packages/95/fa/87b41709552bbd393c85dd18e4e3499dcd8983f66e7972926db8d96aa065/tornado-6.5.2-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b186e85d1e3536d69583d2298423744740986018e393d0321df7340e71898882", size = 443644, upload-time = "2025-08-08T18:26:47.625Z" }, - { url = "https://files.pythonhosted.org/packages/f9/41/fb15f06e33d7430ca89420283a8762a4e6b8025b800ea51796ab5e6d9559/tornado-6.5.2-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e792706668c87709709c18b353da1f7662317b563ff69f00bab83595940c7108", size = 443878, upload-time = "2025-08-08T18:26:50.599Z" }, - { url = "https://files.pythonhosted.org/packages/11/92/fe6d57da897776ad2e01e279170ea8ae726755b045fe5ac73b75357a5a3f/tornado-6.5.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:06ceb1300fd70cb20e43b1ad8aaee0266e69e7ced38fa910ad2e03285009ce7c", size = 444549, upload-time = "2025-08-08T18:26:51.864Z" }, - { url = "https://files.pythonhosted.org/packages/9b/02/c8f4f6c9204526daf3d760f4aa555a7a33ad0e60843eac025ccfd6ff4a93/tornado-6.5.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:74db443e0f5251be86cbf37929f84d8c20c27a355dd452a5cfa2aada0d001ec4", size = 443973, upload-time = "2025-08-08T18:26:53.625Z" }, - { url = "https://files.pythonhosted.org/packages/ae/2d/f5f5707b655ce2317190183868cd0f6822a1121b4baeae509ceb9590d0bd/tornado-6.5.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b5e735ab2889d7ed33b32a459cac490eda71a1ba6857b0118de476ab6c366c04", size = 443954, upload-time = "2025-08-08T18:26:55.072Z" }, - { url = "https://files.pythonhosted.org/packages/e8/59/593bd0f40f7355806bf6573b47b8c22f8e1374c9b6fd03114bd6b7a3dcfd/tornado-6.5.2-cp39-abi3-win32.whl", hash = "sha256:c6f29e94d9b37a95013bb669616352ddb82e3bfe8326fccee50583caebc8a5f0", size = 445023, upload-time = "2025-08-08T18:26:56.677Z" }, - { url = "https://files.pythonhosted.org/packages/c7/2a/f609b420c2f564a748a2d80ebfb2ee02a73ca80223af712fca591386cafb/tornado-6.5.2-cp39-abi3-win_amd64.whl", hash = "sha256:e56a5af51cc30dd2cae649429af65ca2f6571da29504a07995175df14c18f35f", size = 445427, upload-time = "2025-08-08T18:26:57.91Z" }, - { url = "https://files.pythonhosted.org/packages/5e/4f/e1f65e8f8c76d73658b33d33b81eed4322fb5085350e4328d5c956f0c8f9/tornado-6.5.2-cp39-abi3-win_arm64.whl", hash = "sha256:d6c33dc3672e3a1f3618eb63b7ef4683a7688e7b9e6e8f0d9aa5726360a004af", size = 444456, upload-time = "2025-08-08T18:26:59.207Z" }, +version = "6.5.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/f1/3173dfa4a18db4a9b03e5d55325559dab51ee653763bb8745a75af491286/tornado-6.5.5.tar.gz", hash = "sha256:192b8f3ea91bd7f1f50c06955416ed76c6b72f96779b962f07f911b91e8d30e9", size = 516006, upload-time = "2026-03-10T21:31:02.067Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/8c/77f5097695f4dd8255ecbd08b2a1ed8ba8b953d337804dd7080f199e12bf/tornado-6.5.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:487dc9cc380e29f58c7ab88f9e27cdeef04b2140862e5076a66fb6bb68bb1bfa", size = 445983, upload-time = "2026-03-10T21:30:44.28Z" }, + { url = "https://files.pythonhosted.org/packages/ab/5e/7625b76cd10f98f1516c36ce0346de62061156352353ef2da44e5c21523c/tornado-6.5.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:65a7f1d46d4bb41df1ac99f5fcb685fb25c7e61613742d5108b010975a9a6521", size = 444246, upload-time = "2026-03-10T21:30:46.571Z" }, + { url = "https://files.pythonhosted.org/packages/b2/04/7b5705d5b3c0fab088f434f9c83edac1573830ca49ccf29fb83bf7178eec/tornado-6.5.5-cp39-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e74c92e8e65086b338fd56333fb9a68b9f6f2fe7ad532645a290a464bcf46be5", size = 447229, upload-time = "2026-03-10T21:30:48.273Z" }, + { url = "https://files.pythonhosted.org/packages/34/01/74e034a30ef59afb4097ef8659515e96a39d910b712a89af76f5e4e1f93c/tornado-6.5.5-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:435319e9e340276428bbdb4e7fa732c2d399386d1de5686cb331ec8eee754f07", size = 448192, upload-time = "2026-03-10T21:30:51.22Z" }, + { url = "https://files.pythonhosted.org/packages/be/00/fe9e02c5a96429fce1a1d15a517f5d8444f9c412e0bb9eadfbe3b0fc55bf/tornado-6.5.5-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:3f54aa540bdbfee7b9eb268ead60e7d199de5021facd276819c193c0fb28ea4e", size = 448039, upload-time = "2026-03-10T21:30:53.52Z" }, + { url = "https://files.pythonhosted.org/packages/82/9e/656ee4cec0398b1d18d0f1eb6372c41c6b889722641d84948351ae19556d/tornado-6.5.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:36abed1754faeb80fbd6e64db2758091e1320f6bba74a4cf8c09cd18ccce8aca", size = 447445, upload-time = "2026-03-10T21:30:55.541Z" }, + { url = "https://files.pythonhosted.org/packages/5a/76/4921c00511f88af86a33de770d64141170f1cfd9c00311aea689949e274e/tornado-6.5.5-cp39-abi3-win32.whl", hash = "sha256:dd3eafaaeec1c7f2f8fdcd5f964e8907ad788fe8a5a32c4426fbbdda621223b7", size = 448582, upload-time = "2026-03-10T21:30:57.142Z" }, + { url = "https://files.pythonhosted.org/packages/2c/23/f6c6112a04d28eed765e374435fb1a9198f73e1ec4b4024184f21faeb1ad/tornado-6.5.5-cp39-abi3-win_amd64.whl", hash = "sha256:6443a794ba961a9f619b1ae926a2e900ac20c34483eea67be4ed8f1e58d3ef7b", size = 448990, upload-time = "2026-03-10T21:30:58.857Z" }, + { url = "https://files.pythonhosted.org/packages/b7/c8/876602cbc96469911f0939f703453c1157b0c826ecb05bdd32e023397d4e/tornado-6.5.5-cp39-abi3-win_arm64.whl", hash = "sha256:2c9a876e094109333f888539ddb2de4361743e5d21eece20688e3e351e4990a6", size = 448016, upload-time = "2026-03-10T21:31:00.43Z" }, ] [[package]] From c89858320adbd60e535e77465e36a2cca3496e31 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Tue, 23 Dec 2025 16:58:11 +0200 Subject: [PATCH 23/27] metadata: conditionally skip triggers query for ScyllaDB ScyllaDB doesn't support triggers, so skip the triggers query when connected to ScyllaDB. This is detected by checking if the connection has shard awareness (using the existing _is_not_scylla() method). Changes to both SchemaParserV3 and SchemaParserV4: - Modified _query_all() to conditionally append triggers query only for non-ScyllaDB - Modified _query_all() response unpacking to use array slicing for cleaner code - Modified get_table() in V3 to conditionally query triggers This eliminates unnecessary failed queries to system_schema.triggers on ScyllaDB. Signed-off-by: Yaniv Kaul --- cassandra/metadata.py | 112 +++++++++++++++++++++++++++++------------- 1 file changed, 78 insertions(+), 34 deletions(-) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 512aaf7265..43399b7152 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -2577,6 +2577,10 @@ class SchemaParserV3(SchemaParserV22): _SELECT_AGGREGATES = "SELECT * FROM system_schema.aggregates" _SELECT_VIEWS = "SELECT * FROM system_schema.views" + def _is_not_scylla(self): + """Check if NOT connected to ScyllaDB by checking for shard awareness.""" + return getattr(getattr(self.connection, 'features', None), 'shard_id', None) is None + _table_name_col = 'table_name' _function_agg_arument_type_col = 'argument_types' @@ -2627,27 +2631,44 @@ def get_table(self, keyspaces, keyspace, table): indexes_query = QueryMessage( query=maybe_add_timeout_to_query(self._SELECT_INDEXES + where_clause, self.metadata_request_timeout), consistency_level=cl, fetch_size=fetch_size) - triggers_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS + where_clause, self.metadata_request_timeout), - consistency_level=cl, fetch_size=fetch_size) + + # ScyllaDB doesn't have triggers, skip the query + if self._is_not_scylla(): + triggers_query = QueryMessage( + query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS + where_clause, self.metadata_request_timeout), + consistency_level=cl, fetch_size=fetch_size) # in protocol v4 we don't know if this event is a view or a table, so we look for both where_clause = bind_params(" WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder) view_query = QueryMessage( query=maybe_add_timeout_to_query(self._SELECT_VIEWS + where_clause, self.metadata_request_timeout), consistency_level=cl, fetch_size=fetch_size) - ((cf_success, cf_result), (col_success, col_result), - (indexes_sucess, indexes_result), (triggers_success, triggers_result), - (view_success, view_result)) = ( - self.connection.wait_for_responses( - cf_query, col_query, indexes_query, triggers_query, - view_query, timeout=self.timeout, fail_on_error=False) - ) + + if self._is_not_scylla(): + ((cf_success, cf_result), (col_success, col_result), + (indexes_sucess, indexes_result), (triggers_success, triggers_result), + (view_success, view_result)) = ( + self.connection.wait_for_responses( + cf_query, col_query, indexes_query, triggers_query, + view_query, timeout=self.timeout, fail_on_error=False) + ) + else: + ((cf_success, cf_result), (col_success, col_result), + (indexes_sucess, indexes_result), + (view_success, view_result)) = ( + self.connection.wait_for_responses( + cf_query, col_query, indexes_query, + view_query, timeout=self.timeout, fail_on_error=False) + ) + table_result = self._handle_results(cf_success, cf_result, query_msg=cf_query) col_result = self._handle_results(col_success, col_result, query_msg=col_query) if table_result: indexes_result = self._handle_results(indexes_sucess, indexes_result, query_msg=indexes_query) - triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=triggers_query) + if self._is_not_scylla(): + triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=triggers_query) + else: + triggers_result = None return self._build_table_metadata(table_result[0], col_result, triggers_result, indexes_result) view_result = self._handle_results(view_success, view_result, query_msg=view_query) @@ -2696,9 +2717,10 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_row self._build_table_columns(table_meta, col_rows, compact_static, is_dense, virtual) - for trigger_row in trigger_rows: - trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) - table_meta.triggers[trigger_meta.name] = trigger_meta + if self._is_not_scylla(): + for trigger_row in trigger_rows: + trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) + table_meta.triggers[trigger_meta.name] = trigger_meta for index_row in index_rows: index_meta = self._build_index_metadata(table_meta, index_row) @@ -2793,6 +2815,7 @@ def _build_trigger_metadata(table_metadata, row): trigger_meta = TriggerMetadata(table_metadata, name, options) return trigger_meta + def _query_all(self): cl = ConsistencyLevel.ONE fetch_size = self.fetch_size @@ -2809,35 +2832,45 @@ def _query_all(self): fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_AGGREGATES, self.metadata_request_timeout), fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_INDEXES, self.metadata_request_timeout), fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIEWS, self.metadata_request_timeout), fetch_size=fetch_size, consistency_level=cl), ] + # ScyllaDB doesn't have triggers, skip the query + if self._is_not_scylla(): + queries.append(QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), + fetch_size=fetch_size, consistency_level=cl)) + + responses = self.connection.wait_for_responses(*queries, timeout=self.timeout, fail_on_error=False) + + # Unpack common responses (always present) ((ks_success, ks_result), (table_success, table_result), (col_success, col_result), (types_success, types_result), (functions_success, functions_result), (aggregates_success, aggregates_result), - (triggers_success, triggers_result), (indexes_success, indexes_result), - (views_success, views_result)) = self.connection.wait_for_responses( - *queries, timeout=self.timeout, fail_on_error=False - ) + (views_success, views_result)) = responses[:8] + + # Unpack triggers response if present (Cassandra/DSE only) + if self._is_not_scylla(): + (triggers_success, triggers_result) = responses[8] self.keyspaces_result = self._handle_results(ks_success, ks_result, query_msg=queries[0]) self.tables_result = self._handle_results(table_success, table_result, query_msg=queries[1]) self.columns_result = self._handle_results(col_success, col_result, query_msg=queries[2]) - self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[6]) self.types_result = self._handle_results(types_success, types_result, query_msg=queries[3]) self.functions_result = self._handle_results(functions_success, functions_result, query_msg=queries[4]) self.aggregates_result = self._handle_results(aggregates_success, aggregates_result, query_msg=queries[5]) - self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[7]) - self.views_result = self._handle_results(views_success, views_result, query_msg=queries[8]) + self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[6]) + self.views_result = self._handle_results(views_success, views_result, query_msg=queries[7]) + if self._is_not_scylla(): + self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[8]) + else: + self.triggers_result = [] self._aggregate_results() @@ -2915,8 +2948,6 @@ def _query_all(self): fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_AGGREGATES, self.metadata_request_timeout), fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_INDEXES, self.metadata_request_timeout), fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIEWS, self.metadata_request_timeout), @@ -2930,8 +2961,15 @@ def _query_all(self): fetch_size=fetch_size, consistency_level=cl), ] + # ScyllaDB doesn't have triggers, skip the query + if self._is_not_scylla(): + queries.append(QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), + fetch_size=fetch_size, consistency_level=cl)) + responses = self.connection.wait_for_responses( *queries, timeout=self.timeout, fail_on_error=False) + + # Unpack common responses (always present) ( # copied from V3 (ks_success, ks_result), @@ -2940,39 +2978,45 @@ def _query_all(self): (types_success, types_result), (functions_success, functions_result), (aggregates_success, aggregates_result), - (triggers_success, triggers_result), (indexes_success, indexes_result), (views_success, views_result), # V4-only responses (virtual_ks_success, virtual_ks_result), (virtual_table_success, virtual_table_result), - (virtual_column_success, virtual_column_result) - ) = responses + (virtual_column_success, virtual_column_result), + ) = responses[:11] + + # Unpack triggers response if present (Cassandra/DSE only) + if self._is_not_scylla(): + (triggers_success, triggers_result) = responses[11] # copied from V3 self.keyspaces_result = self._handle_results(ks_success, ks_result, query_msg=queries[0]) self.tables_result = self._handle_results(table_success, table_result, query_msg=queries[1]) self.columns_result = self._handle_results(col_success, col_result, query_msg=queries[2]) - self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[6]) self.types_result = self._handle_results(types_success, types_result, query_msg=queries[3]) self.functions_result = self._handle_results(functions_success, functions_result, query_msg=queries[4]) self.aggregates_result = self._handle_results(aggregates_success, aggregates_result, query_msg=queries[5]) - self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[7]) - self.views_result = self._handle_results(views_success, views_result, query_msg=queries[8]) + self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[6]) + self.views_result = self._handle_results(views_success, views_result, query_msg=queries[7]) + if self._is_not_scylla(): + self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[11]) + else: + self.triggers_result = [] # V4-only results # These tables don't exist in some DSE versions reporting 4.X so we can # ignore them if we got an error self.virtual_keyspaces_result = self._handle_results( virtual_ks_success, virtual_ks_result, - expected_failures=(InvalidRequest,), query_msg=queries[9] + expected_failures=(InvalidRequest,), query_msg=queries[8] ) self.virtual_tables_result = self._handle_results( virtual_table_success, virtual_table_result, - expected_failures=(InvalidRequest,), query_msg=queries[10] + expected_failures=(InvalidRequest,), query_msg=queries[9] ) self.virtual_columns_result = self._handle_results( virtual_column_success, virtual_column_result, - expected_failures=(InvalidRequest,), query_msg=queries[11] + expected_failures=(InvalidRequest,), query_msg=queries[10] ) self._aggregate_results() From c3e237862c5e09eec532b0ec5edb013060c94360 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 23 Dec 2025 09:41:02 +0000 Subject: [PATCH 24/27] Fix code quality issues in test_cluster.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix spelling: 'tring' → 'string' in docstring - Remove extra 't' at end of comment - Refactor complex list comprehension for clarity - Use 'is None' instead of '== None' for None comparison Co-authored-by: mykaul <4655593+mykaul@users.noreply.github.com> --- tests/unit/test_cluster.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 49208ac53e..295fe769c5 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -91,7 +91,10 @@ class ClusterTest(unittest.TestCase): def test_tuple_for_contact_points(self): cluster = Cluster(contact_points=[('localhost', 9045), ('127.0.0.2', 9046), '127.0.0.3'], port=9999) - localhost_addr = set([addr[0] for addr in [t for (_,_,_,_,t) in socket.getaddrinfo("localhost",80)]]) + # Refactored for clarity + addr_info = socket.getaddrinfo("localhost", 80) + sockaddr_tuples = [info[4] for info in addr_info] # info[4] is sockaddr + localhost_addr = set([sockaddr[0] for sockaddr in sockaddr_tuples]) for cp in cluster.endpoints_resolved: if cp.address in localhost_addr: assert cp.port == 9045 @@ -108,7 +111,7 @@ def test_invalid_contact_point_types(self): Cluster(contact_points="not a sequence", protocol_version=4, connect_timeout=1) def test_port_str(self): - """Check port passed as tring is converted and checked properly""" + """Check port passed as string is converted and checked properly""" cluster = Cluster(contact_points=['127.0.0.1'], port='1111') for cp in cluster.endpoints_resolved: if cp.address in ('::1', '127.0.0.1'): @@ -182,7 +185,7 @@ def test_event_delay_timing(self, *_): """ sched = _Scheduler(None) sched.schedule(0, lambda: None) - sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()"t + sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()" class SessionTest(unittest.TestCase): @@ -292,7 +295,7 @@ def test_default_exec_parameters(self): assert cluster.profile_manager.default.request_timeout == 10.0 assert session.default_consistency_level == ConsistencyLevel.LOCAL_ONE assert cluster.profile_manager.default.consistency_level == ConsistencyLevel.LOCAL_ONE - assert session.default_serial_consistency_level == None + assert session.default_serial_consistency_level is None assert cluster.profile_manager.default.serial_consistency_level == None assert session.row_factory == named_tuple_factory assert cluster.profile_manager.default.row_factory == named_tuple_factory From 8e6c4d4e773a8dffc8bf6515a13cbfef3bb5d0ef Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 23 Dec 2025 09:42:07 +0000 Subject: [PATCH 25/27] Fix additional '== None' comparison for consistency Co-authored-by: mykaul <4655593+mykaul@users.noreply.github.com> --- tests/unit/test_cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 295fe769c5..4942fd4d69 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -296,7 +296,7 @@ def test_default_exec_parameters(self): assert session.default_consistency_level == ConsistencyLevel.LOCAL_ONE assert cluster.profile_manager.default.consistency_level == ConsistencyLevel.LOCAL_ONE assert session.default_serial_consistency_level is None - assert cluster.profile_manager.default.serial_consistency_level == None + assert cluster.profile_manager.default.serial_consistency_level is None assert session.row_factory == named_tuple_factory assert cluster.profile_manager.default.row_factory == named_tuple_factory From 985931d7b21cc028f3c4c97dacee98388706ac65 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 13 Mar 2026 12:39:30 +0200 Subject: [PATCH 26/27] (improvement) deserializers: use direct PyUnicode_DecodeUTF8/ASCII from C buffer pointer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the two-step to_bytes(buf).decode('utf8') pattern in DesUTF8Type and DesAsciiType with direct CPython C API calls (PyUnicode_DecodeUTF8 and PyUnicode_DecodeASCII). This eliminates an intermediate bytes object allocation per text cell — the old code created a Python bytes object from the C buffer pointer via to_bytes(buf), then immediately decoded it to str and discarded the bytes. Text (UTF8Type/VarcharType) is the most common CQL column type, so this optimization applies to the majority of cells in typical workloads. Benchmark results (Cython row parsing pipeline, median times): | Scenario | Before (original) | After (direct decode) | Speedup | |---------------------------------|-------------------:|----------------------:|--------:| | UTF8 1row x 1col short (11B) | 565 ns | 454 ns | 1.24x | | UTF8 1row x 10col short | 1,594 ns | 1,023 ns | 1.56x | | UTF8 100rows x 5col medium | 61,396 ns | 28,766 ns | 2.13x | | UTF8 1000rows x 5col medium | 547,145 ns | 290,361 ns | 1.88x | | UTF8 100rows x 5col long(200B) | 57,940 ns | 35,680 ns | 1.62x | | UTF8 100rows x 5col multibyte | 125,149 ns | 103,370 ns | 1.21x | | ASCII 100rows x 5col medium | 41,608 ns | 35,817 ns | 1.16x | | ASCII 1000rows x 5col medium | 416,350 ns | 374,341 ns | 1.11x | | Mixed 100rows 3text+2int | 44,646 ns | 31,189 ns | 1.43x | All existing unit tests pass (62 type tests, 116 total across key suites). --- benchmarks/utf8_decode_benchmark.py | 327 ++++++++++++++++++++++++ cassandra/deserializers.pyx | 6 +- tests/unit/cython/test_deserializers.py | 145 +++++++++++ 3 files changed, 475 insertions(+), 3 deletions(-) create mode 100644 benchmarks/utf8_decode_benchmark.py create mode 100644 tests/unit/cython/test_deserializers.py diff --git a/benchmarks/utf8_decode_benchmark.py b/benchmarks/utf8_decode_benchmark.py new file mode 100644 index 0000000000..fed0e5daa7 --- /dev/null +++ b/benchmarks/utf8_decode_benchmark.py @@ -0,0 +1,327 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmarks for UTF-8 and ASCII deserialization in the Cython row parser. + +This optimization replaces the two-step to_bytes(buf).decode('utf8') with +a direct PyUnicode_DecodeUTF8(buf.ptr, buf.size, NULL) call, eliminating +an intermediate bytes object allocation per text cell. + +Requires: pip install pytest-benchmark + +Run with: pytest benchmarks/utf8_decode_benchmark.py -v --benchmark-sort=name +Compare before/after by running on master vs this branch. + +Correctness tests live in tests/unit/cython/test_deserializers.py. +""" + +import struct +import pytest + +from cassandra.obj_parser import ListParser +from cassandra.bytesio import BytesIOReader +from cassandra.parsing import ParseDesc +from cassandra.deserializers import make_deserializers +from cassandra.cqltypes import UTF8Type, AsciiType, Int32Type +from cassandra.policies import ColDesc + + +def _build_text_rows_buffer(num_rows, num_cols, text_data): + """Build a binary buffer representing num_rows x num_cols of text data. + + Format: [int32 row_count] [row1] [row2] ... + Each row: [cell1] [cell2] ... + Each cell: [int32 length] [data bytes] + """ + parts = [struct.pack(">i", num_rows)] + cell = struct.pack(">i", len(text_data)) + text_data + row = cell * num_cols + parts.append(row * num_rows) + return b"".join(parts) + + +def _build_mixed_rows_buffer(num_rows, text_data, int_value=42): + """Build a buffer with mixed columns: 3 text + 2 int32.""" + parts = [struct.pack(">i", num_rows)] + text_cell = struct.pack(">i", len(text_data)) + text_data + int_cell = struct.pack(">i", 4) + struct.pack(">i", int_value) + row = text_cell + text_cell + text_cell + int_cell + int_cell + parts.append(row * num_rows) + return b"".join(parts) + + +def _make_text_desc(num_cols, protocol_version=4): + """Create a ParseDesc for num_cols text columns.""" + coltypes = [UTF8Type] * num_cols + colnames = [f"col{i}" for i in range(num_cols)] + coldescs = [ColDesc("ks", "tbl", f"col{i}") for i in range(num_cols)] + desers = make_deserializers(coltypes) + return ParseDesc(colnames, coltypes, None, coldescs, desers, protocol_version) + + +def _make_ascii_desc(num_cols, protocol_version=4): + """Create a ParseDesc for num_cols ASCII columns.""" + coltypes = [AsciiType] * num_cols + colnames = [f"col{i}" for i in range(num_cols)] + coldescs = [ColDesc("ks", "tbl", f"col{i}") for i in range(num_cols)] + desers = make_deserializers(coltypes) + return ParseDesc(colnames, coltypes, None, coldescs, desers, protocol_version) + + +def _make_mixed_desc(protocol_version=4): + """Create a ParseDesc for 3 text + 2 int32 columns.""" + coltypes = [UTF8Type, UTF8Type, UTF8Type, Int32Type, Int32Type] + colnames = ["text0", "text1", "text2", "int0", "int1"] + coldescs = [ColDesc("ks", "tbl", n) for n in colnames] + desers = make_deserializers(coltypes) + return ParseDesc(colnames, coltypes, None, coldescs, desers, protocol_version) + + +# --------------------------------------------------------------------------- +# Cython pipeline benchmarks — UTF-8 +# --------------------------------------------------------------------------- + + +class TestUTF8CythonPipeline: + """Benchmark the full Cython row parsing pipeline with UTF-8 text columns. + + These benchmarks measure the end-to-end cost of parsing result sets + through the optimized Cython path. The optimization replaces + to_bytes(buf).decode('utf8') with PyUnicode_DecodeUTF8(buf.ptr, buf.size, NULL), + eliminating one intermediate bytes allocation per text cell. + """ + + def test_bench_utf8_1row_1col_short(self, benchmark): + """1 row x 1 col, short string (11 bytes) — isolates per-call overhead.""" + text = b"hello world" + buf = _build_text_rows_buffer(1, 1, text) + desc = _make_text_desc(1) + parser = ListParser() + + def parse(): + reader = BytesIOReader(buf) + return parser.parse_rows(reader, desc) + + result = benchmark(parse) + assert len(result) == 1 + assert result[0][0] == "hello world" + + def test_bench_utf8_1row_10col_short(self, benchmark): + """1 row x 10 cols, short strings — measures per-column overhead.""" + text = b"hello world" + buf = _build_text_rows_buffer(1, 10, text) + desc = _make_text_desc(10) + parser = ListParser() + + def parse(): + reader = BytesIOReader(buf) + return parser.parse_rows(reader, desc) + + result = benchmark(parse) + assert len(result) == 1 + assert len(result[0]) == 10 + + def test_bench_utf8_100rows_5col_medium(self, benchmark): + """100 rows x 5 cols, medium string (46 bytes) — typical workload.""" + text = b"Hello, this is a test string for benchmarking!" + buf = _build_text_rows_buffer(100, 5, text) + desc = _make_text_desc(5) + parser = ListParser() + + def parse(): + reader = BytesIOReader(buf) + return parser.parse_rows(reader, desc) + + result = benchmark(parse) + assert len(result) == 100 + assert result[0][0] == text.decode("utf8") + + def test_bench_utf8_1000rows_5col_medium(self, benchmark): + """1000 rows x 5 cols, medium string — high-throughput scenario.""" + text = b"Hello, this is a test string for benchmarking!" + buf = _build_text_rows_buffer(1000, 5, text) + desc = _make_text_desc(5) + parser = ListParser() + + def parse(): + reader = BytesIOReader(buf) + return parser.parse_rows(reader, desc) + + result = benchmark(parse) + assert len(result) == 1000 + + def test_bench_utf8_100rows_5col_long(self, benchmark): + """100 rows x 5 cols, long string (200 bytes) — larger values.""" + text = b"A" * 200 + buf = _build_text_rows_buffer(100, 5, text) + desc = _make_text_desc(5) + parser = ListParser() + + def parse(): + reader = BytesIOReader(buf) + return parser.parse_rows(reader, desc) + + result = benchmark(parse) + assert len(result) == 100 + assert result[0][0] == "A" * 200 + + def test_bench_utf8_100rows_5col_multibyte(self, benchmark): + """100 rows x 5 cols, multibyte UTF-8 string — tests non-ASCII.""" + text = "Héllo wörld! こんにちは 🌍".encode("utf-8") + buf = _build_text_rows_buffer(100, 5, text) + desc = _make_text_desc(5) + parser = ListParser() + + def parse(): + reader = BytesIOReader(buf) + return parser.parse_rows(reader, desc) + + result = benchmark(parse) + assert len(result) == 100 + assert result[0][0] == text.decode("utf-8") + + +# --------------------------------------------------------------------------- +# Cython pipeline benchmarks — ASCII +# --------------------------------------------------------------------------- + + +class TestASCIICythonPipeline: + """Benchmark the Cython row parsing pipeline with ASCII text columns.""" + + def test_bench_ascii_100rows_5col_medium(self, benchmark): + """100 rows x 5 cols, medium ASCII string.""" + text = b"Hello, this is a test ASCII string for benchmarking!" + buf = _build_text_rows_buffer(100, 5, text) + desc = _make_ascii_desc(5) + parser = ListParser() + + def parse(): + reader = BytesIOReader(buf) + return parser.parse_rows(reader, desc) + + result = benchmark(parse) + assert len(result) == 100 + assert result[0][0] == text.decode("ascii") + + def test_bench_ascii_1000rows_5col_medium(self, benchmark): + """1000 rows x 5 cols, medium ASCII string.""" + text = b"Hello, this is a test ASCII string for benchmarking!" + buf = _build_text_rows_buffer(1000, 5, text) + desc = _make_ascii_desc(5) + parser = ListParser() + + def parse(): + reader = BytesIOReader(buf) + return parser.parse_rows(reader, desc) + + result = benchmark(parse) + assert len(result) == 1000 + + +# --------------------------------------------------------------------------- +# Mixed columns benchmark +# --------------------------------------------------------------------------- + + +class TestMixedColumnsPipeline: + """Benchmark with mixed column types (text + int) for realism.""" + + def test_bench_mixed_100rows_3text_2int(self, benchmark): + """100 rows x (3 text + 2 int) — realistic mixed schema.""" + text = b"Hello, this is a test string for benchmarking!" + buf = _build_mixed_rows_buffer(100, text) + desc = _make_mixed_desc() + parser = ListParser() + + def parse(): + reader = BytesIOReader(buf) + return parser.parse_rows(reader, desc) + + result = benchmark(parse) + assert len(result) == 100 + assert result[0][0] == text.decode("utf8") + assert result[0][3] == 42 + + +# --------------------------------------------------------------------------- +# Python-level reference (bytes.decode) for comparison +# --------------------------------------------------------------------------- + + +class TestPythonDecodeReference: + """Python-level microbenchmark showing the overhead of creating + intermediate bytes objects before decode, which is what the + original Cython code did (to_bytes(buf).decode('utf8')). + + These benchmarks isolate the bytes-creation overhead that the + PyUnicode_DecodeUTF8 optimization eliminates. + """ + + def test_bench_python_bytes_decode_short(self, benchmark): + """Python reference: bytes.decode('utf8') for 500 short strings.""" + data = b"hello world" + + def decode_loop(): + result = None + for _ in range(500): + result = data.decode("utf8") + return result + + result = benchmark(decode_loop) + assert result == "hello world" + + def test_bench_python_copy_then_decode_short(self, benchmark): + """Python reference: bytes(data).decode('utf8') for 500 short strings. + This simulates the old to_bytes(buf).decode() pattern, where + to_bytes() creates a new bytes object from the C buffer.""" + data = b"hello world" + mv = memoryview(data) + + def decode_loop(): + result = None + for _ in range(500): + copied = bytes(mv) # simulates to_bytes(buf) + result = copied.decode("utf8") + return result + + result = benchmark(decode_loop) + assert result == "hello world" + + def test_bench_python_bytes_decode_medium(self, benchmark): + """Python reference: bytes.decode('utf8') for 500 medium strings.""" + data = b"Hello, this is a test string for benchmarking!" + + def decode_loop(): + result = None + for _ in range(500): + result = data.decode("utf8") + return result + + result = benchmark(decode_loop) + + def test_bench_python_copy_then_decode_medium(self, benchmark): + """Python reference: bytes(memoryview).decode('utf8') for 500 medium strings.""" + data = b"Hello, this is a test string for benchmarking!" + mv = memoryview(data) + + def decode_loop(): + result = None + for _ in range(500): + copied = bytes(mv) # simulates to_bytes(buf) + result = copied.decode("utf8") + return result + + result = benchmark(decode_loop) diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx index 98e8676bbc..2ccc4ef093 100644 --- a/cassandra/deserializers.pyx +++ b/cassandra/deserializers.pyx @@ -14,6 +14,7 @@ from libc.stdint cimport int32_t, uint16_t +from cpython.unicode cimport PyUnicode_DecodeASCII, PyUnicode_DecodeUTF8 include 'cython_marshal.pyx' from cassandra.buffer cimport Buffer, to_bytes, slice_buffer @@ -88,7 +89,7 @@ cdef class DesAsciiType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): if buf.size == 0: return "" - return to_bytes(buf).decode('ascii') + return PyUnicode_DecodeASCII(buf.ptr, buf.size, NULL) cdef class DesFloatType(Deserializer): @@ -173,8 +174,7 @@ cdef class DesUTF8Type(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): if buf.size == 0: return "" - cdef val = to_bytes(buf) - return val.decode('utf8') + return PyUnicode_DecodeUTF8(buf.ptr, buf.size, NULL) cdef class DesVarcharType(DesUTF8Type): diff --git a/tests/unit/cython/test_deserializers.py b/tests/unit/cython/test_deserializers.py new file mode 100644 index 0000000000..cd91e0ea90 --- /dev/null +++ b/tests/unit/cython/test_deserializers.py @@ -0,0 +1,145 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Correctness tests for the Cython UTF-8 and ASCII deserializers. + +These verify that the optimized PyUnicode_DecodeUTF8/DecodeASCII code path +in cassandra/deserializers.pyx produces correct results for edge cases. +""" + +import struct +import unittest + +try: + from cassandra.obj_parser import ListParser + from cassandra.bytesio import BytesIOReader + from cassandra.parsing import ParseDesc + from cassandra.deserializers import make_deserializers + from cassandra.cqltypes import UTF8Type, AsciiType + from cassandra.policies import ColDesc + + HAS_CYTHON = True +except ImportError: + HAS_CYTHON = False + + +def _build_text_rows_buffer(num_rows, num_cols, text_data): + """Build a binary buffer representing num_rows x num_cols of text data. + + Format: [int32 row_count] [row1] [row2] ... + Each row: [cell1] [cell2] ... + Each cell: [int32 length] [data bytes] + """ + parts = [struct.pack(">i", num_rows)] + cell = struct.pack(">i", len(text_data)) + text_data + row = cell * num_cols + parts.append(row * num_rows) + return b"".join(parts) + + +def _make_text_desc(num_cols, protocol_version=4): + """Create a ParseDesc for num_cols text columns.""" + coltypes = [UTF8Type] * num_cols + colnames = [f"col{i}" for i in range(num_cols)] + coldescs = [ColDesc("ks", "tbl", f"col{i}") for i in range(num_cols)] + desers = make_deserializers(coltypes) + return ParseDesc(colnames, coltypes, None, coldescs, desers, protocol_version) + + +def _make_ascii_desc(num_cols, protocol_version=4): + """Create a ParseDesc for num_cols ASCII columns.""" + coltypes = [AsciiType] * num_cols + colnames = [f"col{i}" for i in range(num_cols)] + coldescs = [ColDesc("ks", "tbl", f"col{i}") for i in range(num_cols)] + desers = make_deserializers(coltypes) + return ParseDesc(colnames, coltypes, None, coldescs, desers, protocol_version) + + +@unittest.skipUnless(HAS_CYTHON, "Cython extensions not available") +class TestCythonDeserializerCorrectness(unittest.TestCase): + """Verify that the optimized Cython decode produces correct results.""" + + def test_utf8_empty_string(self): + """Empty string should return empty string.""" + buf = _build_text_rows_buffer(1, 1, b"") + desc = _make_text_desc(1) + parser = ListParser() + reader = BytesIOReader(buf) + rows = parser.parse_rows(reader, desc) + self.assertEqual(rows[0][0], "") + + def test_utf8_ascii_only(self): + """Pure ASCII content.""" + text = b"Hello, World! 12345" + buf = _build_text_rows_buffer(1, 1, text) + desc = _make_text_desc(1) + parser = ListParser() + reader = BytesIOReader(buf) + rows = parser.parse_rows(reader, desc) + self.assertEqual(rows[0][0], "Hello, World! 12345") + + def test_utf8_multibyte(self): + """Multibyte UTF-8 characters.""" + text = "Héllo wörld! こんにちは 🌍".encode("utf-8") + buf = _build_text_rows_buffer(1, 1, text) + desc = _make_text_desc(1) + parser = ListParser() + reader = BytesIOReader(buf) + rows = parser.parse_rows(reader, desc) + self.assertEqual(rows[0][0], "Héllo wörld! こんにちは 🌍") + + def test_utf8_long_string(self): + """Long string (10KB).""" + text = ("x" * 10000).encode("utf-8") + buf = _build_text_rows_buffer(1, 1, text) + desc = _make_text_desc(1) + parser = ListParser() + reader = BytesIOReader(buf) + rows = parser.parse_rows(reader, desc) + self.assertEqual(rows[0][0], "x" * 10000) + + def test_ascii_basic(self): + """Basic ASCII decode.""" + text = b"Simple ASCII text 12345 !@#" + buf = _build_text_rows_buffer(1, 1, text) + desc = _make_ascii_desc(1) + parser = ListParser() + reader = BytesIOReader(buf) + rows = parser.parse_rows(reader, desc) + self.assertEqual(rows[0][0], "Simple ASCII text 12345 !@#") + + def test_utf8_null_value(self): + """NULL value (negative length) should return None.""" + # Build buffer: 1 row, 1 column with length = -1 (NULL) + buf = struct.pack(">i", 1) + struct.pack(">i", -1) + desc = _make_text_desc(1) + parser = ListParser() + reader = BytesIOReader(buf) + rows = parser.parse_rows(reader, desc) + self.assertIsNone(rows[0][0]) + + def test_utf8_multiple_rows_columns(self): + """Multiple rows and columns.""" + texts = [b"alpha", b"beta", b"gamma"] + # Build buffer with 3 rows x 1 col, different values + parts = [struct.pack(">i", 3)] + for t in texts: + parts.append(struct.pack(">i", len(t)) + t) + buf = b"".join(parts) + desc = _make_text_desc(1) + parser = ListParser() + reader = BytesIOReader(buf) + rows = parser.parse_rows(reader, desc) + self.assertEqual([r[0] for r in rows], ["alpha", "beta", "gamma"]) From f0ce46cd72f1e8d2413abfafa378359bccc5a5cb Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Thu, 2 Apr 2026 17:29:44 +0300 Subject: [PATCH 27/27] Address review comments: use cythontest decorator, add importorskip guards, add invalid-input tests - Replace hand-rolled try/except ImportError with the project-standard cythontest decorator and HAVE_CYTHON conditional imports, so VERIFY_CYTHON=True CI mode fails loudly instead of silently skipping. - Add pytest.importorskip guards to the benchmark file so it skips gracefully when pytest-benchmark or Cython extensions are missing. - Add test_utf8_invalid_bytes and test_ascii_invalid_bytes to confirm error propagation through the DriverException wrapper. --- benchmarks/utf8_decode_benchmark.py | 3 ++ tests/unit/cython/test_deserializers.py | 41 ++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/benchmarks/utf8_decode_benchmark.py b/benchmarks/utf8_decode_benchmark.py index fed0e5daa7..2f4d31c5be 100644 --- a/benchmarks/utf8_decode_benchmark.py +++ b/benchmarks/utf8_decode_benchmark.py @@ -30,6 +30,9 @@ import struct import pytest +pytest.importorskip("pytest_benchmark") +pytest.importorskip("cassandra.obj_parser") + from cassandra.obj_parser import ListParser from cassandra.bytesio import BytesIOReader from cassandra.parsing import ParseDesc diff --git a/tests/unit/cython/test_deserializers.py b/tests/unit/cython/test_deserializers.py index cd91e0ea90..1ac9a20e5a 100644 --- a/tests/unit/cython/test_deserializers.py +++ b/tests/unit/cython/test_deserializers.py @@ -22,7 +22,11 @@ import struct import unittest -try: +from tests.unit.cython.utils import cythontest + +from cassandra.cython_deps import HAVE_CYTHON + +if HAVE_CYTHON: from cassandra.obj_parser import ListParser from cassandra.bytesio import BytesIOReader from cassandra.parsing import ParseDesc @@ -30,9 +34,7 @@ from cassandra.cqltypes import UTF8Type, AsciiType from cassandra.policies import ColDesc - HAS_CYTHON = True -except ImportError: - HAS_CYTHON = False +from cassandra import DriverException def _build_text_rows_buffer(num_rows, num_cols, text_data): @@ -67,10 +69,10 @@ def _make_ascii_desc(num_cols, protocol_version=4): return ParseDesc(colnames, coltypes, None, coldescs, desers, protocol_version) -@unittest.skipUnless(HAS_CYTHON, "Cython extensions not available") class TestCythonDeserializerCorrectness(unittest.TestCase): """Verify that the optimized Cython decode produces correct results.""" + @cythontest def test_utf8_empty_string(self): """Empty string should return empty string.""" buf = _build_text_rows_buffer(1, 1, b"") @@ -80,6 +82,7 @@ def test_utf8_empty_string(self): rows = parser.parse_rows(reader, desc) self.assertEqual(rows[0][0], "") + @cythontest def test_utf8_ascii_only(self): """Pure ASCII content.""" text = b"Hello, World! 12345" @@ -90,6 +93,7 @@ def test_utf8_ascii_only(self): rows = parser.parse_rows(reader, desc) self.assertEqual(rows[0][0], "Hello, World! 12345") + @cythontest def test_utf8_multibyte(self): """Multibyte UTF-8 characters.""" text = "Héllo wörld! こんにちは 🌍".encode("utf-8") @@ -100,6 +104,7 @@ def test_utf8_multibyte(self): rows = parser.parse_rows(reader, desc) self.assertEqual(rows[0][0], "Héllo wörld! こんにちは 🌍") + @cythontest def test_utf8_long_string(self): """Long string (10KB).""" text = ("x" * 10000).encode("utf-8") @@ -110,6 +115,7 @@ def test_utf8_long_string(self): rows = parser.parse_rows(reader, desc) self.assertEqual(rows[0][0], "x" * 10000) + @cythontest def test_ascii_basic(self): """Basic ASCII decode.""" text = b"Simple ASCII text 12345 !@#" @@ -120,6 +126,7 @@ def test_ascii_basic(self): rows = parser.parse_rows(reader, desc) self.assertEqual(rows[0][0], "Simple ASCII text 12345 !@#") + @cythontest def test_utf8_null_value(self): """NULL value (negative length) should return None.""" # Build buffer: 1 row, 1 column with length = -1 (NULL) @@ -130,6 +137,7 @@ def test_utf8_null_value(self): rows = parser.parse_rows(reader, desc) self.assertIsNone(rows[0][0]) + @cythontest def test_utf8_multiple_rows_columns(self): """Multiple rows and columns.""" texts = [b"alpha", b"beta", b"gamma"] @@ -143,3 +151,26 @@ def test_utf8_multiple_rows_columns(self): reader = BytesIOReader(buf) rows = parser.parse_rows(reader, desc) self.assertEqual([r[0] for r in rows], ["alpha", "beta", "gamma"]) + + @cythontest + def test_utf8_invalid_bytes(self): + """Invalid UTF-8 bytes should raise an error (DriverException wrapping UnicodeDecodeError).""" + # 0xFF 0xFE is not valid UTF-8 + buf = _build_text_rows_buffer(1, 1, b"\xff\xfe\x80\x81") + desc = _make_text_desc(1) + parser = ListParser() + reader = BytesIOReader(buf) + with self.assertRaises(DriverException) as ctx: + parser.parse_rows(reader, desc) + self.assertIn("utf-8", str(ctx.exception).lower()) + + @cythontest + def test_ascii_invalid_bytes(self): + """Non-ASCII bytes in an ASCII column should raise an error (DriverException wrapping UnicodeDecodeError).""" + buf = _build_text_rows_buffer(1, 1, b"\x80\x81\x82") + desc = _make_ascii_desc(1) + parser = ListParser() + reader = BytesIOReader(buf) + with self.assertRaises(DriverException) as ctx: + parser.parse_rows(reader, desc) + self.assertIn("ascii", str(ctx.exception).lower())